mirror of
				https://github.com/TheAnachronism/docspell.git
				synced 2025-11-03 18:00:11 +00:00 
			
		
		
		
	Merge pull request #1611 from eikek/refactor-rememberme
Use uid as foreign key in rememberme
This commit is contained in:
		@@ -267,7 +267,9 @@ object Login {
 | 
				
			|||||||
          config: Config
 | 
					          config: Config
 | 
				
			||||||
      ): F[RememberToken] =
 | 
					      ): F[RememberToken] =
 | 
				
			||||||
        for {
 | 
					        for {
 | 
				
			||||||
          rme <- RRememberMe.generate[F](acc)
 | 
					          uid <- OptionT(store.transact(RUser.findIdByAccount(acc)))
 | 
				
			||||||
 | 
					            .getOrRaise(new IllegalStateException(s"No user_id found for account: $acc"))
 | 
				
			||||||
 | 
					          rme <- RRememberMe.generate[F](uid)
 | 
				
			||||||
          _ <- store.transact(RRememberMe.insert(rme))
 | 
					          _ <- store.transact(RRememberMe.insert(rme))
 | 
				
			||||||
          token <- RememberToken.user(rme.id, config.serverSecret)
 | 
					          token <- RememberToken.user(rme.id, config.serverSecret)
 | 
				
			||||||
        } yield token
 | 
					        } yield token
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,14 @@
 | 
				
			|||||||
 | 
					alter table "rememberme" add column "user_id" varchar(254);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					update "rememberme" m
 | 
				
			||||||
 | 
					set "user_id" = (select "uid" from "user_" where "login" = m."login" and "cid" = m."cid");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					alter table "rememberme" alter column "user_id" set not null;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					alter table "rememberme" drop constraint "CONSTRAINT_20F";
 | 
				
			||||||
 | 
					drop index "rememberme_cid_login_idx";
 | 
				
			||||||
 | 
					alter table "rememberme" drop column "login";
 | 
				
			||||||
 | 
					alter table "rememberme" drop column "cid";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					create index "rememberme_user_id_idx" on "rememberme"("user_id");
 | 
				
			||||||
 | 
					alter table "rememberme" add constraint "remember_user_id_fk" foreign key("user_id") references "user_"("uid");
 | 
				
			||||||
@@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					alter table `rememberme` add column (`user_id` varchar(254));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					update `rememberme` m
 | 
				
			||||||
 | 
					set `user_id` = (select `uid` from `user_` where `login` = m.`login` and `cid` = m.`cid`);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					alter table `rememberme` modify `user_id` varchar(254) NOT NULL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					alter table `rememberme` drop foreign key `rememberme_ibfk_1`;
 | 
				
			||||||
 | 
					alter table `rememberme` drop column `login` cascade;
 | 
				
			||||||
 | 
					alter table `rememberme` drop column `cid` cascade;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					create index `rememberme_user_id_idx` on `rememberme`(`user_id`);
 | 
				
			||||||
 | 
					alter table `rememberme` add constraint `remember_user_id_fk` foreign key(`user_id`) references `user_`(`uid`);
 | 
				
			||||||
@@ -0,0 +1,12 @@
 | 
				
			|||||||
 | 
					alter table "rememberme" add column "user_id" varchar(254);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					update "rememberme" m
 | 
				
			||||||
 | 
					set "user_id" = (select "uid" from "user_" where "login" = m."login" and "cid" = m."cid");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					alter table "rememberme" alter column "user_id" set not null;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					alter table "rememberme" drop column "login" cascade;
 | 
				
			||||||
 | 
					alter table "rememberme" drop column "cid" cascade;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					create index "rememberme_user_id_idx" on "rememberme"("user_id");
 | 
				
			||||||
 | 
					alter table "rememberme" add constraint "remember_user_id_fk" foreign key("user_id") references "user_"("uid");
 | 
				
			||||||
@@ -7,6 +7,7 @@
 | 
				
			|||||||
package docspell.store.queries
 | 
					package docspell.store.queries
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import cats.data.OptionT
 | 
					import cats.data.OptionT
 | 
				
			||||||
 | 
					import cats.syntax.all._
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import docspell.common._
 | 
					import docspell.common._
 | 
				
			||||||
import docspell.store.qb.DSL._
 | 
					import docspell.store.qb.DSL._
 | 
				
			||||||
@@ -15,10 +16,9 @@ import docspell.store.records.{RCollective, RRememberMe, RUser}
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import doobie._
 | 
					import doobie._
 | 
				
			||||||
import doobie.implicits._
 | 
					import doobie.implicits._
 | 
				
			||||||
import org.log4s._
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
object QLogin {
 | 
					object QLogin {
 | 
				
