diff --git a/modules/analysis/src/main/scala/docspell/analysis/TextAnalyser.scala b/modules/analysis/src/main/scala/docspell/analysis/TextAnalyser.scala index 75d07eef..44f7203b 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/TextAnalyser.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/TextAnalyser.scala @@ -7,18 +7,21 @@ import docspell.analysis.contact.Contact import docspell.analysis.date.DateFind import docspell.analysis.nlp.PipelineCache import docspell.analysis.nlp.StanfordNerClassifier -import docspell.analysis.nlp.StanfordSettings +import docspell.analysis.nlp.StanfordNerSettings +import docspell.analysis.nlp.StanfordTextClassifier +import docspell.analysis.nlp.TextClassifier import docspell.common._ trait TextAnalyser[F[_]] { def annotate( logger: Logger[F], - settings: StanfordSettings, + settings: StanfordNerSettings, cacheKey: Ident, text: String ): F[TextAnalyser.Result] + def classifier(blocker: Blocker)(implicit CS: ContextShift[F]): TextClassifier[F] } object TextAnalyser { @@ -35,7 +38,7 @@ object TextAnalyser { new TextAnalyser[F] { def annotate( logger: Logger[F], - settings: StanfordSettings, + settings: StanfordNerSettings, cacheKey: Ident, text: String ): F[TextAnalyser.Result] = @@ -48,6 +51,11 @@ object TextAnalyser { spans = NerLabelSpan.build(list) } yield Result(spans ++ list, dates) + def classifier(blocker: Blocker)(implicit + CS: ContextShift[F] + ): TextClassifier[F] = + new StanfordTextClassifier[F](cfg.classifier, blocker) + private def textLimit(logger: Logger[F], text: String): F[String] = if (text.length <= cfg.maxLength) text.pure[F] else @@ -56,7 +64,7 @@ object TextAnalyser { s" Analysing only first ${cfg.maxLength} characters." ) *> text.take(cfg.maxLength).pure[F] - private def stanfordNer(key: Ident, settings: StanfordSettings, text: String) + private def stanfordNer(key: Ident, settings: StanfordNerSettings, text: String) : F[Vector[NerLabel]] = StanfordNerClassifier.nerAnnotate[F](key.id, cache)(settings, text) diff --git a/modules/analysis/src/main/scala/docspell/analysis/TextAnalysisConfig.scala b/modules/analysis/src/main/scala/docspell/analysis/TextAnalysisConfig.scala index 577f6753..596a6247 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/TextAnalysisConfig.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/TextAnalysisConfig.scala @@ -1,5 +1,8 @@ package docspell.analysis +import docspell.analysis.nlp.TextClassifierConfig + case class TextAnalysisConfig( - maxLength: Int + maxLength: Int, + classifier: TextClassifierConfig ) diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/ClassifierModel.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/ClassifierModel.scala new file mode 100644 index 00000000..82f9f9cc --- /dev/null +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/ClassifierModel.scala @@ -0,0 +1,5 @@ +package docspell.analysis.nlp + +import java.nio.file.Path + +case class ClassifierModel(model: Path) diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/PipelineCache.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/PipelineCache.scala index 9787563f..88e13ee3 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/nlp/PipelineCache.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/PipelineCache.scala @@ -19,7 +19,7 @@ import org.log4s.getLogger */ trait PipelineCache[F[_]] { - def obtain(key: String, settings: StanfordSettings): F[StanfordCoreNLP] + def obtain(key: String, settings: StanfordNerSettings): F[StanfordCoreNLP] } @@ -28,7 +28,7 @@ object PipelineCache { def none[F[_]: Applicative]: PipelineCache[F] = new PipelineCache[F] { - def obtain(ignored: String, settings: StanfordSettings): F[StanfordCoreNLP] = + def obtain(ignored: String, settings: StanfordNerSettings): F[StanfordCoreNLP] = makeClassifier(settings).pure[F] } @@ -38,7 +38,7 @@ object PipelineCache { final private class Impl[F[_]: Sync](data: Ref[F, Map[String, Entry]]) extends PipelineCache[F] { - def obtain(key: String, settings: StanfordSettings): F[StanfordCoreNLP] = + def obtain(key: String, settings: StanfordNerSettings): F[StanfordCoreNLP] = for { id <- makeSettingsId(settings) nlp <- data.modify(cache => getOrCreate(key, id, cache, settings)) @@ -48,7 +48,7 @@ object PipelineCache { key: String, id: String, cache: Map[String, Entry], - settings: StanfordSettings + settings: StanfordNerSettings ): (Map[String, Entry], StanfordCoreNLP) = cache.get(key) match { case Some(entry) => @@ -68,7 +68,7 @@ object PipelineCache { (cache.updated(key, e), nlp) } - private def makeSettingsId(settings: StanfordSettings): F[String] = { + private def makeSettingsId(settings: StanfordNerSettings): F[String] = { val base = settings.copy(regexNer = None).toString val size: F[Long] = settings.regexNer match { @@ -81,7 +81,7 @@ object PipelineCache { } } - private def makeClassifier(settings: StanfordSettings): StanfordCoreNLP = { + private def makeClassifier(settings: StanfordNerSettings): StanfordCoreNLP = { logger.info(s"Creating ${settings.lang.name} Stanford NLP NER classifier...") new StanfordCoreNLP(Properties.forSettings(settings)) } diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/Properties.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/Properties.scala index 314f04fb..46a614d1 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/nlp/Properties.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/Properties.scala @@ -7,6 +7,9 @@ import docspell.common._ object Properties { + def fromMap(m: Map[String, String]): JProps = + apply(m.toSeq: _*) + def apply(ps: (String, String)*): JProps = { val p = new JProps() for ((k, v) <- ps) @@ -14,7 +17,7 @@ object Properties { p } - def forSettings(settings: StanfordSettings): JProps = { + def forSettings(settings: StanfordNerSettings): JProps = { val regexNerFile = settings.regexNer .map(p => p.normalize().toAbsolutePath().toString()) settings.lang match { diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerClassifier.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerClassifier.scala index 424396e5..383a07ea 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerClassifier.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerClassifier.scala @@ -25,7 +25,7 @@ object StanfordNerClassifier { def nerAnnotate[F[_]: Applicative]( cacheKey: String, cache: PipelineCache[F] - )(settings: StanfordSettings, text: String): F[Vector[NerLabel]] = + )(settings: StanfordNerSettings, text: String): F[Vector[NerLabel]] = cache .obtain(cacheKey, settings) .map(crf => runClassifier(crf, text)) diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordSettings.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerSettings.scala similarity index 88% rename from modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordSettings.scala rename to modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerSettings.scala index c2f6f98c..06136a18 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordSettings.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordNerSettings.scala @@ -19,4 +19,8 @@ import docspell.common._ * as a last step to tag untagged tokens using the provided list of * regexps. */ -case class StanfordSettings(lang: Language, highRecall: Boolean, regexNer: Option[Path]) +case class StanfordNerSettings( + lang: Language, + highRecall: Boolean, + regexNer: Option[Path] +) diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala new file mode 100644 index 00000000..3da3b5ba --- /dev/null +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/StanfordTextClassifier.scala @@ -0,0 +1,149 @@ +package docspell.analysis.nlp + +import java.nio.file.Path + +import cats.effect._ +import cats.effect.concurrent.Ref +import cats.implicits._ +import fs2.Stream + +import docspell.analysis.nlp.TextClassifier._ +import docspell.common._ + +import edu.stanford.nlp.classify.ColumnDataClassifier + +final class StanfordTextClassifier[F[_]: Sync: ContextShift]( + cfg: TextClassifierConfig, + blocker: Blocker +) extends TextClassifier[F] { + + def trainClassifier[A]( + logger: Logger[F], + data: Stream[F, Data] + )(handler: TextClassifier.Handler[F, A]): F[A] = + File + .withTempDir(cfg.workingDir, "trainclassifier") + .use { dir => + for { + rawData <- writeDataFile(blocker, dir, data) + _ <- logger.debug(s"Learning from ${rawData.count} items.") + trainData <- splitData(logger, rawData) + scores <- cfg.classifierConfigs.traverse(m => train(logger, trainData, m)) + sorted = scores.sortBy(-_.score) + res <- handler(sorted.head.model) + } yield res + } + + def classify( + logger: Logger[F], + model: ClassifierModel, + text: String + ): F[Option[String]] = + Sync[F].delay { + val cls = ColumnDataClassifier.getClassifier( + model.model.normalize().toAbsolutePath().toString() + ) + val cat = cls.classOf(cls.makeDatumFromLine(normalisedText(text))) + Option(cat) + } + + // --- helpers + + def train( + logger: Logger[F], + in: TrainData, + props: Map[String, String] + ): F[TrainResult] = + for { + _ <- logger.debug(s"Training classifier from $props") + res <- Sync[F].delay { + val cdc = new ColumnDataClassifier(Properties.fromMap(amendProps(in, props))) + cdc.trainClassifier(in.train.toString()) + val score = cdc.testClassifier(in.test.toString()) + TrainResult(score.first(), ClassifierModel(in.modelFile)) + } + _ <- logger.debug(s"Trained with result $res") + } yield res + + def splitData(logger: Logger[F], in: RawData): F[TrainData] = { + val nTest = (in.count * 0.25).toLong + + val td = + TrainData(in.file.resolveSibling("train.txt"), in.file.resolveSibling("test.txt")) + + val fileLines = + fs2.io.file + .readAll(in.file, blocker, 4096) + .through(fs2.text.utf8Decode) + .through(fs2.text.lines) + + for { + _ <- logger.debug( + s"Splitting raw data into test/train data. Testing with $nTest entries" + ) + _ <- + fileLines + .take(nTest) + .intersperse("\n") + .through(fs2.text.utf8Encode) + .through(fs2.io.file.writeAll(td.test, blocker)) + .compile + .drain + _ <- + fileLines + .drop(nTest) + .intersperse("\n") + .through(fs2.text.utf8Encode) + .through(fs2.io.file.writeAll(td.train, blocker)) + .compile + .drain + } yield td + } + + def writeDataFile(blocker: Blocker, dir: Path, data: Stream[F, Data]): F[RawData] = { + val target = dir.resolve("rawdata") + for { + counter <- Ref.of[F, Long](0L) + _ <- + data + .map(d => s"${d.cls}\t${d.ref}\t${normalisedText(d.text)}") + .evalTap(_ => counter.update(_ + 1)) + .intersperse("\n") + .through(fs2.text.utf8Encode) + .through(fs2.io.file.writeAll(target, blocker)) + .compile + .drain + lines <- counter.get + } yield RawData(lines, target) + + } + + def normalisedText(text: String): String = + text.replaceAll("[\n\t]+", " ") + + def amendProps( + trainData: TrainData, + props: Map[String, String] + ): Map[String, String] = + prepend("2", props) ++ Map( + "trainFile" -> trainData.train.normalize().toAbsolutePath().toString(), + "testFile" -> trainData.test.normalize().toAbsolutePath().toString(), + "serializeTo" -> trainData.modelFile.normalize().toAbsolutePath().toString() + ).toList + + case class RawData(count: Long, file: Path) + case class TrainData(train: Path, test: Path) { + val modelFile = train.resolveSibling("model.ser.gz") + } + + case class TrainResult(score: Double, model: ClassifierModel) + + def prepend(pre: String, data: Map[String, String]): Map[String, String] = + data.toList + .map({ + case (k, v) => + if (k.startsWith(pre)) (k, v) + else (pre + k, v) + }) + .toMap +} diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifier.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifier.scala new file mode 100644 index 00000000..f2927d0c --- /dev/null +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifier.scala @@ -0,0 +1,25 @@ +package docspell.analysis.nlp + +import cats.data.Kleisli +import fs2.Stream + +import docspell.analysis.nlp.TextClassifier.Data +import docspell.common._ + +trait TextClassifier[F[_]] { + + def trainClassifier[A](logger: Logger[F], data: Stream[F, Data])( + handler: TextClassifier.Handler[F, A] + ): F[A] + + def classify(logger: Logger[F], model: ClassifierModel, text: String): F[Option[String]] + +} + +object TextClassifier { + + type Handler[F[_], A] = Kleisli[F, ClassifierModel, A] + + case class Data(cls: String, ref: String, text: String) + +} diff --git a/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifierConfig.scala b/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifierConfig.scala new file mode 100644 index 00000000..e3baac46 --- /dev/null +++ b/modules/analysis/src/main/scala/docspell/analysis/nlp/TextClassifierConfig.scala @@ -0,0 +1,10 @@ +package docspell.analysis.nlp + +import java.nio.file.Path + +import cats.data.NonEmptyList + +case class TextClassifierConfig( + workingDir: Path, + classifierConfigs: NonEmptyList[Map[String, String]] +) diff --git a/modules/analysis/src/test/resources/test.ser.gz b/modules/analysis/src/test/resources/test.ser.gz new file mode 100644 index 00000000..b6d0956b Binary files /dev/null and b/modules/analysis/src/test/resources/test.ser.gz differ diff --git a/modules/analysis/src/test/scala/docspell/analysis/nlp/StanfordTextClassifierSuite.scala b/modules/analysis/src/test/scala/docspell/analysis/nlp/StanfordTextClassifierSuite.scala new file mode 100644 index 00000000..b9596923 --- /dev/null +++ b/modules/analysis/src/test/scala/docspell/analysis/nlp/StanfordTextClassifierSuite.scala @@ -0,0 +1,76 @@ +package docspell.analysis.nlp + +import minitest._ +import cats.effect._ +import scala.concurrent.ExecutionContext +import java.nio.file.Paths +import cats.data.NonEmptyList +import docspell.common._ +import fs2.Stream +import cats.data.Kleisli +import TextClassifier.Data + +object StanfordTextClassifierSuite extends SimpleTestSuite { + val logger = Logger.log4s[IO](org.log4s.getLogger) + + implicit val CS = IO.contextShift(ExecutionContext.global) + + test("learn from data") { + val cfg = TextClassifierConfig(Paths.get("target"), NonEmptyList.of(Map())) + + val data = + Stream + .emit(Data("invoice", "n", "this is your invoice total $421")) + .repeat + .take(10) + .zip( + Stream + .emit(Data("receipt", "n", "shopping receipt cheese cake bar")) + .repeat + .take(10) + ) + .flatMap({ + case (a, b) => + Stream.emits(Seq(a, b)) + }) + .covary[IO] + + val modelExists = + Blocker[IO].use { blocker => + val classifier = new StanfordTextClassifier[IO](cfg, blocker) + classifier.trainClassifier[Boolean](logger, data)( + Kleisli(result => File.existsNonEmpty[IO](result.model)) + ) + } + assertEquals(modelExists.unsafeRunSync(), true) + } + + test("run classifier") { + val cfg = TextClassifierConfig(Paths.get("target"), NonEmptyList.of(Map())) + val things = for { + dir <- File.withTempDir[IO](Paths.get("target"), "testcls") + blocker <- Blocker[IO] + } yield (dir, blocker) + + things + .use { + case (dir, blocker) => + val classifier = new StanfordTextClassifier[IO](cfg, blocker) + + val modelFile = dir.resolve("test.ser.gz") + for { + _ <- + LenientUri + .fromJava(getClass.getResource("/test.ser.gz")) + .readURL[IO](4096, blocker) + .through(fs2.io.file.writeAll(modelFile, blocker)) + .compile + .drain + model = ClassifierModel(modelFile) + cat <- classifier.classify(logger, model, "there is receipt always") + _ = assertEquals(cat, Some("receipt")) + } yield () + } + .unsafeRunSync() + } +} diff --git a/modules/joex/src/main/resources/reference.conf b/modules/joex/src/main/resources/reference.conf index 746f7bac..e09bfd3b 100644 --- a/modules/joex/src/main/resources/reference.conf +++ b/modules/joex/src/main/resources/reference.conf @@ -298,7 +298,7 @@ docspell.joex { # These settings are used to configure the classifier. If # multiple are given, they are all tried and the "best" is # chosen at the end. See - # https://nlp.stanford.edu/wiki/Software/Classifier/20_Newsgroups + # https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/classify/ColumnDataClassifier.html # for more info about these settings. The settings are almost # identical to them, as they yielded best results with *my* # dataset. diff --git a/modules/joex/src/main/scala/docspell/joex/Config.scala b/modules/joex/src/main/scala/docspell/joex/Config.scala index a90ad61a..cbbb4a33 100644 --- a/modules/joex/src/main/scala/docspell/joex/Config.scala +++ b/modules/joex/src/main/scala/docspell/joex/Config.scala @@ -2,7 +2,10 @@ package docspell.joex import java.nio.file.Path +import cats.data.NonEmptyList + import docspell.analysis.TextAnalysisConfig +import docspell.analysis.nlp.TextClassifierConfig import docspell.backend.Config.Files import docspell.common._ import docspell.convert.ConvertConfig @@ -62,7 +65,15 @@ object Config { ) { def textAnalysisConfig: TextAnalysisConfig = - TextAnalysisConfig(maxLength) + TextAnalysisConfig( + maxLength, + TextClassifierConfig( + workingDir, + NonEmptyList + .fromList(classification.classifiers) + .getOrElse(NonEmptyList.of(Map.empty)) + ) + ) def regexNerFileConfig: RegexNerFile.Config = RegexNerFile.Config(regexNer.enabled, workingDir, regexNer.fileCacheTime) diff --git a/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala b/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala new file mode 100644 index 00000000..a161417a --- /dev/null +++ b/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala @@ -0,0 +1,64 @@ +package docspell.joex.learn + +import cats.data.Kleisli +import cats.data.OptionT +import cats.effect._ +import fs2.Stream + +import docspell.analysis.TextAnalyser +import docspell.analysis.nlp.ClassifierModel +import docspell.analysis.nlp.TextClassifier.Data +import docspell.backend.ops.OCollective +import docspell.common._ +import docspell.joex.Config +import docspell.joex.scheduler._ + +object LearnClassifierTask { + + type Args = LearnClassifierArgs + + def apply[F[_]: Sync: ContextShift]( + cfg: Config.TextAnalysis, + blocker: Blocker, + analyser: TextAnalyser[F] + ): Task[F, Args, Unit] = + Task { ctx => + (for { + sett <- findActiveSettings[F](ctx.args.collective, cfg) + data = selectItems( + ctx, + math.min(cfg.classification.itemCount, sett.itemCount), + sett.category.getOrElse("") + ) + _ <- OptionT.liftF( + analyser + .classifier(blocker) + .trainClassifier[Unit](ctx.logger, data)(Kleisli(handleModel(ctx))) + ) + } yield ()) + .getOrElseF(logInactiveWarning(ctx.logger)) + } + + private def handleModel[F[_]]( + ctx: Context[F, Args] + )(trainedModel: ClassifierModel): F[Unit] = + ??? + + private def selectItems[F[_]]( + ctx: Context[F, Args], + max: Int, + category: String + ): Stream[F, Data] = + ??? + + private def findActiveSettings[F[_]: Sync]( + coll: Ident, + cfg: Config.TextAnalysis + ): OptionT[F, OCollective.Classifier] = + ??? + + private def logInactiveWarning[F[_]: Sync](logger: Logger[F]): F[Unit] = + logger.warn( + "Classification is disabled. Check joex config and the collective settings." + ) +} diff --git a/modules/joex/src/main/scala/docspell/joex/process/TextAnalysis.scala b/modules/joex/src/main/scala/docspell/joex/process/TextAnalysis.scala index abbb6870..92975a70 100644 --- a/modules/joex/src/main/scala/docspell/joex/process/TextAnalysis.scala +++ b/modules/joex/src/main/scala/docspell/joex/process/TextAnalysis.scala @@ -4,7 +4,7 @@ import cats.effect._ import cats.implicits._ import docspell.analysis.TextAnalyser -import docspell.analysis.nlp.StanfordSettings +import docspell.analysis.nlp.StanfordNerSettings import docspell.common._ import docspell.joex.analysis.RegexNerFile import docspell.joex.process.ItemData.AttachmentDates @@ -42,7 +42,7 @@ object TextAnalysis { analyser: TextAnalyser[F], nerFile: RegexNerFile[F] )(rm: RAttachmentMeta): F[(RAttachmentMeta, AttachmentDates)] = { - val settings = StanfordSettings(ctx.args.meta.language, false, None) + val settings = StanfordNerSettings(ctx.args.meta.language, false, None) for { customNer <- nerFile.makeFile(ctx.args.meta.collective) sett = settings.copy(regexNer = customNer)