diff --git a/modules/common/src/main/scala/docspell/common/Ident.scala b/modules/common/src/main/scala/docspell/common/Ident.scala index 928f0dd5..a0737022 100644 --- a/modules/common/src/main/scala/docspell/common/Ident.scala +++ b/modules/common/src/main/scala/docspell/common/Ident.scala @@ -80,4 +80,7 @@ object Ident { implicit val order: Order[Ident] = Order.by(_.id) + + implicit val ordering: Ordering[Ident] = + Ordering.by(_.id) } diff --git a/modules/store/src/main/scala/docspell/store/Db.scala b/modules/store/src/main/scala/docspell/store/Db.scala index 5017c115..22824b2b 100644 --- a/modules/store/src/main/scala/docspell/store/Db.scala +++ b/modules/store/src/main/scala/docspell/store/Db.scala @@ -13,6 +13,8 @@ import io.circe.{Decoder, Encoder} sealed trait Db { def name: String def driverClass: String + + def fold[A](fpg: => A, fm: => A, fh2: => A): A } object Db { @@ -20,16 +22,19 @@ object Db { case object PostgreSQL extends Db { val name = "postgresql" val driverClass = "org.postgresql.Driver" + def fold[A](fpg: => A, fm: => A, fh2: => A): A = fpg } case object MariaDB extends Db { val name = "mariadb" val driverClass = "org.mariadb.jdbc.Driver" + def fold[A](fpg: => A, fm: => A, fh2: => A): A = fm } case object H2 extends Db { val name = "h2" val driverClass = "org.h2.Driver" + def fold[A](fpg: => A, fm: => A, fh2: => A): A = fh2 } val all: NonEmptyList[Db] = NonEmptyList.of(PostgreSQL, MariaDB, H2) diff --git a/modules/store/src/main/scala/docspell/store/Store.scala b/modules/store/src/main/scala/docspell/store/Store.scala index ffb3000e..adb21536 100644 --- a/modules/store/src/main/scala/docspell/store/Store.scala +++ b/modules/store/src/main/scala/docspell/store/Store.scala @@ -36,6 +36,8 @@ trait Store[F[_]] { def add(insert: ConnectionIO[Int], exists: ConnectionIO[Boolean]): F[AddResult] def transactor: Transactor[F] + + def dbms: Db } object Store { diff --git a/modules/store/src/main/scala/docspell/store/impl/StoreImpl.scala b/modules/store/src/main/scala/docspell/store/impl/StoreImpl.scala index f49b1035..fb37a482 100644 --- a/modules/store/src/main/scala/docspell/store/impl/StoreImpl.scala +++ b/modules/store/src/main/scala/docspell/store/impl/StoreImpl.scala @@ -29,6 +29,8 @@ final class StoreImpl[F[_]: Async]( ) extends Store[F] { private[this] val xa = transactor + val dbms = jdbc.dbms + def createFileRepository( cfg: FileRepositoryConfig, withAttributeStore: Boolean diff --git a/modules/store/src/main/scala/docspell/store/impl/TempIdTable.scala b/modules/store/src/main/scala/docspell/store/impl/TempIdTable.scala new file mode 100644 index 00000000..ab5bbc64 --- /dev/null +++ b/modules/store/src/main/scala/docspell/store/impl/TempIdTable.scala @@ -0,0 +1,85 @@ +package docspell.store.impl + +import cats.Foldable +import cats.data.NonEmptyList +import cats.effect._ +import cats.syntax.all._ +import docspell.common.Ident +import docspell.store.Db +import docspell.store.qb.{Column, TableDef} +import docspell.store.impl.DoobieMeta._ +import doobie._ +import doobie.implicits._ + +/** Temporary table used to store item ids fetched from fulltext search */ +object TempIdTable { + case class Row(id: Ident) + case class Table(tableName: String, alias: Option[String], dbms: Db) extends TableDef { + val id: Column[Ident] = Column("id", this) + + val all: NonEmptyList[Column[_]] = NonEmptyList.of(id) + + def as(newAlias: String): Table = copy(alias = Some(newAlias)) + + def insertAll[F[_]: Foldable](rows: F[Row]): ConnectionIO[Int] = + insertBatch(this, rows) + + def dropTable: ConnectionIO[Int] = + TempIdTable.dropTable(Fragment.const0(tableName)).update.run + + def createIndex: ConnectionIO[Unit] = { + val analyze = dbms.fold( + TempIdTable.analyzeTablePg(this), + Sync[ConnectionIO].unit, + Sync[ConnectionIO].unit + ) + + TempIdTable.createIndex(this) *> analyze + } + } + + def createTable(db: Db, name: String): ConnectionIO[Table] = { + val stmt = db.fold( + createTablePostgreSQL(Fragment.const(name)), + createTableMariaDB(Fragment.const0(name)), + createTableH2(Fragment.const0(name)) + ) + stmt.as(Table(name, None, db)) + } + + private def dropTable(name: Fragment): Fragment = + sql"""DROP TABLE IF EXISTS $name""" + + private def createTableH2(name: Fragment): ConnectionIO[Int] = + sql"""${dropTable(name)}; CREATE LOCAL TEMPORARY TABLE $name ( + | id varchar not null + |);""".stripMargin.update.run + + private def createTableMariaDB(name: Fragment): ConnectionIO[Int] = + dropTable(name).update.run *> + sql"CREATE TEMPORARY TABLE $name (id varchar(254) not null);".update.run + + private def createTablePostgreSQL(name: Fragment): ConnectionIO[Int] = + sql"""CREATE TEMPORARY TABLE IF NOT EXISTS $name ( + | id varchar not null + |) ON COMMIT DROP;""".stripMargin.update.run + + private def createIndex(table: Table): ConnectionIO[Unit] = { + val idxName = Fragment.const0(s"${table.tableName}_id_idx") + val tableName = Fragment.const0(table.tableName) + val col = Fragment.const0(table.id.name) + sql"""CREATE INDEX IF NOT EXISTS $idxName ON $tableName($col);""".update.run.void + } + + private def analyzeTablePg(table: Table): ConnectionIO[Unit] = { + val tableName = Fragment.const0(table.tableName) + sql"ANALYZE $tableName".update.run.void + } + + private def insertBatch[F[_]: Foldable](table: Table, rows: F[Row]) = { + val sql = + s"INSERT INTO ${table.tableName} (${table.id.name}) VALUES (?)" + + Update[Row](sql).updateMany(rows) + } +} diff --git a/modules/store/src/test/scala/docspell/store/DatabaseTest.scala b/modules/store/src/test/scala/docspell/store/DatabaseTest.scala index d4180b66..a265c6ab 100644 --- a/modules/store/src/test/scala/docspell/store/DatabaseTest.scala +++ b/modules/store/src/test/scala/docspell/store/DatabaseTest.scala @@ -11,6 +11,7 @@ import docspell.common._ import docspell.logging.TestLoggingConfig import munit.CatsEffectSuite import org.testcontainers.utility.DockerImageName +import doobie._ import java.util.UUID @@ -19,6 +20,8 @@ trait DatabaseTest with TestContainersFixtures with TestLoggingConfig { + val cio: Sync[ConnectionIO] = Sync[ConnectionIO] + lazy val mariadbCnt = ForAllContainerFixture( MariaDBContainer.Def(DockerImageName.parse("mariadb:10.5")).createContainer() ) @@ -29,12 +32,12 @@ trait DatabaseTest lazy val pgDataSource = ResourceSuiteLocalFixture( "pgDataSource", - DatabaseTest.makeDataSourceFixture(postgresCnt()) + DatabaseTest.makeDataSourceFixture(IO(postgresCnt())) ) lazy val mariaDataSource = ResourceSuiteLocalFixture( "mariaDataSource", - DatabaseTest.makeDataSourceFixture(mariadbCnt()) + DatabaseTest.makeDataSourceFixture(IO(mariadbCnt())) ) lazy val h2DataSource = ResourceSuiteLocalFixture( @@ -50,34 +53,42 @@ trait DatabaseTest } yield (jdbc, ds)) lazy val pgStore = ResourceSuiteLocalFixture( - "pgStore", { - val (jdbc, ds) = pgDataSource() - StoreFixture.store(ds, jdbc) - } + "pgStore", + for { + t <- Resource.eval(IO(pgDataSource())) + store <- StoreFixture.store(t._2, t._1) + } yield store ) lazy val mariaStore = ResourceSuiteLocalFixture( - "mariaStore", { - val (jdbc, ds) = mariaDataSource() - StoreFixture.store(ds, jdbc) - } + "mariaStore", + for { + t <- Resource.eval(IO(mariaDataSource())) + store <- StoreFixture.store(t._2, t._1) + } yield store ) lazy val h2Store = ResourceSuiteLocalFixture( - "h2Store", { - val (jdbc, ds) = h2DataSource() - StoreFixture.store(ds, jdbc) - } + "h2Store", + for { + t <- Resource.eval(IO(h2DataSource())) + store <- StoreFixture.store(t._2, t._1) + } yield store ) + + def postgresAll = List(postgresCnt, pgDataSource, pgStore) + def mariaDbAll = List(mariadbCnt, mariaDataSource, mariaStore) + def h2All = List(h2DataSource, h2Store) } object DatabaseTest { private def jdbcConfig(cnt: JdbcDatabaseContainer) = JdbcConfig(LenientUri.unsafe(cnt.jdbcUrl), cnt.username, cnt.password) - private def makeDataSourceFixture(cnt: JdbcDatabaseContainer) = + private def makeDataSourceFixture(cnt: IO[JdbcDatabaseContainer]) = for { - jdbc <- Resource.eval(IO(jdbcConfig(cnt))) + c <- Resource.eval(cnt) + jdbc <- Resource.pure(jdbcConfig(c)) ds <- StoreFixture.dataSource(jdbc) } yield (jdbc, ds) } diff --git a/modules/store/src/test/scala/docspell/store/impl/TempIdTableTest.scala b/modules/store/src/test/scala/docspell/store/impl/TempIdTableTest.scala new file mode 100644 index 00000000..54df9879 --- /dev/null +++ b/modules/store/src/test/scala/docspell/store/impl/TempIdTableTest.scala @@ -0,0 +1,55 @@ +package docspell.store.impl + +import cats.effect.IO +import docspell.common.Ident +import docspell.store._ +import docspell.store.impl.TempIdTable.Row +import docspell.store.qb._ +import docspell.store.qb.DSL._ + +class TempIdTableTest extends DatabaseTest { + + override def munitFixtures = postgresAll ++ mariaDbAll ++ h2All + + def id(str: String): Ident = Ident.unsafe(str) + + test("create temporary table postgres") { + val store = pgStore() + assertCreateTempTable(store) + } + + test("create temporary table mariadb") { + val store = mariaStore() + assertCreateTempTable(store) + } + + test("create temporary table h2") { + val store = h2Store() + assertCreateTempTable(store) + } + + def assertCreateTempTable(store: Store[IO]) = { + val insertRows = List(Row(id("abc-def")), Row(id("abc-123")), Row(id("zyx-321"))) + val create = + for { + table <- TempIdTable.createTable(store.dbms, "tt") + n <- table.insertAll(insertRows) + _ <- table.createIndex + rows <- Select(select(table.all), from(table)) + .orderBy(table.id) + .build + .query[Row] + .to[List] + } yield (n, rows) + + val verify = + store.transact(create).map { case (inserted, rows) => + if (store.dbms != Db.MariaDB) { + assertEquals(inserted, 3) + } + assertEquals(rows, insertRows.sortBy(_.id)) + } + + verify *> verify + } +}