			||||||
  private[this] val logger = getLogger
 | 
					  private[this] val logger = docspell.logging.getLogger[ConnectionIO]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  case class Data(
 | 
					  case class Data(
 | 
				
			||||||
      account: AccountId,
 | 
					      account: AccountId,
 | 
				
			||||||
@@ -28,25 +28,33 @@ object QLogin {
 | 
				
			|||||||
      source: AccountSource
 | 
					      source: AccountSource
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def findUser(acc: AccountId): ConnectionIO[Option[Data]] = {
 | 
					  private def findUser0(
 | 
				
			||||||
 | 
					      where: (RUser.Table, RCollective.Table) => Condition
 | 
				
			||||||
 | 
					  ): ConnectionIO[Option[Data]] = {
 | 
				
			||||||
    val user = RUser.as("u")
 | 
					    val user = RUser.as("u")
 | 
				
			||||||
    val coll = RCollective.as("c")
 | 
					    val coll = RCollective.as("c")
 | 
				
			||||||
    val sql =
 | 
					    val sql =
 | 
				
			||||||
      Select(
 | 
					      Select(
 | 
				
			||||||
        select(user.cid, user.login, user.password, coll.state, user.state, user.source),
 | 
					        select(user.cid, user.login, user.password, coll.state, user.state, user.source),
 | 
				
			||||||
        from(user).innerJoin(coll, user.cid === coll.id),
 | 
					        from(user).innerJoin(coll, user.cid === coll.id),
 | 
				
			||||||
        user.login === acc.user && user.cid === acc.collective
 | 
					        where(user, coll)
 | 
				
			||||||
      ).build
 | 
					      ).build
 | 
				
			||||||
    logger.trace(s"SQL : $sql")
 | 
					    logger.trace(s"SQL : $sql") *>
 | 
				
			||||||
    sql.query[Data].option
 | 
					      sql.query[Data].option
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def findUser(acc: AccountId): ConnectionIO[Option[Data]] =
 | 
				
			||||||
 | 
					    findUser0((user, _) => user.login === acc.user && user.cid === acc.collective)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def findUser(userId: Ident): ConnectionIO[Option[Data]] =
 | 
				
			||||||
 | 
					    findUser0((user, _) => user.uid === userId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def findByRememberMe(
 | 
					  def findByRememberMe(
 | 
				
			||||||
      rememberId: Ident,
 | 
					      rememberId: Ident,
 | 
				
			||||||
      minCreated: Timestamp
 | 
					      minCreated: Timestamp
 | 
				
			||||||
  ): OptionT[ConnectionIO, Data] =
 | 
					  ): OptionT[ConnectionIO, Data] =
 | 
				
			||||||
    for {
 | 
					    for {
 | 
				
			||||||
      rem <- OptionT(RRememberMe.useRememberMe(rememberId, minCreated))
 | 
					      rem <- OptionT(RRememberMe.useRememberMe(rememberId, minCreated))
 | 
				
			||||||
      acc <- OptionT(findUser(rem.accountId))
 | 
					      acc <- OptionT(findUser(rem.userId))
 | 
				
			||||||
    } yield acc
 | 
					    } yield acc
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -64,7 +64,7 @@ object QUser {
 | 
				
			|||||||
      n2 <- deleteUserSentMails(uid)
 | 
					      n2 <- deleteUserSentMails(uid)
 | 
				
			||||||
      _ <- logger.info(s"Removed $n2 sent mails")
 | 
					      _ <- logger.info(s"Removed $n2 sent mails")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      n3 <- deleteRememberMe(accountId)
 | 
					      n3 <- deleteRememberMe(uid)
 | 
				
			||||||
      _ <- logger.info(s"Removed $n3 remember me tokens")
 | 
					      _ <- logger.info(s"Removed $n3 remember me tokens")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      n4 <- deleteTotp(uid)
 | 
					      n4 <- deleteTotp(uid)
 | 
				
			||||||
@@ -111,11 +111,8 @@ object QUser {
 | 
				
			|||||||
    } yield n1.sum + n2.sum
 | 
					    } yield n1.sum + n2.sum
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def deleteRememberMe(id: AccountId): ConnectionIO[Int] =
 | 
					  def deleteRememberMe(userId: Ident): ConnectionIO[Int] =
 | 
				
			||||||
    DML.delete(
 | 
					    DML.delete(RRememberMe.T, RRememberMe.T.userId === userId)
 | 
				
			||||||
      RRememberMe.T,
 | 
					 | 
				
			||||||
      RRememberMe.T.cid === id.collective && RRememberMe.T.username === id.user
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def deleteTotp(uid: Ident): ConnectionIO[Int] =
 | 
					  def deleteTotp(uid: Ident): ConnectionIO[Int] =
 | 
				
			||||||
    DML.delete(RTotp.T, RTotp.T.userId === uid)
 | 
					    DML.delete(RTotp.T, RTotp.T.userId === uid)
 | 
				
			||||||
@@ -130,10 +127,6 @@ object QUser {
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private def loadUserId(id: AccountId): ConnectionIO[Option[Ident]] =
 | 
					  private def loadUserId(id: AccountId): ConnectionIO[Option[Ident]] =
 | 
				
			||||||
    run(
 | 
					    RUser.findIdByAccount(id)
 | 
				
			||||||
      select(RUser.T.uid),
 | 
					 | 
				
			||||||
      from(RUser.T),
 | 
					 | 
				
			||||||
      RUser.T.cid === id.collective && RUser.T.login === id.user
 | 
					 | 
				
			||||||
    ).query[Ident].option
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,39 +17,38 @@ import docspell.store.qb._
 | 
				
			|||||||
import doobie._
 | 
					import doobie._
 | 
				
			||||||
import doobie.implicits._
 | 
					import doobie.implicits._
 | 
				
			||||||
 | 
					
 | 
				
			||||||
case class RRememberMe(id: Ident, accountId: AccountId, created: Timestamp, uses: Int) {}
 | 
					case class RRememberMe(id: Ident, userId: Ident, created: Timestamp, uses: Int) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
object RRememberMe {
 | 
					object RRememberMe {
 | 
				
			||||||
  final case class Table(alias: Option[String]) extends TableDef {
 | 
					  final case class Table(alias: Option[String]) extends TableDef {
 | 
				
			||||||
    val tableName = "rememberme"
 | 
					    val tableName = "rememberme"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    val id = Column[Ident]("id", this)
 | 
					    val id = Column[Ident]("id", this)
 | 
				
			||||||
    val cid = Column[Ident]("cid", this)
 | 
					    val userId = Column[Ident]("user_id", this)
 | 
				
			||||||
    val username = Column[Ident]("login", this)
 | 
					 | 
				
			||||||
    val created = Column[Timestamp]("created", this)
 | 
					    val created = Column[Timestamp]("created", this)
 | 
				
			||||||
    val uses = Column[Int]("uses", this)
 | 
					    val uses = Column[Int]("uses", this)
 | 
				
			||||||
    val all = NonEmptyList.of[Column[_]](id, cid, username, created, uses)
 | 
					    val all = NonEmptyList.of[Column[_]](id, userId, created, uses)
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  val T = Table(None)
 | 
					  val T = Table(None)
 | 
				
			||||||
  def as(alias: String): Table =
 | 
					  def as(alias: String): Table =
 | 
				
			||||||
    Table(Some(alias))
 | 
					    Table(Some(alias))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def generate[F[_]: Sync](account: AccountId): F[RRememberMe] =
 | 
					  def generate[F[_]: Sync](userId: Ident): F[RRememberMe] =
 | 
				
			||||||
    for {
 | 
					    for {
 | 
				
			||||||
      c <- Timestamp.current[F]
 | 
					      c <- Timestamp.current[F]
 | 
				
			||||||
      i <- Ident.randomId[F]
 | 
					      i <- Ident.randomId[F]
 | 
				
			||||||
    } yield RRememberMe(i, account, c, 0)
 | 
					    } yield RRememberMe(i, userId, c, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def insert(v: RRememberMe): ConnectionIO[Int] =
 | 
					  def insert(v: RRememberMe): ConnectionIO[Int] =
 | 
				
			||||||
    DML.insert(
 | 
					    DML.insert(
 | 
				
			||||||
      T,
 | 
					      T,
 | 
				
			||||||
      T.all,
 | 
					      T.all,
 | 
				
			||||||
      fr"${v.id},${v.accountId.collective},${v.accountId.user},${v.created},${v.uses}"
 | 
					      fr"${v.id},${v.userId},${v.created},${v.uses}"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def insertNew(acc: AccountId): ConnectionIO[RRememberMe] =
 | 
					  def insertNew(userId: Ident): ConnectionIO[RRememberMe] =
 | 
				
			||||||
    generate[ConnectionIO](acc).flatMap(v => insert(v).map(_ => v))
 | 
					    generate[ConnectionIO](userId).flatMap(v => insert(v).map(_ => v))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def findById(rid: Ident): ConnectionIO[Option[RRememberMe]] =
 | 
					  def findById(rid: Ident): ConnectionIO[Option[RRememberMe]] =
 | 
				
			||||||
    run(select(T.all), from(T), T.id === rid).query[RRememberMe].option
 | 
					    run(select(T.all), from(T), T.id === rid).query[RRememberMe].option
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user