mirror of
https://github.com/TheAnachronism/docspell.git
synced 2025-06-22 02:18:26 +00:00
Extend guessing tags to all tag categories
This commit is contained in:
@ -169,7 +169,7 @@ object JoexAppImpl {
|
||||
.withTask(
|
||||
JobTask.json(
|
||||
LearnClassifierArgs.taskName,
|
||||
LearnClassifierTask[F](cfg.textAnalysis, blocker, analyser),
|
||||
LearnClassifierTask[F](cfg.textAnalysis, analyser),
|
||||
LearnClassifierTask.onCancel[F]
|
||||
)
|
||||
)
|
||||
|
@ -0,0 +1,45 @@
|
||||
package docspell.joex.learn
|
||||
|
||||
import cats.data.NonEmptyList
|
||||
import cats.implicits._
|
||||
import docspell.common.Ident
|
||||
import docspell.store.records.{RClassifierModel, RTag}
|
||||
import doobie._
|
||||
|
||||
final class ClassifierName(val name: String) extends AnyVal
|
||||
|
||||
object ClassifierName {
|
||||
def apply(name: String): ClassifierName =
|
||||
new ClassifierName(name)
|
||||
|
||||
val noCategory: ClassifierName =
|
||||
apply("__docspell_no_category__")
|
||||
|
||||
val categoryPrefix = "tagcategory-"
|
||||
|
||||
def tagCategory(cat: String): ClassifierName =
|
||||
apply(s"${categoryPrefix}${cat}")
|
||||
|
||||
val concernedPerson: ClassifierName =
|
||||
apply("concernedperson")
|
||||
|
||||
val concernedEquip: ClassifierName =
|
||||
apply("concernedequip")
|
||||
|
||||
val correspondentOrg: ClassifierName =
|
||||
apply("correspondentorg")
|
||||
|
||||
val correspondentPerson: ClassifierName =
|
||||
apply("correspondentperson")
|
||||
|
||||
def findTagModels[F[_]](coll: Ident): ConnectionIO[List[RClassifierModel]] =
|
||||
for {
|
||||
categories <- RTag.listCategories(coll, noCategory.name)
|
||||
models <- NonEmptyList.fromList(categories) match {
|
||||
case Some(nel) =>
|
||||
RClassifierModel.findAllByName(coll, nel.map(tagCategory).map(_.name))
|
||||
case None =>
|
||||
List.empty[RClassifierModel].pure[ConnectionIO]
|
||||
}
|
||||
} yield models
|
||||
}
|
@ -4,23 +4,16 @@ import cats.data.Kleisli
|
||||
import cats.data.OptionT
|
||||
import cats.effect._
|
||||
import cats.implicits._
|
||||
import fs2.{Pipe, Stream}
|
||||
|
||||
import docspell.analysis.TextAnalyser
|
||||
import docspell.analysis.classifier.ClassifierModel
|
||||
import docspell.analysis.classifier.TextClassifier.Data
|
||||
import docspell.backend.ops.OCollective
|
||||
import docspell.common._
|
||||
import docspell.joex.Config
|
||||
import docspell.joex.scheduler._
|
||||
import docspell.store.queries.QItem
|
||||
import docspell.store.records.RClassifierSetting
|
||||
|
||||
import bitpeace.MimetypeHint
|
||||
import docspell.store.records.{RClassifierSetting, RTag}
|
||||
|
||||
object LearnClassifierTask {
|
||||
val noClass = "__NONE__"
|
||||
val pageSep = " --n-- "
|
||||
val noClass = "__NONE__"
|
||||
|
||||
type Args = LearnClassifierArgs
|
||||
|
||||
@ -29,67 +22,53 @@ object LearnClassifierTask {
|
||||
|
||||
def apply[F[_]: Sync: ContextShift](
|
||||
cfg: Config.TextAnalysis,
|
||||
blocker: Blocker,
|
||||
analyser: TextAnalyser[F]
|
||||
): Task[F, Args, Unit] =
|
||||
Task { ctx =>
|
||||
(for {
|
||||
sett <- findActiveSettings[F](ctx, cfg)
|
||||
data = selectItems(
|
||||
ctx,
|
||||
math.min(cfg.classification.itemCount, sett.itemCount).toLong,
|
||||
sett.category.getOrElse("")
|
||||
)
|
||||
maxItems = math.min(cfg.classification.itemCount, sett.itemCount)
|
||||
_ <- OptionT.liftF(
|
||||
analyser.classifier
|
||||
.trainClassifier[Unit](ctx.logger, data)(Kleisli(handleModel(ctx, blocker)))
|
||||
learnAllTagCategories(analyser)(ctx.args.collective, maxItems).run(ctx)
|
||||
)
|
||||
} yield ())
|
||||
.getOrElseF(logInactiveWarning(ctx.logger))
|
||||
}
|
||||
|
||||
private def handleModel[F[_]: Sync: ContextShift](
|
||||
ctx: Context[F, Args],
|
||||
blocker: Blocker
|
||||
)(trainedModel: ClassifierModel): F[Unit] =
|
||||
for {
|
||||
oldFile <- ctx.store.transact(
|
||||
RClassifierSetting.findById(ctx.args.collective).map(_.flatMap(_.fileId))
|
||||
)
|
||||
_ <- ctx.logger.info("Storing new trained model")
|
||||
fileData = fs2.io.file.readAll(trainedModel.model, blocker, 4096)
|
||||
newFile <-
|
||||
ctx.store.bitpeace.saveNew(fileData, 4096, MimetypeHint.none).compile.lastOrError
|
||||
_ <- ctx.store.transact(
|
||||
RClassifierSetting.updateFile(ctx.args.collective, Ident.unsafe(newFile.id))
|
||||
)
|
||||
_ <- ctx.logger.debug(s"New model stored at file ${newFile.id}")
|
||||
_ <- oldFile match {
|
||||
case Some(fid) =>
|
||||
ctx.logger.debug(s"Deleting old model file ${fid.id}") *>
|
||||
ctx.store.bitpeace.delete(fid.id).compile.drain
|
||||
case None => ().pure[F]
|
||||
}
|
||||
} yield ()
|
||||
|
||||
private def selectItems[F[_]](
|
||||
ctx: Context[F, Args],
|
||||
max: Long,
|
||||
def learnTagCategory[F[_]: Sync: ContextShift, A](
|
||||
analyser: TextAnalyser[F],
|
||||
collective: Ident,
|
||||
maxItems: Int
|
||||
)(
|
||||
category: String
|
||||
): Stream[F, Data] = {
|
||||
val connStream =
|
||||
for {
|
||||
item <- QItem.findAllNewesFirst(ctx.args.collective, 10).through(restrictTo(max))
|
||||
tt <- Stream.eval(
|
||||
QItem.resolveTextAndTag(ctx.args.collective, item, category, pageSep)
|
||||
): Task[F, A, Unit] =
|
||||
Task { ctx =>
|
||||
val data = SelectItems.forCategory(ctx, collective)(maxItems, category)
|
||||
ctx.logger.info(s"Learn classifier for tag category: $category") *>
|
||||
analyser.classifier.trainClassifier(ctx.logger, data)(
|
||||
Kleisli(
|
||||
StoreClassifierModel.handleModel(
|
||||
ctx,
|
||||
collective,
|
||||
ClassifierName.tagCategory(category)
|
||||
)
|
||||
)
|
||||
)
|
||||
} yield Data(tt.tag.map(_.name).getOrElse(noClass), item.id, tt.text.trim)
|
||||
ctx.store.transact(connStream.filter(_.text.nonEmpty))
|
||||
}
|
||||
}
|
||||
|
||||
private def restrictTo[F[_], A](max: Long): Pipe[F, A, A] =
|
||||
if (max <= 0) identity
|
||||
else _.take(max)
|
||||
def learnAllTagCategories[F[_]: Sync: ContextShift, A](analyser: TextAnalyser[F])(
|
||||
collective: Ident,
|
||||
maxItems: Int
|
||||
): Task[F, A, Unit] =
|
||||
Task { ctx =>
|
||||
for {
|
||||
cats <- ctx.store.transact(
|
||||
RTag.listCategories(collective, ClassifierName.noCategory.name)
|
||||
)
|
||||
task = learnTagCategory[F, A](analyser, collective, maxItems) _
|
||||
_ <- cats.map(task).traverse(_.run(ctx))
|
||||
} yield ()
|
||||
}
|
||||
|
||||
private def findActiveSettings[F[_]: Sync](
|
||||
ctx: Context[F, Args],
|
||||
@ -98,7 +77,6 @@ object LearnClassifierTask {
|
||||
if (cfg.classification.enabled)
|
||||
OptionT(ctx.store.transact(RClassifierSetting.findById(ctx.args.collective)))
|
||||
.filter(_.enabled)
|
||||
.filter(_.category.nonEmpty)
|
||||
.map(OCollective.Classifier.fromRecord)
|
||||
else
|
||||
OptionT.none
|
||||
|
@ -0,0 +1,39 @@
|
||||
package docspell.joex.learn
|
||||
|
||||
import fs2.Stream
|
||||
|
||||
import docspell.analysis.classifier.TextClassifier.Data
|
||||
import docspell.common._
|
||||
import docspell.joex.scheduler.Context
|
||||
import docspell.store.Store
|
||||
import docspell.store.qb.Batch
|
||||
import docspell.store.queries.QItem
|
||||
|
||||
object SelectItems {
|
||||
val pageSep = LearnClassifierTask.pageSep
|
||||
val noClass = LearnClassifierTask.noClass
|
||||
|
||||
def forCategory[F[_]](ctx: Context[F, _], collective: Ident)(
|
||||
max: Int,
|
||||
category: String
|
||||
): Stream[F, Data] =
|
||||
forCategory(ctx.store, collective, max, category)
|
||||
|
||||
def forCategory[F[_]](
|
||||
store: Store[F],
|
||||
collective: Ident,
|
||||
max: Int,
|
||||
category: String
|
||||
): Stream[F, Data] = {
|
||||
val limit = if (max <= 0) Batch.all else Batch.limit(max)
|
||||
val connStream =
|
||||
for {
|
||||
item <- QItem.findAllNewesFirst(collective, 10, limit)
|
||||
tt <- Stream.eval(
|
||||
QItem.resolveTextAndTag(collective, item, category, pageSep)
|
||||
)
|
||||
} yield Data(tt.tag.map(_.name).getOrElse(noClass), item.id, tt.text.trim)
|
||||
store.transact(connStream.filter(_.text.nonEmpty))
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
package docspell.joex.learn
|
||||
|
||||
import cats.effect._
|
||||
import cats.implicits._
|
||||
|
||||
import docspell.analysis.classifier.ClassifierModel
|
||||
import docspell.common._
|
||||
import docspell.joex.scheduler._
|
||||
import docspell.store.Store
|
||||
import docspell.store.records.RClassifierModel
|
||||
|
||||
import bitpeace.MimetypeHint
|
||||
|
||||
object StoreClassifierModel {
|
||||
|
||||
def handleModel[F[_]: Sync: ContextShift](
|
||||
ctx: Context[F, _],
|
||||
collective: Ident,
|
||||
modelName: ClassifierName
|
||||
)(
|
||||
trainedModel: ClassifierModel
|
||||
): F[Unit] =
|
||||
handleModel(ctx.store, ctx.blocker, ctx.logger)(collective, modelName, trainedModel)
|
||||
|
||||
def handleModel[F[_]: Sync: ContextShift](
|
||||
store: Store[F],
|
||||
blocker: Blocker,
|
||||
logger: Logger[F]
|
||||
)(
|
||||
collective: Ident,
|
||||
modelName: ClassifierName,
|
||||
trainedModel: ClassifierModel
|
||||
): F[Unit] =
|
||||
for {
|
||||
oldFile <- store.transact(
|
||||
RClassifierModel.findByName(collective, modelName.name).map(_.map(_.fileId))
|
||||
)
|
||||
_ <- logger.debug(s"Storing new trained model for: ${modelName.name}")
|
||||
fileData = fs2.io.file.readAll(trainedModel.model, blocker, 4096)
|
||||
newFile <-
|
||||
store.bitpeace.saveNew(fileData, 4096, MimetypeHint.none).compile.lastOrError
|
||||
_ <- store.transact(
|
||||
RClassifierModel.updateFile(collective, modelName.name, Ident.unsafe(newFile.id))
|
||||
)
|
||||
_ <- logger.debug(s"New model stored at file ${newFile.id}")
|
||||
_ <- oldFile match {
|
||||
case Some(fid) =>
|
||||
logger.debug(s"Deleting old model file ${fid.id}") *>
|
||||
store.bitpeace.delete(fid.id).compile.drain
|
||||
case None => ().pure[F]
|
||||
}
|
||||
} yield ()
|
||||
}
|
@ -9,12 +9,11 @@ import docspell.analysis.{NlpSettings, TextAnalyser}
|
||||
import docspell.common._
|
||||
import docspell.joex.Config
|
||||
import docspell.joex.analysis.RegexNerFile
|
||||
import docspell.joex.learn.LearnClassifierTask
|
||||
import docspell.joex.learn.{ClassifierName, LearnClassifierTask}
|
||||
import docspell.joex.process.ItemData.AttachmentDates
|
||||
import docspell.joex.scheduler.Context
|
||||
import docspell.joex.scheduler.Task
|
||||
import docspell.store.records.RAttachmentMeta
|
||||
import docspell.store.records.RClassifierSetting
|
||||
import docspell.store.records.{RAttachmentMeta, RClassifierSetting}
|
||||
|
||||
import bitpeace.RangeDef
|
||||
|
||||
@ -42,10 +41,13 @@ object TextAnalysis {
|
||||
e <- s
|
||||
_ <- ctx.logger.info(s"Text-Analysis finished in ${e.formatExact}")
|
||||
v = t.toVector
|
||||
tag <- predictTag(ctx, cfg, item.metas, analyser.classifier).value
|
||||
classifierEnabled <- getActive(ctx, cfg)
|
||||
tag <-
|
||||
if (classifierEnabled) predictTags(ctx, cfg, item.metas, analyser.classifier)
|
||||
else List.empty[String].pure[F]
|
||||
} yield item
|
||||
.copy(metas = v.map(_._1), dateLabels = v.map(_._2))
|
||||
.appendTags(tag.toSeq)
|
||||
.appendTags(tag)
|
||||
}
|
||||
|
||||
def annotateAttachment[F[_]: Sync](
|
||||
@ -66,15 +68,29 @@ object TextAnalysis {
|
||||
} yield (rm.copy(nerlabels = labels.all.toList), AttachmentDates(rm, labels.dates))
|
||||
}
|
||||
|
||||
def predictTags[F[_]: Sync: ContextShift](
|
||||
ctx: Context[F, Args],
|
||||
cfg: Config.TextAnalysis,
|
||||
metas: Vector[RAttachmentMeta],
|
||||
classifier: TextClassifier[F]
|
||||
): F[List[String]] =
|
||||
for {
|
||||
models <- ctx.store.transact(ClassifierName.findTagModels(ctx.args.meta.collective))
|
||||
_ <- ctx.logger.debug(s"Guessing tags for ${models.size} categories")
|
||||
tags <- models
|
||||
.map(_.fileId.some)
|
||||
.traverse(predictTag(ctx, cfg, metas, classifier))
|
||||
} yield tags.flatten
|
||||
|
||||
def predictTag[F[_]: Sync: ContextShift](
|
||||
ctx: Context[F, Args],
|
||||
cfg: Config.TextAnalysis,
|
||||
metas: Vector[RAttachmentMeta],
|
||||
classifier: TextClassifier[F]
|
||||
): OptionT[F, String] =
|
||||
for {
|
||||
model <- findActiveModel(ctx, cfg)
|
||||
_ <- OptionT.liftF(ctx.logger.info(s"Guessing tag …"))
|
||||
)(modelFileId: Option[Ident]): F[Option[String]] =
|
||||
(for {
|
||||
_ <- OptionT.liftF(ctx.logger.info(s"Guessing tag for ${modelFileId} …"))
|
||||
model <- OptionT.fromOption[F](modelFileId)
|
||||
text = metas.flatMap(_.content).mkString(LearnClassifierTask.pageSep)
|
||||
modelData =
|
||||
ctx.store.bitpeace
|
||||
@ -90,20 +106,21 @@ object TextAnalysis {
|
||||
.flatMap(_ => classifier.classify(ctx.logger, ClassifierModel(modelFile), text))
|
||||
}).filter(_ != LearnClassifierTask.noClass)
|
||||
_ <- OptionT.liftF(ctx.logger.debug(s"Guessed tag: ${cls}"))
|
||||
} yield cls
|
||||
} yield cls).value
|
||||
|
||||
private def findActiveModel[F[_]: Sync](
|
||||
private def getActive[F[_]: Sync](
|
||||
ctx: Context[F, Args],
|
||||
cfg: Config.TextAnalysis
|
||||
): OptionT[F, Ident] =
|
||||
(if (cfg.classification.enabled)
|
||||
OptionT(ctx.store.transact(RClassifierSetting.findById(ctx.args.meta.collective)))
|
||||
.filter(_.enabled)
|
||||
.mapFilter(_.fileId)
|
||||
else
|
||||
OptionT.none[F, Ident]).orElse(
|
||||
OptionT.liftF(ctx.logger.info("Classification is disabled.")) *> OptionT
|
||||
.none[F, Ident]
|
||||
)
|
||||
): F[Boolean] =
|
||||
if (cfg.classification.enabled)
|
||||
ctx.store
|
||||
.transact(RClassifierSetting.findById(ctx.args.meta.collective))
|
||||
.map(_.exists(_.enabled))
|
||||
.flatTap(enabled =>
|
||||
if (enabled) ().pure[F]
|
||||
else ctx.logger.info("Classification is disabled. Check config or settings.")
|
||||
)
|
||||
else
|
||||
ctx.logger.info("Classification is disabled.") *> false.pure[F]
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user