mirror of
https://github.com/TheAnachronism/docspell.git
synced 2025-06-04 14:15:59 +00:00
Implement learning a text classifier from collective data
This commit is contained in:
parent
68bb65572b
commit
316b490008
@ -26,7 +26,7 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
|||||||
.use { dir =>
|
.use { dir =>
|
||||||
for {
|
for {
|
||||||
rawData <- writeDataFile(blocker, dir, data)
|
rawData <- writeDataFile(blocker, dir, data)
|
||||||
_ <- logger.debug(s"Learning from ${rawData.count} items.")
|
_ <- logger.info(s"Learning from ${rawData.count} items.")
|
||||||
trainData <- splitData(logger, rawData)
|
trainData <- splitData(logger, rawData)
|
||||||
scores <- cfg.classifierConfigs.traverse(m => train(logger, trainData, m))
|
scores <- cfg.classifierConfigs.traverse(m => train(logger, trainData, m))
|
||||||
sorted = scores.sortBy(-_.score)
|
sorted = scores.sortBy(-_.score)
|
||||||
@ -43,7 +43,7 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
|||||||
val cls = ColumnDataClassifier.getClassifier(
|
val cls = ColumnDataClassifier.getClassifier(
|
||||||
model.model.normalize().toAbsolutePath().toString()
|
model.model.normalize().toAbsolutePath().toString()
|
||||||
)
|
)
|
||||||
val cat = cls.classOf(cls.makeDatumFromLine(normalisedText(text)))
|
val cat = cls.classOf(cls.makeDatumFromLine("\t\t" + normalisedText(text)))
|
||||||
Option(cat)
|
Option(cat)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
|||||||
} yield res
|
} yield res
|
||||||
|
|
||||||
def splitData(logger: Logger[F], in: RawData): F[TrainData] = {
|
def splitData(logger: Logger[F], in: RawData): F[TrainData] = {
|
||||||
val nTest = (in.count * 0.25).toLong
|
val nTest = (in.count * 0.15).toLong
|
||||||
|
|
||||||
val td =
|
val td =
|
||||||
TrainData(in.file.resolveSibling("train.txt"), in.file.resolveSibling("test.txt"))
|
TrainData(in.file.resolveSibling("train.txt"), in.file.resolveSibling("test.txt"))
|
||||||
@ -106,9 +106,10 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
|||||||
counter <- Ref.of[F, Long](0L)
|
counter <- Ref.of[F, Long](0L)
|
||||||
_ <-
|
_ <-
|
||||||
data
|
data
|
||||||
.map(d => s"${d.cls}\t${d.ref}\t${normalisedText(d.text)}")
|
.filter(_.text.nonEmpty)
|
||||||
|
.map(d => s"${d.cls}\t${fixRef(d.ref)}\t${normalisedText(d.text)}")
|
||||||
.evalTap(_ => counter.update(_ + 1))
|
.evalTap(_ => counter.update(_ + 1))
|
||||||
.intersperse("\n")
|
.intersperse("\r\n")
|
||||||
.through(fs2.text.utf8Encode)
|
.through(fs2.text.utf8Encode)
|
||||||
.through(fs2.io.file.writeAll(target, blocker))
|
.through(fs2.io.file.writeAll(target, blocker))
|
||||||
.compile
|
.compile
|
||||||
@ -119,13 +120,16 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
|||||||
}
|
}
|
||||||
|
|
||||||
def normalisedText(text: String): String =
|
def normalisedText(text: String): String =
|
||||||
text.replaceAll("[\n\t]+", " ")
|
text.replaceAll("[\n\r\t]+", " ")
|
||||||
|
|
||||||
|
def fixRef(str: String): String =
|
||||||
|
str.replace('\t', '_')
|
||||||
|
|
||||||
def amendProps(
|
def amendProps(
|
||||||
trainData: TrainData,
|
trainData: TrainData,
|
||||||
props: Map[String, String]
|
props: Map[String, String]
|
||||||
): Map[String, String] =
|
): Map[String, String] =
|
||||||
prepend("2", props) ++ Map(
|
prepend("2.", props) ++ Map(
|
||||||
"trainFile" -> trainData.train.normalize().toAbsolutePath().toString(),
|
"trainFile" -> trainData.train.normalize().toAbsolutePath().toString(),
|
||||||
"testFile" -> trainData.test.normalize().toAbsolutePath().toString(),
|
"testFile" -> trainData.test.normalize().toAbsolutePath().toString(),
|
||||||
"serializeTo" -> trainData.modelFile.normalize().toAbsolutePath().toString()
|
"serializeTo" -> trainData.modelFile.normalize().toAbsolutePath().toString()
|
||||||
|
@ -3,7 +3,8 @@ package docspell.joex.learn
|
|||||||
import cats.data.Kleisli
|
import cats.data.Kleisli
|
||||||
import cats.data.OptionT
|
import cats.data.OptionT
|
||||||
import cats.effect._
|
import cats.effect._
|
||||||
import fs2.Stream
|
import cats.implicits._
|
||||||
|
import fs2.{Pipe, Stream}
|
||||||
|
|
||||||
import docspell.analysis.TextAnalyser
|
import docspell.analysis.TextAnalyser
|
||||||
import docspell.analysis.nlp.ClassifierModel
|
import docspell.analysis.nlp.ClassifierModel
|
||||||
@ -12,9 +13,13 @@ import docspell.backend.ops.OCollective
|
|||||||
import docspell.common._
|
import docspell.common._
|
||||||
import docspell.joex.Config
|
import docspell.joex.Config
|
||||||
import docspell.joex.scheduler._
|
import docspell.joex.scheduler._
|
||||||
|
import docspell.store.queries.QItem
|
||||||
import docspell.store.records.RClassifierSetting
|
import docspell.store.records.RClassifierSetting
|
||||||
|
|
||||||
|
import bitpeace.MimetypeHint
|
||||||
|
|
||||||
object LearnClassifierTask {
|
object LearnClassifierTask {
|
||||||
|
val noClass = "__NONE__"
|
||||||
|
|
||||||
type Args = LearnClassifierArgs
|
type Args = LearnClassifierArgs
|
||||||
|
|
||||||
@ -31,29 +36,58 @@ object LearnClassifierTask {
|
|||||||
sett <- findActiveSettings[F](ctx, cfg)
|
sett <- findActiveSettings[F](ctx, cfg)
|
||||||
data = selectItems(
|
data = selectItems(
|
||||||
ctx,
|
ctx,
|
||||||
math.min(cfg.classification.itemCount, sett.itemCount),
|
math.min(cfg.classification.itemCount, sett.itemCount).toLong,
|
||||||
sett.category.getOrElse("")
|
sett.category.getOrElse("")
|
||||||
)
|
)
|
||||||
_ <- OptionT.liftF(
|
_ <- OptionT.liftF(
|
||||||
analyser
|
analyser
|
||||||
.classifier(blocker)
|
.classifier(blocker)
|
||||||
.trainClassifier[Unit](ctx.logger, data)(Kleisli(handleModel(ctx)))
|
.trainClassifier[Unit](ctx.logger, data)(Kleisli(handleModel(ctx, blocker)))
|
||||||
)
|
)
|
||||||
} yield ())
|
} yield ())
|
||||||
.getOrElseF(logInactiveWarning(ctx.logger))
|
.getOrElseF(logInactiveWarning(ctx.logger))
|
||||||
}
|
}
|
||||||
|
|
||||||
private def handleModel[F[_]](
|
private def handleModel[F[_]: Sync: ContextShift](
|
||||||
ctx: Context[F, Args]
|
ctx: Context[F, Args],
|
||||||
|
blocker: Blocker
|
||||||
)(trainedModel: ClassifierModel): F[Unit] =
|
)(trainedModel: ClassifierModel): F[Unit] =
|
||||||
???
|
for {
|
||||||
|
oldFile <- ctx.store.transact(
|
||||||
|
RClassifierSetting.findById(ctx.args.collective).map(_.flatMap(_.fileId))
|
||||||
|
)
|
||||||
|
_ <- ctx.logger.info("Storing new trained model")
|
||||||
|
fileData = fs2.io.file.readAll(trainedModel.model, blocker, 4096)
|
||||||
|
newFile <-
|
||||||
|
ctx.store.bitpeace.saveNew(fileData, 4096, MimetypeHint.none).compile.lastOrError
|
||||||
|
_ <- ctx.store.transact(
|
||||||
|
RClassifierSetting.updateFile(ctx.args.collective, Ident.unsafe(newFile.id))
|
||||||
|
)
|
||||||
|
_ <- ctx.logger.debug(s"New model stored at file ${newFile.id}")
|
||||||
|
_ <- oldFile match {
|
||||||
|
case Some(fid) =>
|
||||||
|
ctx.logger.debug(s"Deleting old model file ${fid.id}") *>
|
||||||
|
ctx.store.bitpeace.delete(fid.id).compile.drain
|
||||||
|
case None => ().pure[F]
|
||||||
|
}
|
||||||
|
} yield ()
|
||||||
|
|
||||||
private def selectItems[F[_]](
|
private def selectItems[F[_]](
|
||||||
ctx: Context[F, Args],
|
ctx: Context[F, Args],
|
||||||
max: Int,
|
max: Long,
|
||||||
category: String
|
category: String
|
||||||
): Stream[F, Data] =
|
): Stream[F, Data] = {
|
||||||
???
|
val connStream =
|
||||||
|
for {
|
||||||
|
item <- QItem.findAllNewesFirst(ctx.args.collective, 10).through(restrictTo(max))
|
||||||
|
tt <- Stream.eval(QItem.resolveTextAndTag(ctx.args.collective, item, category))
|
||||||
|
} yield Data(tt.tag.map(_.name).getOrElse(noClass), item.id, tt.text.trim)
|
||||||
|
ctx.store.transact(connStream.filter(_.text.nonEmpty))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def restrictTo[F[_], A](max: Long): Pipe[F, A, A] =
|
||||||
|
if (max <= 0) identity
|
||||||
|
else _.take(max)
|
||||||
|
|
||||||
private def findActiveSettings[F[_]: Sync](
|
private def findActiveSettings[F[_]: Sync](
|
||||||
ctx: Context[F, Args],
|
ctx: Context[F, Args],
|
||||||
|
@ -67,8 +67,8 @@ trait DoobieSyntax {
|
|||||||
Fragment.const(" FROM ") ++ table ++ this.where(where)
|
Fragment.const(" FROM ") ++ table ++ this.where(where)
|
||||||
|
|
||||||
def selectDistinct(cols: Seq[Column], table: Fragment, where: Fragment): Fragment =
|
def selectDistinct(cols: Seq[Column], table: Fragment, where: Fragment): Fragment =
|
||||||
Fragment.const("SELECT DISTINCT(") ++ commas(cols.map(_.f)) ++
|
Fragment.const("SELECT DISTINCT ") ++ commas(cols.map(_.f)) ++
|
||||||
Fragment.const(") FROM ") ++ table ++ this.where(where)
|
Fragment.const(" FROM ") ++ table ++ this.where(where)
|
||||||
|
|
||||||
def selectCount(col: Column, table: Fragment, where: Fragment): Fragment =
|
def selectCount(col: Column, table: Fragment, where: Fragment): Fragment =
|
||||||
Fragment.const("SELECT COUNT(") ++ col.f ++ Fragment.const(") FROM ") ++ table ++ this
|
Fragment.const("SELECT COUNT(") ++ col.f ++ Fragment.const(") FROM ") ++ table ++ this
|
||||||
|
@ -7,6 +7,7 @@ import cats.effect.concurrent.Ref
|
|||||||
import cats.implicits._
|
import cats.implicits._
|
||||||
import fs2.Stream
|
import fs2.Stream
|
||||||
|
|
||||||
|
import docspell.common.syntax.all._
|
||||||
import docspell.common.{IdRef, _}
|
import docspell.common.{IdRef, _}
|
||||||
import docspell.store.Store
|
import docspell.store.Store
|
||||||
import docspell.store.impl.Implicits._
|
import docspell.store.impl.Implicits._
|
||||||
@ -615,4 +616,74 @@ object QItem {
|
|||||||
.query[NameAndNotes]
|
.query[NameAndNotes]
|
||||||
.streamWithChunkSize(chunkSize)
|
.streamWithChunkSize(chunkSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def findAllNewesFirst(
|
||||||
|
collective: Ident,
|
||||||
|
chunkSize: Int
|
||||||
|
): Stream[ConnectionIO, Ident] = {
|
||||||
|
val cols = Seq(RItem.Columns.id)
|
||||||
|
(selectSimple(cols, RItem.table, RItem.Columns.cid.is(collective)) ++
|
||||||
|
orderBy(RItem.Columns.created.desc))
|
||||||
|
.query[Ident]
|
||||||
|
.streamWithChunkSize(chunkSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
case class TagName(id: Ident, name: String)
|
||||||
|
case class TextAndTag(itemId: Ident, text: String, tag: Option[TagName])
|
||||||
|
|
||||||
|
def resolveTextAndTag(
|
||||||
|
collective: Ident,
|
||||||
|
itemId: Ident,
|
||||||
|
tagCategory: String
|
||||||
|
): ConnectionIO[TextAndTag] = {
|
||||||
|
val aId = RAttachment.Columns.id.prefix("a")
|
||||||
|
val aItem = RAttachment.Columns.itemId.prefix("a")
|
||||||
|
val mId = RAttachmentMeta.Columns.id.prefix("m")
|
||||||
|
val mText = RAttachmentMeta.Columns.content.prefix("m")
|
||||||
|
val tiItem = RTagItem.Columns.itemId.prefix("ti")
|
||||||
|
val tiTag = RTagItem.Columns.tagId.prefix("ti")
|
||||||
|
val tId = RTag.Columns.tid.prefix("t")
|
||||||
|
val tName = RTag.Columns.name.prefix("t")
|
||||||
|
val tCat = RTag.Columns.category.prefix("t")
|
||||||
|
val iId = RItem.Columns.id.prefix("i")
|
||||||
|
val iColl = RItem.Columns.cid.prefix("i")
|
||||||
|
|
||||||
|
val cte = withCTE(
|
||||||
|
"tags" -> selectSimple(
|
||||||
|
Seq(tiItem, tId, tName),
|
||||||
|
RTagItem.table ++ fr"ti INNER JOIN" ++
|
||||||
|
RTag.table ++ fr"t ON" ++ tId.is(tiTag),
|
||||||
|
and(tiItem.is(itemId), tCat.is(tagCategory))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
val cols = Seq(mText, tId, tName)
|
||||||
|
|
||||||
|
val from = RItem.table ++ fr"i INNER JOIN" ++
|
||||||
|
RAttachment.table ++ fr"a ON" ++ aItem.is(iId) ++ fr"INNER JOIN" ++
|
||||||
|
RAttachmentMeta.table ++ fr"m ON" ++ aId.is(mId) ++ fr"LEFT JOIN" ++
|
||||||
|
fr"tags t ON" ++ RTagItem.Columns.itemId.prefix("t").is(iId)
|
||||||
|
|
||||||
|
val where =
|
||||||
|
and(
|
||||||
|
iId.is(itemId),
|
||||||
|
iColl.is(collective),
|
||||||
|
mText.isNotNull,
|
||||||
|
mText.isNot("")
|
||||||
|
)
|
||||||
|
|
||||||
|
val q = cte ++ selectDistinct(cols, from, where)
|
||||||
|
for {
|
||||||
|
_ <- logger.ftrace[ConnectionIO](
|
||||||
|
s"query: $q (${itemId.id}, ${collective.id}, ${tagCategory})"
|
||||||
|
)
|
||||||
|
texts <- q.query[(String, Option[TagName])].to[List]
|
||||||
|
_ <- logger.ftrace[ConnectionIO](
|
||||||
|
s"Got ${texts.size} text and tag entries for item ${itemId.id}"
|
||||||
|
)
|
||||||
|
tag = texts.headOption.flatMap(_._2)
|
||||||
|
txt = texts.map(_._1).mkString(" --n-- ")
|
||||||
|
} yield TextAndTag(itemId, txt, tag)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -61,6 +61,9 @@ object RClassifierSetting {
|
|||||||
sql.update.run
|
sql.update.run
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def updateFile(coll: Ident, fid: Ident): ConnectionIO[Int] =
|
||||||
|
updateRow(table, cid.is(coll), fileId.setTo(fid)).update.run
|
||||||
|
|
||||||
def updateSettings(v: RClassifierSetting): ConnectionIO[Int] =
|
def updateSettings(v: RClassifierSetting): ConnectionIO[Int] =
|
||||||
for {
|
for {
|
||||||
n1 <- updateRow(
|
n1 <- updateRow(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user