Merge pull request #1611 from eikek/refactor-rememberme

Use uid as foreign key in rememberme
This commit is contained in:
mergify[bot] 2022-06-27 21:42:49 +00:00 committed by GitHub
commit f8a0ea9c62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 69 additions and 28 deletions

View File

@ -267,7 +267,9 @@ object Login {
config: Config config: Config
): F[RememberToken] = ): F[RememberToken] =
for { for {
rme <- RRememberMe.generate[F](acc) uid <- OptionT(store.transact(RUser.findIdByAccount(acc)))
.getOrRaise(new IllegalStateException(s"No user_id found for account: $acc"))
rme <- RRememberMe.generate[F](uid)
_ <- store.transact(RRememberMe.insert(rme)) _ <- store.transact(RRememberMe.insert(rme))
token <- RememberToken.user(rme.id, config.serverSecret) token <- RememberToken.user(rme.id, config.serverSecret)
} yield token } yield token

View File

@ -0,0 +1,14 @@
alter table "rememberme" add column "user_id" varchar(254);
update "rememberme" m
set "user_id" = (select "uid" from "user_" where "login" = m."login" and "cid" = m."cid");
alter table "rememberme" alter column "user_id" set not null;
alter table "rememberme" drop constraint "CONSTRAINT_20F";
drop index "rememberme_cid_login_idx";
alter table "rememberme" drop column "login";
alter table "rememberme" drop column "cid";
create index "rememberme_user_id_idx" on "rememberme"("user_id");
alter table "rememberme" add constraint "remember_user_id_fk" foreign key("user_id") references "user_"("uid");

View File

@ -0,0 +1,13 @@
alter table `rememberme` add column (`user_id` varchar(254));
update `rememberme` m
set `user_id` = (select `uid` from `user_` where `login` = m.`login` and `cid` = m.`cid`);
alter table `rememberme` modify `user_id` varchar(254) NOT NULL;
alter table `rememberme` drop foreign key `rememberme_ibfk_1`;
alter table `rememberme` drop column `login` cascade;
alter table `rememberme` drop column `cid` cascade;
create index `rememberme_user_id_idx` on `rememberme`(`user_id`);
alter table `rememberme` add constraint `remember_user_id_fk` foreign key(`user_id`) references `user_`(`uid`);

View File

@ -0,0 +1,12 @@
alter table "rememberme" add column "user_id" varchar(254);
update "rememberme" m
set "user_id" = (select "uid" from "user_" where "login" = m."login" and "cid" = m."cid");
alter table "rememberme" alter column "user_id" set not null;
alter table "rememberme" drop column "login" cascade;
alter table "rememberme" drop column "cid" cascade;
create index "rememberme_user_id_idx" on "rememberme"("user_id");
alter table "rememberme" add constraint "remember_user_id_fk" foreign key("user_id") references "user_"("uid");

View File

