From 0c97b4ef762f30a0dc77dacbf3f06dac56df0752 Mon Sep 17 00:00:00 2001 From: Eike Kettner Date: Mon, 31 Aug 2020 22:35:27 +0200 Subject: [PATCH] Initial impl of a text classifier based on stanford-nlp --- .../docspell/analysis/TextAnalyser.scala | 16 +- .../analysis/TextAnalysisConfig.scala | 5 +- .../analysis/nlp/ClassifierModel.scala | 5 + .../docspell/analysis/nlp/PipelineCache.scala | 12 +- .../docspell/analysis/nlp/Properties.scala | 5 +- .../analysis/nlp/StanfordNerClassifier.scala | 2 +- ...ttings.scala => StanfordNerSettings.scala} | 6 +- .../analysis/nlp/StanfordTextClassifier.scala | 149 ++++++++++++++++++ .../analysis/nlp/TextClassifier.scala | 25 +++ .../analysis/nlp/TextClassifierConfig.scala | 10 ++ .../analysis/src/test/resources/test.ser.gz | Bin 0 -> 1682 bytes .../nlp/StanfordTextClassifierSuite.scala | 76 +++++++++ .../joex/src/main/resources/reference.conf | 2 +- .../src/main/scala/docspell/joex/Config.scala | 13 +- .../joex/learn/LearnClassifierTask.scala | 64 ++++++++ .../docspell/joex/process/TextAnalysis.scala | 4 +- 16 files changed, 376 insertions(+), 18 deletions(-) create mode 100644 modules/analysis/src/main/scala/docspell/analysis/nlp/ClassifierModel.scala rename modules/analysis/src/main/scala/docspell/analysis/nlp/{StanfordSettings.scala => StanfordNerSettings.scala} (88%) create mode 100644 modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala create mode 100644 modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifier.scala create mode 100644 modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifierConfig.scala create mode 100644 modules/analysis/src/test/resources/test.ser.gz create mode 100644 modules/analysis/src/test/scala/docspell/analysis/nlp/StanfordTextClassifierSuite.scala create mode 100644 modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala diff --git a/modules/analysis/src/main/scala/docspell/analysis/TextAnalyser.scala b/modules/analysis/src/main/scala/docspell/analysis/TextAnalyser.scala index 75d07eef..44f7203b 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/TextAnalyser.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/TextAnalyser.scala @@ -7,18 +7,21 @@ import docspell.analysis.contact.Contact import docspell.analysis.date.DateFind import docspell.analysis.nlp.PipelineCache import docspell.analysis.nlp.StanfordNerClassifier -import docspell.analysis.nlp.StanfordSettings +import docspell.analysis.nlp.StanfordNerSettings +import docspell.analysis.nlp.StanfordTextClassifier +import docspell.analysis.nlp.TextClassifier import docspell.common._ trait TextAnalyser[F[_]] { def annotate( logger: Logger[F], - settings: StanfordSettings, + settings: StanfordNerSettings, cacheKey: Ident, text: String ): F[TextAnalyser.Result] + def classifier(blocker: Blocker)(implicit CS: ContextShift[F]): TextClassifier[F] } object TextAnalyser { @@ -35,7 +38,7 @@ object TextAnalyser { new TextAnalyser[F] { def annotate( logger: Logger[F], - settings: StanfordSettings, + settings: StanfordNerSettings, cacheKey: Ident, text: String ): F[TextAnalyser.Result] = @@ -48,6 +51,11 @@ object TextAnalyser { spans = NerLabelSpan.build(list) } yield Result(spans ++ list, dates) + def classifier(blocker: Blocker)(implicit + CS: ContextShift[F] + ): TextClassifier[F] = + new StanfordTextClassifier[F](cfg.classifier, blocker) + private def textLimit(logger: Logger[F], text: String): F[String] = if (text.length <= cfg.maxLength) text.pure[F] else @@ -56,7 +64,7 @@ object TextAnalyser { s" Analysing only first ${cfg.maxLength} characters." ) *> text.take(cfg.maxLength).pure[F] - private def stanfordNer(key: Ident, settings: StanfordSettings, text: String) + private def stanfordNer(key: Ident, settings: StanfordNerSettings, text: String) : F[Vector[NerLabel]] = StanfordNerClassifier.nerAnnotate[F](key.id, cache)(settings, text) diff --git a/modules/analysis/src/main/scala/docspell/analysis/TextAnalysisConfig.scala b/modules/analysis/src/main/scala/docspell/analysis/TextAnalysisConfig.scala index 577f6753..596a6247 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/TextAnalysisConfig.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/TextAnalysisConfig.scala @@ -1,5 +1,8 @@ package docspell.analysis +import docspell.analysis.nlp.TextClassifierConfig + case class TextAnalysisConfig( - maxLength: Int + maxLength: Int, + classifier: TextClassifierConfig ) diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/ClassifierModel.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/ClassifierModel.scala new file mode 100644 index 00000000..82f9f9cc --- /dev/null +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/ClassifierModel.scala @@ -0,0 +1,5 @@ +package docspell.analysis.nlp + +import java.nio.file.Path + +case class ClassifierModel(model: Path) diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/PipelineCache.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/PipelineCache.scala index 9787563f..88e13ee3 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/nlp/PipelineCache.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/PipelineCache.scala @@ -19,7 +19,7 @@ import org.log4s.getLogger */ trait PipelineCache[F[_]] { - def obtain(key: String, settings: StanfordSettings): F[StanfordCoreNLP] + def obtain(key: String, settings: StanfordNerSettings): F[StanfordCoreNLP] } @@ -28,7 +28,7 @@ object PipelineCache { def none[F[_]: Applicative]: PipelineCache[F] = new PipelineCache[F] { - def obtain(ignored: String, settings: StanfordSettings): F[StanfordCoreNLP] = + def obtain(ignored: String, settings: StanfordNerSettings): F[StanfordCoreNLP] = makeClassifier(settings).pure[F] } @@ -38,7 +38,7 @@ object PipelineCache { final private class Impl[F[_]: Sync](data: Ref[F, Map[String, Entry]]) extends PipelineCache[F] { - def obtain(key: String, settings: StanfordSettings): F[StanfordCoreNLP] = + def obtain(key: String, settings: StanfordNerSettings): F[StanfordCoreNLP] = for { id <- makeSettingsId(settings) nlp <- data.modify(cache => getOrCreate(key, id, cache, settings)) @@ -48,7 +48,7 @@ object PipelineCache { key: String, id: String, cache: Map[String, Entry], - settings: StanfordSettings + settings: StanfordNerSettings ): (Map[String, Entry], StanfordCoreNLP) = cache.get(key) match { case Some(entry) => @@ -68,7 +68,7 @@ object PipelineCache { (cache.updated(key, e), nlp) } - private def makeSettingsId(settings: StanfordSettings): F[String] = { + private def makeSettingsId(settings: StanfordNerSettings): F[String] = { val base = settings.copy(regexNer = None).toString val size: F[Long] = settings.regexNer match { @@ -81,7 +81,7 @@ object PipelineCache { } } - private def makeClassifier(settings: StanfordSettings): StanfordCoreNLP = { + private def makeClassifier(settings: StanfordNerSettings): StanfordCoreNLP = { logger.info(s"Creating ${settings.lang.name} Stanford NLP NER classifier...") new StanfordCoreNLP(Properties.forSettings(settings)) } diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/Properties.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/Properties.scala index 314f04fb..46a614d1 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/nlp/Properties.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/Properties.scala @@ -7,6 +7,9 @@ import docspell.common._ object Properties { + def fromMap(m: Map[String, String]): JProps = + apply(m.toSeq: _*) + def apply(ps: (String, String)*): JProps = { val p = new JProps() for ((k, v) <- ps) @@ -14,7 +17,7 @@ object Properties { p } - def forSettings(settings: StanfordSettings): JProps = { + def forSettings(settings: StanfordNerSettings): JProps = { val regexNerFile = settings.regexNer .map(p => p.normalize().toAbsolutePath().toString()) settings.lang match { diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerClassifier.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerClassifier.scala index 424396e5..383a07ea 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerClassifier.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerClassifier.scala @@ -25,7 +25,7 @@ object StanfordNerClassifier { def nerAnnotate[F[_]: Applicative]( cacheKey: String, cache: PipelineCache[F] - )(settings: StanfordSettings, text: String): F[Vector[NerLabel]] = + )(settings: StanfordNerSettings, text: String): F[Vector[NerLabel]] = cache .obtain(cacheKey, settings) .map(crf => runClassifier(crf, text)) diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordSettings.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerSettings.scala similarity index 88% rename from modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordSettings.scala rename to modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerSettings.scala index c2f6f98c..06136a18 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordSettings.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerSettings.scala @@ -19,4 +19,8 @@ import docspell.common._ * as a last step to tag untagged tokens using the provided list of * regexps. */ -case class StanfordSettings(lang: Language, highRecall: Boolean, regexNer: Option[Path]) +case class StanfordNerSettings( + lang: Language, + highRecall: Boolean, + regexNer: Option[Path] +) diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala new file mode 100644 index 00000000..3da3b5ba --- /dev/null +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala @@ -0,0 +1,149 @@ +package docspell.analysis.nlp + +import java.nio.file.Path + +import cats.effect._ +import cats.effect.concurrent.Ref +import cats.implicits._ +import fs2.Stream + +import docspell.analysis.nlp.TextClassifier._ +import docspell.common._ + +import edu.stanford.nlp.classify.ColumnDataClassifier + +final class StanfordTextClassifier[F[_]: Sync: ContextShift]( + cfg: TextClassifierConfig, + blocker: Blocker +) extends TextClassifier[F] { + + def trainClassifier[A]( + logger: Logger[F], + data: Stream[F, Data] + )(handler: TextClassifier.Handler[F, A]): F[A] = + File + .withTempDir(cfg.workingDir, "trainclassifier") + .use { dir => + for { + rawData <- writeDataFile(blocker, dir, data) + _ <- logger.debug(s"Learning from ${rawData.count} items.") + trainData <- splitData(logger, rawData) + scores <- cfg.classifierConfigs.traverse(m => train(logger, trainData, m)) + sorted = scores.sortBy(-_.score) + res <- handler(sorted.head.model) + } yield res + } + + def classify( + logger: Logger[F], + model: ClassifierModel, + text: String + ): F[Option[String]] = + Sync[F].delay { + val cls = ColumnDataClassifier.getClassifier( + model.model.normalize().toAbsolutePath().toString() + ) + val cat = cls.classOf(cls.makeDatumFromLine(normalisedText(text))) + Option(cat) + } + + // --- helpers + + def train( + logger: Logger[F], + in: TrainData, + props: Map[String, String] + ): F[TrainResult] = + for { + _ <- logger.debug(s"Training classifier from $props") + res <- Sync[F].delay { + val cdc = new ColumnDataClassifier(Properties.fromMap(amendProps(in, props))) + cdc.trainClassifier(in.train.toString()) + val score = cdc.testClassifier(in.test.toString()) + TrainResult(score.first(), ClassifierModel(in.modelFile)) + } + _ <- logger.debug(s"Trained with result $res") + } yield res + + def splitData(logger: Logger[F], in: RawData): F[TrainData] = { + val nTest = (in.count * 0.25).toLong + + val td = + TrainData(in.file.resolveSibling("train.txt"), in.file.resolveSibling("test.txt")) + + val fileLines = + fs2.io.file + .readAll(in.file, blocker, 4096) + .through(fs2.text.utf8Decode) + .through(fs2.text.lines) + + for { + _ <- logger.debug( + s"Splitting raw data into test/train data. Testing with $nTest entries" + ) + _ <- + fileLines + .take(nTest) + .intersperse("\n") + .through(fs2.text.utf8Encode) + .through(fs2.io.file.writeAll(td.test, blocker)) + .compile + .drain + _ <- + fileLines + .drop(nTest) + .intersperse("\n") + .through(fs2.text.utf8Encode) + .through(fs2.io.file.writeAll(td.train, blocker)) + .compile + .drain + } yield td + } + + def writeDataFile(blocker: Blocker, dir: Path, data: Stream[F, Data]): F[RawData] = { + val target = dir.resolve("rawdata") + for { + counter <- Ref.of[F, Long](0L) + _ <- + data + .map(d => s"${d.cls}\t${d.ref}\t${normalisedText(d.text)}") + .evalTap(_ => counter.update(_ + 1)) + .intersperse("\n") + .through(fs2.text.utf8Encode) + .through(fs2.io.file.writeAll(target, blocker)) + .compile + .drain + lines <- counter.get + } yield RawData(lines, target) + + } + + def normalisedText(text: String): String = + text.replaceAll("[\n\t]+", " ") + + def amendProps( + trainData: TrainData, + props: Map[String, String] + ): Map[String, String] = + prepend("2", props) ++ Map( + "trainFile" -> trainData.train.normalize().toAbsolutePath().toString(), + "testFile" -> trainData.test.normalize().toAbsolutePath().toString(), + "serializeTo" -> trainData.modelFile.normalize().toAbsolutePath().toString() + ).toList + + case class RawData(count: Long, file: Path) + case class TrainData(train: Path, test: Path) { + val modelFile = train.resolveSibling("model.ser.gz") + } + + case class TrainResult(score: Double, model: ClassifierModel) + + def prepend(pre: String, data: Map[String, String]): Map[String, String] = + data.toList + .map({ + case (k, v) => + if (k.startsWith(pre)) (k, v) + else (pre + k, v) + }) + .toMap +} diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifier.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifier.scala new file mode 100644 index 00000000..f2927d0c --- /dev/null +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifier.scala @@ -0,0 +1,25 @@ +package docspell.analysis.nlp + +import cats.data.Kleisli +import fs2.Stream + +import docspell.analysis.nlp.TextClassifier.Data +import docspell.common._ + +trait TextClassifier[F[_]] { + + def trainClassifier[A](logger: Logger[F], data: Stream[F, Data])( + handler: TextClassifier.Handler[F, A] + ): F[A] + + def classify(logger: Logger[F], model: ClassifierModel, text: String): F[Option[String]] + +} + +object TextClassifier { + + type Handler[F[_], A] = Kleisli[F, ClassifierModel, A] + + case class Data(cls: String, ref: String, text: String) + +} diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifierConfig.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifierConfig.scala new file mode 100644 index 00000000..e3baac46 --- /dev/null +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifierConfig.scala @@ -0,0 +1,10 @@ +package docspell.analysis.nlp + +import java.nio.file.Path + +import cats.data.NonEmptyList + +case class TextClassifierConfig( + workingDir: Path, + classifierConfigs: NonEmptyList[Map[String, String]] +) diff --git a/modules/analysis/src/test/resources/test.ser.gz b/modules/analysis/src/test/resources/test.ser.gz new file mode 100644 index 0000000000000000000000000000000000000000..b6d0956ba0f2100bc670502e77717ac428688a9d GIT binary patch literal 1682 zcmV;D25tEtiwFP!000000G(G|h#W-_uG{2t|NlSvAr}RsVnTKaY7`E<+}*m#hRNlU zI}_OzN#9KGZs%ruCfz-^dl8hNMkVn<(FY{~1q~4s6<H7Jqx~jUm_TsP5ECp<%mY0-@@PZ?(B)8q9X%i_~VKHeiPKlTd zmW>j~ACtzusF(Kywb(F|MfK26PrEQJ$b#ZqcfO5d#5J5nu`dQ=VdW%4d%n&~B4C zYWu6>@FXPfC`Lp4pU#@{f1~={yD??1>h%B_#_FQ$x zc7y*^TJft_zf}=!AVwP(f-Gvkswcl1dF={0w+5QBuvW4Y6fJAbwUO4ASS$84S@ZPk zJsYq5{lfX*63}EIz#=yl+w`ZD=wDA2fnRSlh1nH-;!N?;%ldeozq)-4H#V>CqJkqeXF*r6X2IzY`> zEs7%6uK$2_pu)+VP3yt#{qFsbFO{8lv_m?~O;Q(ZGx$I4Vb(fpSxYEr5!9s2RkdQU z2`?tIcu5k<(22CICb&Wk&aisKV7q`bP+E&%%r~I zjX|pQYp1sA=RTpAoAuX*fYt4P_i24k=tK3;`7?v>nz=0zYyy@&D5ekFl}jB~` z_oHyxE6;eY>j;&G;gEjB?`X5psn0=cf!UOkGR;D75h>$bR;UO($wGHgIG&4U146Cc z=Ae_wLzLM5W)Y_4m?srRFAdH1PC<)lX!U=~q*{Ar2-sz_GE>1OsLVm1?G$HRo2et1 zhe^ld3ZapP+vO5-Q_ev*E}5z$wAl{2XkyTF&{KPlqTGol{da@%B5n$o8oQ3viJmk_ zFuqJ;&=t7PIQ`QESveR9$c%PJmq2X z_Cbd%d4&R=cB(@-5U!$3>k7nhN`^tyTvp`hj7iXgAUmkjEgNE&ts~5oh>HQ*R3peE zOqHf*Cqf^n!0M65MI6~QUI&g%acQDryC>KFGOH_@$Y@5!)I?xCD2E&j z;rq0p1y-HG=(1`=*Rltv0vY z{gR<#G!I=t=hDTmO8)edHDIuT7z!&$-VAlN;-p!ba(x(x9fJ=AIE|;=zZdSwLuY-@ z)U<>F1NF2JN3*sWFl5%wWH7=kj*B8X+#d`N`u<7<(L}@Trq-P-5#euXG7sG`1*aNZ^T-^g5MB~4$VGHHKM<-b^d+b1E1aqr-LPY z^bJ1t%!wcNyf^s7^64QKiJsk0?OsJFlnZ6`5YmJnOVgVg0mg3BE8u$5u4~sfPwuEg zYH9>^!^auhY@#ax3A|ehdIG_J c3vL|H?ZIe-e@`Bh+e5s60M;VBa!(Ba0MQ0Ep#T5? literal 0 HcmV?d00001 diff --git a/modules/analysis/src/test/scala/docspell/analysis/nlp/StanfordTextClassifierSuite.scala b/modules/analysis/src/test/scala/docspell/analysis/nlp/StanfordTextClassifierSuite.scala new file mode 100644 index 00000000..b9596923 --- /dev/null +++ b/modules/analysis/src/test/scala/docspell/analysis/nlp/StanfordTextClassifierSuite.scala @@ -0,0 +1,76 @@ +package docspell.analysis.nlp + +import minitest._ +import cats.effect._ +import scala.concurrent.ExecutionContext +import java.nio.file.Paths +import cats.data.NonEmptyList +import docspell.common._ +import fs2.Stream +import cats.data.Kleisli +import TextClassifier.Data + +object StanfordTextClassifierSuite extends SimpleTestSuite { + val logger = Logger.log4s[IO](org.log4s.getLogger) + + implicit val CS = IO.contextShift(ExecutionContext.global) + + test("learn from data") { + val cfg = TextClassifierConfig(Paths.get("target"), NonEmptyList.of(Map())) + + val data = + Stream + .emit(Data("invoice", "n", "this is your invoice total $421")) + .repeat + .take(10) + .zip( + Stream + .emit(Data("receipt", "n", "shopping receipt cheese cake bar")) + .repeat + .take(10) + ) + .flatMap({ + case (a, b) => + Stream.emits(Seq(a, b)) + }) + .covary[IO] + + val modelExists = + Blocker[IO].use { blocker => + val classifier = new StanfordTextClassifier[IO](cfg, blocker) + classifier.trainClassifier[Boolean](logger, data)( + Kleisli(result => File.existsNonEmpty[IO](result.model)) + ) + } + assertEquals(modelExists.unsafeRunSync(), true) + } + + test("run classifier") { + val cfg = TextClassifierConfig(Paths.get("target"), NonEmptyList.of(Map())) + val things = for { + dir <- File.withTempDir[IO](Paths.get("target"), "testcls") + blocker <- Blocker[IO] + } yield (dir, blocker) + + things + .use { + case (dir, blocker) => + val classifier = new StanfordTextClassifier[IO](cfg, blocker) + + val modelFile = dir.resolve("test.ser.gz") + for { + _ <- + LenientUri + .fromJava(getClass.getResource("/test.ser.gz")) + .readURL[IO](4096, blocker) + .through(fs2.io.file.writeAll(modelFile, blocker)) + .compile + .drain + model = ClassifierModel(modelFile) + cat <- classifier.classify(logger, model, "there is receipt always") + _ = assertEquals(cat, Some("receipt")) + } yield () + } + .unsafeRunSync() + } +} diff --git a/modules/joex/src/main/resources/reference.conf b/modules/joex/src/main/resources/reference.conf index 746f7bac..e09bfd3b 100644 --- a/modules/joex/src/main/resources/reference.conf +++ b/modules/joex/src/main/resources/reference.conf @@ -298,7 +298,7 @@ docspell.joex { # These settings are used to configure the classifier. If # multiple are given, they are all tried and the "best" is # chosen at the end. See - # https://nlp.stanford.edu/wiki/Software/Classifier/20_Newsgroups + # https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/classify/ColumnDataClassifier.html # for more info about these settings. The settings are almost # identical to them, as they yielded best results with *my* # dataset. diff --git a/modules/joex/src/main/scala/docspell/joex/Config.scala b/modules/joex/src/main/scala/docspell/joex/Config.scala index a90ad61a..cbbb4a33 100644 --- a/modules/joex/src/main/scala/docspell/joex/Config.scala +++ b/modules/joex/src/main/scala/docspell/joex/Config.scala @@ -2,7 +2,10 @@ package docspell.joex import java.nio.file.Path +import cats.data.NonEmptyList + import docspell.analysis.TextAnalysisConfig +import docspell.analysis.nlp.TextClassifierConfig import docspell.backend.Config.Files import docspell.common._ import docspell.convert.ConvertConfig @@ -62,7 +65,15 @@ object Config { ) { def textAnalysisConfig: TextAnalysisConfig = - TextAnalysisConfig(maxLength) + TextAnalysisConfig( + maxLength, + TextClassifierConfig( + workingDir, + NonEmptyList + .fromList(classification.classifiers) + .getOrElse(NonEmptyList.of(Map.empty)) + ) + ) def regexNerFileConfig: RegexNerFile.Config = RegexNerFile.Config(regexNer.enabled, workingDir, regexNer.fileCacheTime) diff --git a/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala b/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala new file mode 100644 index 00000000..a161417a --- /dev/null +++ b/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala @@ -0,0 +1,64 @@ +package docspell.joex.learn + +import cats.data.Kleisli +import cats.data.OptionT +import cats.effect._ +import fs2.Stream + +import docspell.analysis.TextAnalyser +import docspell.analysis.nlp.ClassifierModel +import docspell.analysis.nlp.TextClassifier.Data +import docspell.backend.ops.OCollective +import docspell.common._ +import docspell.joex.Config +import docspell.joex.scheduler._ + +object LearnClassifierTask { + + type Args = LearnClassifierArgs + + def apply[F[_]: Sync: ContextShift]( + cfg: Config.TextAnalysis, + blocker: Blocker, + analyser: TextAnalyser[F] + ): Task[F, Args, Unit] = + Task { ctx => + (for { + sett <- findActiveSettings[F](ctx.args.collective, cfg) + data = selectItems( + ctx, + math.min(cfg.classification.itemCount, sett.itemCount), + sett.category.getOrElse("") + ) + _ <- OptionT.liftF( + analyser + .classifier(blocker) + .trainClassifier[Unit](ctx.logger, data)(Kleisli(handleModel(ctx))) + ) + } yield ()) + .getOrElseF(logInactiveWarning(ctx.logger)) + } + + private def handleModel[F[_]]( + ctx: Context[F, Args] + )(trainedModel: ClassifierModel): F[Unit] = + ??? + + private def selectItems[F[_]]( + ctx: Context[F, Args], + max: Int, + category: String + ): Stream[F, Data] = + ??? + + private def findActiveSettings[F[_]: Sync]( + coll: Ident, + cfg: Config.TextAnalysis + ): OptionT[F, OCollective.Classifier] = + ??? + + private def logInactiveWarning[F[_]: Sync](logger: Logger[F]): F[Unit] = + logger.warn( + "Classification is disabled. Check joex config and the collective settings." + ) +} diff --git a/modules/joex/src/main/scala/docspell/joex/process/TextAnalysis.scala b/modules/joex/src/main/scala/docspell/joex/process/TextAnalysis.scala index abbb6870..92975a70 100644 --- a/modules/joex/src/main/scala/docspell/joex/process/TextAnalysis.scala +++ b/modules/joex/src/main/scala/docspell/joex/process/TextAnalysis.scala @@ -4,7 +4,7 @@ import cats.effect._ import cats.implicits._ import docspell.analysis.TextAnalyser -import docspell.analysis.nlp.StanfordSettings +import docspell.analysis.nlp.StanfordNerSettings import docspell.common._ import docspell.joex.analysis.RegexNerFile import docspell.joex.process.ItemData.AttachmentDates @@ -42,7 +42,7 @@ object TextAnalysis { analyser: TextAnalyser[F], nerFile: RegexNerFile[F] )(rm: RAttachmentMeta): F[(RAttachmentMeta, AttachmentDates)] = { - val settings = StanfordSettings(ctx.args.meta.language, false, None) + val settings = StanfordNerSettings(ctx.args.meta.language, false, None) for { customNer <- nerFile.makeFile(ctx.args.meta.collective) sett = settings.copy(regexNer = customNer)