diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala index 3da3b5ba..d8846fc4 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala @@ -26,7 +26,7 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift]( .use { dir => for { rawData <- writeDataFile(blocker, dir, data) - _ <- logger.debug(s"Learning from ${rawData.count} items.") + _ <- logger.info(s"Learning from ${rawData.count} items.") trainData <- splitData(logger, rawData) scores <- cfg.classifierConfigs.traverse(m => train(logger, trainData, m)) sorted = scores.sortBy(-_.score) @@ -43,7 +43,7 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift]( val cls = ColumnDataClassifier.getClassifier( model.model.normalize().toAbsolutePath().toString() ) - val cat = cls.classOf(cls.makeDatumFromLine(normalisedText(text))) + val cat = cls.classOf(cls.makeDatumFromLine("\t\t" + normalisedText(text))) Option(cat) } @@ -66,7 +66,7 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift]( } yield res def splitData(logger: Logger[F], in: RawData): F[TrainData] = { - val nTest = (in.count * 0.25).toLong + val nTest = (in.count * 0.15).toLong val td = TrainData(in.file.resolveSibling("train.txt"), in.file.resolveSibling("test.txt")) @@ -106,9 +106,10 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift]( counter <- Ref.of[F, Long](0L) _ <- data - .map(d => s"${d.cls}\t${d.ref}\t${normalisedText(d.text)}") + .filter(_.text.nonEmpty) + .map(d => s"${d.cls}\t${fixRef(d.ref)}\t${normalisedText(d.text)}") .evalTap(_ => counter.update(_ + 1)) - .intersperse("\n") + .intersperse("\r\n") .through(fs2.text.utf8Encode) .through(fs2.io.file.writeAll(target, blocker)) .compile @@ -119,13 +120,16 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift]( } def normalisedText(text: String): String = - text.replaceAll("[\n\t]+", " ") + text.replaceAll("[\n\r\t]+", " ") + + def fixRef(str: String): String = + str.replace('\t', '_') def amendProps( trainData: TrainData, props: Map[String, String] ): Map[String, String] = - prepend("2", props) ++ Map( + prepend("2.", props) ++ Map( "trainFile" -> trainData.train.normalize().toAbsolutePath().toString(), "testFile" -> trainData.test.normalize().toAbsolutePath().toString(), "serializeTo" -> trainData.modelFile.normalize().toAbsolutePath().toString() diff --git a/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala b/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala index 6c11fecf..013cd215 100644 --- a/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala +++ b/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala @@ -3,7 +3,8 @@ package docspell.joex.learn import cats.data.Kleisli import cats.data.OptionT import cats.effect._ -import fs2.Stream +import cats.implicits._ +import fs2.{Pipe, Stream} import docspell.analysis.TextAnalyser import docspell.analysis.nlp.ClassifierModel @@ -12,9 +13,13 @@ import docspell.backend.ops.OCollective import docspell.common._ import docspell.joex.Config import docspell.joex.scheduler._ +import docspell.store.queries.QItem import docspell.store.records.RClassifierSetting +import bitpeace.MimetypeHint + object LearnClassifierTask { + val noClass = "__NONE__" type Args = LearnClassifierArgs @@ -31,29 +36,58 @@ object LearnClassifierTask { sett <- findActiveSettings[F](ctx, cfg) data = selectItems( ctx, - math.min(cfg.classification.itemCount, sett.itemCount), + math.min(cfg.classification.itemCount, sett.itemCount).toLong, sett.category.getOrElse("") ) _ <- OptionT.liftF( analyser .classifier(blocker) - .trainClassifier[Unit](ctx.logger, data)(Kleisli(handleModel(ctx))) + .trainClassifier[Unit](ctx.logger, data)(Kleisli(handleModel(ctx, blocker))) ) } yield ()) .getOrElseF(logInactiveWarning(ctx.logger)) } - private def handleModel[F[_]]( - ctx: Context[F, Args] + private def handleModel[F[_]: Sync: ContextShift]( + ctx: Context[F, Args], + blocker: Blocker )(trainedModel: ClassifierModel): F[Unit] = - ??? + for { + oldFile <- ctx.store.transact( + RClassifierSetting.findById(ctx.args.collective).map(_.flatMap(_.fileId)) + ) + _ <- ctx.logger.info("Storing new trained model") + fileData = fs2.io.file.readAll(trainedModel.model, blocker, 4096) + newFile <- + ctx.store.bitpeace.saveNew(fileData, 4096, MimetypeHint.none).compile.lastOrError + _ <- ctx.store.transact( + RClassifierSetting.updateFile(ctx.args.collective, Ident.unsafe(newFile.id)) + ) + _ <- ctx.logger.debug(s"New model stored at file ${newFile.id}") + _ <- oldFile match { + case Some(fid) => + ctx.logger.debug(s"Deleting old model file ${fid.id}") *> + ctx.store.bitpeace.delete(fid.id).compile.drain + case None => ().pure[F] + } + } yield () private def selectItems[F[_]]( ctx: Context[F, Args], - max: Int, + max: Long, category: String - ): Stream[F, Data] = - ??? + ): Stream[F, Data] = { + val connStream = + for { + item <- QItem.findAllNewesFirst(ctx.args.collective, 10).through(restrictTo(max)) + tt <- Stream.eval(QItem.resolveTextAndTag(ctx.args.collective, item, category)) + } yield Data(tt.tag.map(_.name).getOrElse(noClass), item.id, tt.text.trim) + ctx.store.transact(connStream.filter(_.text.nonEmpty)) + } + + private def restrictTo[F[_], A](max: Long): Pipe[F, A, A] = + if (max <= 0) identity + else _.take(max) private def findActiveSettings[F[_]: Sync]( ctx: Context[F, Args], diff --git a/modules/store/src/main/scala/docspell/store/impl/DoobieSyntax.scala b/modules/store/src/main/scala/docspell/store/impl/DoobieSyntax.scala index e4a67538..3a992b71 100644 --- a/modules/store/src/main/scala/docspell/store/impl/DoobieSyntax.scala +++ b/modules/store/src/main/scala/docspell/store/impl/DoobieSyntax.scala @@ -67,8 +67,8 @@ trait DoobieSyntax { Fragment.const(" FROM ") ++ table ++ this.where(where) def selectDistinct(cols: Seq[Column], table: Fragment, where: Fragment): Fragment = - Fragment.const("SELECT DISTINCT(") ++ commas(cols.map(_.f)) ++ - Fragment.const(") FROM ") ++ table ++ this.where(where) + Fragment.const("SELECT DISTINCT ") ++ commas(cols.map(_.f)) ++ + Fragment.const(" FROM ") ++ table ++ this.where(where) def selectCount(col: Column, table: Fragment, where: Fragment): Fragment = Fragment.const("SELECT COUNT(") ++ col.f ++ Fragment.const(") FROM ") ++ table ++ this diff --git a/modules/store/src/main/scala/docspell/store/queries/QItem.scala b/modules/store/src/main/scala/docspell/store/queries/QItem.scala index 1240d4a7..312523ce 100644 --- a/modules/store/src/main/scala/docspell/store/queries/QItem.scala +++ b/modules/store/src/main/scala/docspell/store/queries/QItem.scala @@ -7,6 +7,7 @@ import cats.effect.concurrent.Ref import cats.implicits._ import fs2.Stream +import docspell.common.syntax.all._ import docspell.common.{IdRef, _} import docspell.store.Store import docspell.store.impl.Implicits._ @@ -615,4 +616,74 @@ object QItem { .query[NameAndNotes] .streamWithChunkSize(chunkSize) } + + def findAllNewesFirst( + collective: Ident, + chunkSize: Int + ): Stream[ConnectionIO, Ident] = { + val cols = Seq(RItem.Columns.id) + (selectSimple(cols, RItem.table, RItem.Columns.cid.is(collective)) ++ + orderBy(RItem.Columns.created.desc)) + .query[Ident] + .streamWithChunkSize(chunkSize) + } + + case class TagName(id: Ident, name: String) + case class TextAndTag(itemId: Ident, text: String, tag: Option[TagName]) + + def resolveTextAndTag( + collective: Ident, + itemId: Ident, + tagCategory: String + ): ConnectionIO[TextAndTag] = { + val aId = RAttachment.Columns.id.prefix("a") + val aItem = RAttachment.Columns.itemId.prefix("a") + val mId = RAttachmentMeta.Columns.id.prefix("m") + val mText = RAttachmentMeta.Columns.content.prefix("m") + val tiItem = RTagItem.Columns.itemId.prefix("ti") + val tiTag = RTagItem.Columns.tagId.prefix("ti") + val tId = RTag.Columns.tid.prefix("t") + val tName = RTag.Columns.name.prefix("t") + val tCat = RTag.Columns.category.prefix("t") + val iId = RItem.Columns.id.prefix("i") + val iColl = RItem.Columns.cid.prefix("i") + + val cte = withCTE( + "tags" -> selectSimple( + Seq(tiItem, tId, tName), + RTagItem.table ++ fr"ti INNER JOIN" ++ + RTag.table ++ fr"t ON" ++ tId.is(tiTag), + and(tiItem.is(itemId), tCat.is(tagCategory)) + ) + ) + + val cols = Seq(mText, tId, tName) + + val from = RItem.table ++ fr"i INNER JOIN" ++ + RAttachment.table ++ fr"a ON" ++ aItem.is(iId) ++ fr"INNER JOIN" ++ + RAttachmentMeta.table ++ fr"m ON" ++ aId.is(mId) ++ fr"LEFT JOIN" ++ + fr"tags t ON" ++ RTagItem.Columns.itemId.prefix("t").is(iId) + + val where = + and( + iId.is(itemId), + iColl.is(collective), + mText.isNotNull, + mText.isNot("") + ) + + val q = cte ++ selectDistinct(cols, from, where) + for { + _ <- logger.ftrace[ConnectionIO]( + s"query: $q (${itemId.id}, ${collective.id}, ${tagCategory})" + ) + texts <- q.query[(String, Option[TagName])].to[List] + _ <- logger.ftrace[ConnectionIO]( + s"Got ${texts.size} text and tag entries for item ${itemId.id}" + ) + tag = texts.headOption.flatMap(_._2) + txt = texts.map(_._1).mkString(" --n-- ") + } yield TextAndTag(itemId, txt, tag) + } + } diff --git a/modules/store/src/main/scala/docspell/store/records/RClassifierSetting.scala b/modules/store/src/main/scala/docspell/store/records/RClassifierSetting.scala index c15f870c..680741a0 100644 --- a/modules/store/src/main/scala/docspell/store/records/RClassifierSetting.scala +++ b/modules/store/src/main/scala/docspell/store/records/RClassifierSetting.scala @@ -61,6 +61,9 @@ object RClassifierSetting { sql.update.run } + def updateFile(coll: Ident, fid: Ident): ConnectionIO[Int] = + updateRow(table, cid.is(coll), fileId.setTo(fid)).update.run + def updateSettings(v: RClassifierSetting): ConnectionIO[Int] = for { n1 <- updateRow(