@ -7,6 +7,7 @@
package docspell.store.queries package docspell.store.queries
import cats.data.OptionT import cats.data.OptionT
import cats.syntax.all._
import docspell.common._ import docspell.common._
import docspell.store.qb.DSL._ import docspell.store.qb.DSL._
@ -15,10 +16,9 @@ import docspell.store.records.{RCollective, RRememberMe, RUser}
import doobie._ import doobie._
import doobie.implicits._ import doobie.implicits._
import org.log4s._
object QLogin { object QLogin {
private[this] val logger = getLogger private[this] val logger = docspell.logging.getLogger[ConnectionIO]
case class Data( case class Data(
account: AccountId, account: AccountId,
@ -28,25 +28,33 @@ object QLogin {
source: AccountSource source: AccountSource
) )
def findUser(acc: AccountId): ConnectionIO[Option[Data]] = { private def findUser0(
where: (RUser.Table, RCollective.Table) => Condition
): ConnectionIO[Option[Data]] = {
val user = RUser.as("u") val user = RUser.as("u")
val coll = RCollective.as("c") val coll = RCollective.as("c")
val sql = val sql =
Select( Select(
select(user.cid, user.login, user.password, coll.state, user.state, user.source), select(user.cid, user.login, user.password, coll.state, user.state, user.source),
from(user).innerJoin(coll, user.cid === coll.id), from(user).innerJoin(coll, user.cid === coll.id),
user.login === acc.user && user.cid === acc.collective where(user, coll)
).build ).build
logger.trace(s"SQL : $sql") logger.trace(s"SQL : $sql") *>
sql.query[Data].option sql.query[Data].option
} }
def findUser(acc: AccountId): ConnectionIO[Option[Data]] =
findUser0((user, _) => user.login === acc.user && user.cid === acc.collective)
def findUser(userId: Ident): ConnectionIO[Option[Data]] =
findUser0((user, _) => user.uid === userId)
def findByRememberMe( def findByRememberMe(
rememberId: Ident, rememberId: Ident,
minCreated: Timestamp minCreated: Timestamp
): OptionT[ConnectionIO, Data] = ): OptionT[ConnectionIO, Data] =
for { for {
rem <- OptionT(RRememberMe.useRememberMe(rememberId, minCreated)) rem <- OptionT(RRememberMe.useRememberMe(rememberId, minCreated))
acc <- OptionT(findUser(rem.accountId)) acc <- OptionT(findUser(rem.userId))
} yield acc } yield acc
} }

View File

@ -64,7 +64,7 @@ object QUser {
n2 <- deleteUserSentMails(uid) n2 <- deleteUserSentMails(uid)
_ <- logger.info(s"Removed $n2 sent mails") _ <- logger.info(s"Removed $n2 sent mails")
n3 <- deleteRememberMe(accountId) n3 <- deleteRememberMe(uid)
_ <- logger.info(s"Removed $n3 remember me tokens") _ <- logger.info(s"Removed $n3 remember me tokens")
n4 <- deleteTotp(uid) n4 <- deleteTotp(uid)
@ -111,11 +111,8 @@ object QUser {
} yield n1.sum + n2.sum } yield n1.sum + n2.sum
} }
def deleteRememberMe(id: AccountId): ConnectionIO[Int] = def deleteRememberMe(userId: Ident): ConnectionIO[Int] =
DML.delete( DML.delete(RRememberMe.T, RRememberMe.T.userId === userId)
RRememberMe.T,
RRememberMe.T.cid === id.collective && RRememberMe.T.username === id.user
)
def deleteTotp(uid: Ident): ConnectionIO[Int] = def deleteTotp(uid: Ident): ConnectionIO[Int] =
DML.delete(RTotp.T, RTotp.T.userId === uid) DML.delete(RTotp.T, RTotp.T.userId === uid)
@ -130,10 +127,6 @@ object QUser {
} }
private def loadUserId(id: AccountId): ConnectionIO[Option[Ident]] = private def loadUserId(id: AccountId): ConnectionIO[Option[Ident]] =
run( RUser.findIdByAccount(id)
select(RUser.T.uid),
from(RUser.T),
RUser.T.cid === id.collective && RUser.T.login === id.user
).query[Ident].option
} }

View File

@ -17,39 +17,38 @@ import docspell.store.qb._
import doobie._ import doobie._
import doobie.implicits._ import doobie.implicits._
case class RRememberMe(id: Ident, accountId: AccountId, created: Timestamp, uses: Int) {} case class RRememberMe(id: Ident, userId: Ident, created: Timestamp, uses: Int) {}
object RRememberMe { object RRememberMe {
final case class Table(alias: Option[String]) extends TableDef { final case class Table(alias: Option[String]) extends TableDef {
val tableName = "rememberme" val tableName = "rememberme"
val id = Column[Ident]("id", this) val id = Column[Ident]("id", this)
val cid = Column[Ident]("cid", this) val userId = Column[Ident]("user_id", this)
val username = Column[Ident]("login", this)
val created = Column[Timestamp]("created", this) val created = Column[Timestamp]("created", this)
val uses = Column[Int]("uses", this) val uses = Column[Int]("uses", this)
val all = NonEmptyList.of[Column[_]](id, cid, username, created, uses) val all = NonEmptyList.of[Column[_]](id, userId, created, uses)
} }
val T = Table(None) val T = Table(None)
def as(alias: String): Table = def as(alias: String): Table =
Table(Some(alias)) Table(Some(alias))
def generate[F[_]: Sync](account: AccountId): F[RRememberMe] = def generate[F[_]: Sync](userId: Ident): F[RRememberMe] =
for { for {
c <- Timestamp.current[F] c <- Timestamp.current[F]
i <- Ident.randomId[F] i <- Ident.randomId[F]
} yield RRememberMe(i, account, c, 0) } yield RRememberMe(i, userId, c, 0)
def insert(v: RRememberMe): ConnectionIO[Int] = def insert(v: RRememberMe): ConnectionIO[Int] =
DML.insert( DML.insert(
T, T,
T.all, T.all,
fr"${v.id},${v.accountId.collective},${v.accountId.user},${v.created},${v.uses}" fr"${v.id},${v.userId},${v.created},${v.uses}"
) )
def insertNew(acc: AccountId): ConnectionIO[RRememberMe] = def insertNew(userId: Ident): ConnectionIO[RRememberMe] =
generate[ConnectionIO](acc).flatMap(v => insert(v).map(_ => v)) generate[ConnectionIO](userId).flatMap(v => insert(v).map(_ => v))
def findById(rid: Ident): ConnectionIO[Option[RRememberMe]] = def findById(rid: Ident): ConnectionIO[Option[RRememberMe]] =
run(select(T.all), from(T), T.id === rid).query[RRememberMe].option run(select(T.all), from(T), T.id === rid).query[RRememberMe].option