mirror of
https://github.com/TheAnachronism/docspell.git
synced 2025-02-15 20:33:26 +00:00
Initial impl of a text classifier based on stanford-nlp
This commit is contained in:
parent
8c4f2e702b
commit
0c97b4ef76
@ -7,18 +7,21 @@ import docspell.analysis.contact.Contact
|
||||
import docspell.analysis.date.DateFind
|
||||
import docspell.analysis.nlp.PipelineCache
|
||||
import docspell.analysis.nlp.StanfordNerClassifier
|
||||
import docspell.analysis.nlp.StanfordSettings
|
||||
import docspell.analysis.nlp.StanfordNerSettings
|
||||
import docspell.analysis.nlp.StanfordTextClassifier
|
||||
import docspell.analysis.nlp.TextClassifier
|
||||
import docspell.common._
|
||||
|
||||
trait TextAnalyser[F[_]] {
|
||||
|
||||
def annotate(
|
||||
logger: Logger[F],
|
||||
settings: StanfordSettings,
|
||||
settings: StanfordNerSettings,
|
||||
cacheKey: Ident,
|
||||
text: String
|
||||
): F[TextAnalyser.Result]
|
||||
|
||||
def classifier(blocker: Blocker)(implicit CS: ContextShift[F]): TextClassifier[F]
|
||||
}
|
||||
object TextAnalyser {
|
||||
|
||||
@ -35,7 +38,7 @@ object TextAnalyser {
|
||||
new TextAnalyser[F] {
|
||||
def annotate(
|
||||
logger: Logger[F],
|
||||
settings: StanfordSettings,
|
||||
settings: StanfordNerSettings,
|
||||
cacheKey: Ident,
|
||||
text: String
|
||||
): F[TextAnalyser.Result] =
|
||||
@ -48,6 +51,11 @@ object TextAnalyser {
|
||||
spans = NerLabelSpan.build(list)
|
||||
} yield Result(spans ++ list, dates)
|
||||
|
||||
def classifier(blocker: Blocker)(implicit
|
||||
CS: ContextShift[F]
|
||||
): TextClassifier[F] =
|
||||
new StanfordTextClassifier[F](cfg.classifier, blocker)
|
||||
|
||||
private def textLimit(logger: Logger[F], text: String): F[String] =
|
||||
if (text.length <= cfg.maxLength) text.pure[F]
|
||||
else
|
||||
@ -56,7 +64,7 @@ object TextAnalyser {
|
||||
s" Analysing only first ${cfg.maxLength} characters."
|
||||
) *> text.take(cfg.maxLength).pure[F]
|
||||
|
||||
private def stanfordNer(key: Ident, settings: StanfordSettings, text: String)
|
||||
private def stanfordNer(key: Ident, settings: StanfordNerSettings, text: String)
|
||||
: F[Vector[NerLabel]] =
|
||||
StanfordNerClassifier.nerAnnotate[F](key.id, cache)(settings, text)
|
||||
|
||||
|
@ -1,5 +1,8 @@
|
||||
package docspell.analysis
|
||||
|
||||
import docspell.analysis.nlp.TextClassifierConfig
|
||||
|
||||
case class TextAnalysisConfig(
|
||||
maxLength: Int
|
||||
maxLength: Int,
|
||||
classifier: TextClassifierConfig
|
||||
)
|
||||
|
@ -0,0 +1,5 @@
|
||||
package docspell.analysis.nlp
|
||||
|
||||
import java.nio.file.Path
|
||||
|
||||
case class ClassifierModel(model: Path)
|
@ -19,7 +19,7 @@ import org.log4s.getLogger
|
||||
*/
|
||||
trait PipelineCache[F[_]] {
|
||||
|
||||
def obtain(key: String, settings: StanfordSettings): F[StanfordCoreNLP]
|
||||
def obtain(key: String, settings: StanfordNerSettings): F[StanfordCoreNLP]
|
||||
|
||||
}
|
||||
|
||||
@ -28,7 +28,7 @@ object PipelineCache {
|
||||
|
||||
def none[F[_]: Applicative]: PipelineCache[F] =
|
||||
new PipelineCache[F] {
|
||||
def obtain(ignored: String, settings: StanfordSettings): F[StanfordCoreNLP] =
|
||||
def obtain(ignored: String, settings: StanfordNerSettings): F[StanfordCoreNLP] =
|
||||
makeClassifier(settings).pure[F]
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ object PipelineCache {
|
||||
final private class Impl[F[_]: Sync](data: Ref[F, Map[String, Entry]])
|
||||
extends PipelineCache[F] {
|
||||
|
||||
def obtain(key: String, settings: StanfordSettings): F[StanfordCoreNLP] =
|
||||
def obtain(key: String, settings: StanfordNerSettings): F[StanfordCoreNLP] =
|
||||
for {
|
||||
id <- makeSettingsId(settings)
|
||||
nlp <- data.modify(cache => getOrCreate(key, id, cache, settings))
|
||||
@ -48,7 +48,7 @@ object PipelineCache {
|
||||
key: String,
|
||||
id: String,
|
||||
cache: Map[String, Entry],
|
||||
settings: StanfordSettings
|
||||
settings: StanfordNerSettings
|
||||
): (Map[String, Entry], StanfordCoreNLP) =
|
||||
cache.get(key) match {
|
||||
case Some(entry) =>
|
||||
@ -68,7 +68,7 @@ object PipelineCache {
|
||||
(cache.updated(key, e), nlp)
|
||||
}
|
||||
|
||||
private def makeSettingsId(settings: StanfordSettings): F[String] = {
|
||||
private def makeSettingsId(settings: StanfordNerSettings): F[String] = {
|
||||
val base = settings.copy(regexNer = None).toString
|
||||
val size: F[Long] =
|
||||
settings.regexNer match {
|
||||
@ -81,7 +81,7 @@ object PipelineCache {
|
||||
}
|
||||
|
||||
}
|
||||
private def makeClassifier(settings: StanfordSettings): StanfordCoreNLP = {
|
||||
private def makeClassifier(settings: StanfordNerSettings): StanfordCoreNLP = {
|
||||
logger.info(s"Creating ${settings.lang.name} Stanford NLP NER classifier...")
|
||||
new StanfordCoreNLP(Properties.forSettings(settings))
|
||||
}
|
||||
|
@ -7,6 +7,9 @@ import docspell.common._
|
||||
|
||||
object Properties {
|
||||
|
||||
def fromMap(m: Map[String, String]): JProps =
|
||||
apply(m.toSeq: _*)
|
||||
|
||||
def apply(ps: (String, String)*): JProps = {
|
||||
val p = new JProps()
|
||||
for ((k, v) <- ps)
|
||||
@ -14,7 +17,7 @@ object Properties {
|
||||
p
|
||||
}
|
||||
|
||||
def forSettings(settings: StanfordSettings): JProps = {
|
||||
def forSettings(settings: StanfordNerSettings): JProps = {
|
||||
val regexNerFile = settings.regexNer
|
||||
.map(p => p.normalize().toAbsolutePath().toString())
|
||||
settings.lang match {
|
||||
|
@ -25,7 +25,7 @@ object StanfordNerClassifier {
|
||||
def nerAnnotate[F[_]: Applicative](
|
||||
cacheKey: String,
|
||||
cache: PipelineCache[F]
|
||||
)(settings: StanfordSettings, text: String): F[Vector[NerLabel]] =
|
||||
)(settings: StanfordNerSettings, text: String): F[Vector[NerLabel]] =
|
||||
cache
|
||||
.obtain(cacheKey, settings)
|
||||
.map(crf => runClassifier(crf, text))
|
||||
|
@ -19,4 +19,8 @@ import docspell.common._
|
||||
* as a last step to tag untagged tokens using the provided list of
|
||||
* regexps.
|
||||
*/
|
||||
case class StanfordSettings(lang: Language, highRecall: Boolean, regexNer: Option[Path])
|
||||
case class StanfordNerSettings(
|
||||
lang: Language,
|
||||
highRecall: Boolean,
|
||||
regexNer: Option[Path]
|
||||
)
|
@ -0,0 +1,149 @@
|
||||
package docspell.analysis.nlp
|
||||
|
||||
import java.nio.file.Path
|
||||
|
||||
import cats.effect._
|
||||
import cats.effect.concurrent.Ref
|
||||
import cats.implicits._
|
||||
import fs2.Stream
|
||||
|
||||
import docspell.analysis.nlp.TextClassifier._
|
||||
import docspell.common._
|
||||
|
||||
import edu.stanford.nlp.classify.ColumnDataClassifier
|
||||
|
||||
final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
||||
cfg: TextClassifierConfig,
|
||||
blocker: Blocker
|
||||
) extends TextClassifier[F] {
|
||||
|
||||
def trainClassifier[A](
|
||||
logger: Logger[F],
|
||||
data: Stream[F, Data]
|
||||
)(handler: TextClassifier.Handler[F, A]): F[A] =
|
||||
File
|
||||
.withTempDir(cfg.workingDir, "trainclassifier")
|
||||
.use { dir =>
|
||||
for {
|
||||
rawData <- writeDataFile(blocker, dir, data)
|
||||
_ <- logger.debug(s"Learning from ${rawData.count} items.")
|
||||
trainData <- splitData(logger, rawData)
|
||||
scores <- cfg.classifierConfigs.traverse(m => train(logger, trainData, m))
|
||||
sorted = scores.sortBy(-_.score)
|
||||
res <- handler(sorted.head.model)
|
||||
} yield res
|
||||
}
|
||||
|
||||
def classify(
|
||||
logger: Logger[F],
|
||||
model: ClassifierModel,
|
||||
text: String
|
||||
): F[Option[String]] =
|
||||
Sync[F].delay {
|
||||
val cls = ColumnDataClassifier.getClassifier(
|
||||
model.model.normalize().toAbsolutePath().toString()
|
||||
)
|
||||
val cat = cls.classOf(cls.makeDatumFromLine(normalisedText(text)))
|
||||
Option(cat)
|
||||
}
|
||||
|
||||
// --- helpers
|
||||
|
||||
def train(
|
||||
logger: Logger[F],
|
||||
in: TrainData,
|
||||
props: Map[String, String]
|
||||
): F[TrainResult] =
|
||||
for {
|
||||
_ <- logger.debug(s"Training classifier from $props")
|
||||
res <- Sync[F].delay {
|
||||
val cdc = new ColumnDataClassifier(Properties.fromMap(amendProps(in, props)))
|
||||
cdc.trainClassifier(in.train.toString())
|
||||
val score = cdc.testClassifier(in.test.toString())
|
||||
TrainResult(score.first(), ClassifierModel(in.modelFile))
|
||||
}
|
||||
_ <- logger.debug(s"Trained with result $res")
|
||||
} yield res
|
||||
|
||||
def splitData(logger: Logger[F], in: RawData): F[TrainData] = {
|
||||
val nTest = (in.count * 0.25).toLong
|
||||
|
||||
val td =
|
||||
TrainData(in.file.resolveSibling("train.txt"), in.file.resolveSibling("test.txt"))
|
||||
|
||||
val fileLines =
|
||||
fs2.io.file
|
||||
.readAll(in.file, blocker, 4096)
|
||||
.through(fs2.text.utf8Decode)
|
||||
.through(fs2.text.lines)
|
||||
|
||||
for {
|
||||
_ <- logger.debug(
|
||||
s"Splitting raw data into test/train data. Testing with $nTest entries"
|
||||
)
|
||||
_ <-
|
||||
fileLines
|
||||
.take(nTest)
|
||||
.intersperse("\n")
|
||||
.through(fs2.text.utf8Encode)
|
||||
.through(fs2.io.file.writeAll(td.test, blocker))
|
||||
.compile
|
||||
.drain
|
||||
_ <-
|
||||
fileLines
|
||||
.drop(nTest)
|
||||
.intersperse("\n")
|
||||
.through(fs2.text.utf8Encode)
|
||||
.through(fs2.io.file.writeAll(td.train, blocker))
|
||||
.compile
|
||||
.drain
|
||||
} yield td
|
||||
}
|
||||
|
||||
def writeDataFile(blocker: Blocker, dir: Path, data: Stream[F, Data]): F[RawData] = {
|
||||
val target = dir.resolve("rawdata")
|
||||
for {
|
||||
counter <- Ref.of[F, Long](0L)
|
||||
_ <-
|
||||
data
|
||||
.map(d => s"${d.cls}\t${d.ref}\t${normalisedText(d.text)}")
|
||||
.evalTap(_ => counter.update(_ + 1))
|
||||
.intersperse("\n")
|
||||
.through(fs2.text.utf8Encode)
|
||||
.through(fs2.io.file.writeAll(target, blocker))
|
||||
.compile
|
||||
.drain
|
||||
lines <- counter.get
|
||||
} yield RawData(lines, target)
|
||||
|
||||
}
|
||||
|
||||
def normalisedText(text: String): String =
|
||||
text.replaceAll("[\n\t]+", " ")
|
||||
|
||||
def amendProps(
|
||||
trainData: TrainData,
|
||||
props: Map[String, String]
|
||||
): Map[String, String] =
|
||||
prepend("2", props) ++ Map(
|
||||
"trainFile" -> trainData.train.normalize().toAbsolutePath().toString(),
|
||||
"testFile" -> trainData.test.normalize().toAbsolutePath().toString(),
|
||||
"serializeTo" -> trainData.modelFile.normalize().toAbsolutePath().toString()
|
||||
).toList
|
||||
|
||||
case class RawData(count: Long, file: Path)
|
||||
case class TrainData(train: Path, test: Path) {
|
||||
val modelFile = train.resolveSibling("model.ser.gz")
|
||||
}
|
||||
|
||||
case class TrainResult(score: Double, model: ClassifierModel)
|
||||
|
||||
def prepend(pre: String, data: Map[String, String]): Map[String, String] =
|
||||
data.toList
|
||||
.map({
|
||||
case (k, v) =>
|
||||
if (k.startsWith(pre)) (k, v)
|
||||
else (pre + k, v)
|
||||
})
|
||||
.toMap
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package docspell.analysis.nlp
|
||||
|
||||
import cats.data.Kleisli
|
||||
import fs2.Stream
|
||||
|
||||
import docspell.analysis.nlp.TextClassifier.Data
|
||||
import docspell.common._
|
||||
|
||||
trait TextClassifier[F[_]] {
|
||||
|
||||
def trainClassifier[A](logger: Logger[F], data: Stream[F, Data])(
|
||||
handler: TextClassifier.Handler[F, A]
|
||||
): F[A]
|
||||
|
||||
def classify(logger: Logger[F], model: ClassifierModel, text: String): F[Option[String]]
|
||||
|
||||
}
|
||||
|
||||
object TextClassifier {
|
||||
|
||||
type Handler[F[_], A] = Kleisli[F, ClassifierModel, A]
|
||||
|
||||
case class Data(cls: String, ref: String, text: String)
|
||||
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
package docspell.analysis.nlp
|
||||
|
||||
import java.nio.file.Path
|
||||
|
||||
import cats.data.NonEmptyList
|
||||
|
||||
case class TextClassifierConfig(
|
||||
workingDir: Path,
|
||||
classifierConfigs: NonEmptyList[Map[String, String]]
|
||||
)
|
BIN
modules/analysis/src/test/resources/test.ser.gz
Normal file
BIN
modules/analysis/src/test/resources/test.ser.gz
Normal file
Binary file not shown.
@ -0,0 +1,76 @@
|
||||
package docspell.analysis.nlp
|
||||
|
||||
import minitest._
|
||||
import cats.effect._
|
||||
import scala.concurrent.ExecutionContext
|
||||
import java.nio.file.Paths
|
||||
import cats.data.NonEmptyList
|
||||
import docspell.common._
|
||||
import fs2.Stream
|
||||
import cats.data.Kleisli
|
||||
import TextClassifier.Data
|
||||
|
||||
object StanfordTextClassifierSuite extends SimpleTestSuite {
|
||||
val logger = Logger.log4s[IO](org.log4s.getLogger)
|
||||
|
||||
implicit val CS = IO.contextShift(ExecutionContext.global)
|
||||
|
||||
test("learn from data") {
|
||||
val cfg = TextClassifierConfig(Paths.get("target"), NonEmptyList.of(Map()))
|
||||
|
||||
val data =
|
||||
Stream
|
||||
.emit(Data("invoice", "n", "this is your invoice total $421"))
|
||||
.repeat
|
||||
.take(10)
|
||||
.zip(
|
||||
Stream
|
||||
.emit(Data("receipt", "n", "shopping receipt cheese cake bar"))
|
||||
.repeat
|
||||
.take(10)
|
||||
)
|
||||
.flatMap({
|
||||
case (a, b) =>
|
||||
Stream.emits(Seq(a, b))
|
||||
})
|
||||
.covary[IO]
|
||||
|
||||
val modelExists =
|
||||
Blocker[IO].use { blocker =>
|
||||
val classifier = new StanfordTextClassifier[IO](cfg, blocker)
|
||||
classifier.trainClassifier[Boolean](logger, data)(
|
||||
Kleisli(result => File.existsNonEmpty[IO](result.model))
|
||||
)
|
||||
}
|
||||
assertEquals(modelExists.unsafeRunSync(), true)
|
||||
}
|
||||
|
||||
test("run classifier") {
|
||||
val cfg = TextClassifierConfig(Paths.get("target"), NonEmptyList.of(Map()))
|
||||
val things = for {
|
||||
dir <- File.withTempDir[IO](Paths.get("target"), "testcls")
|
||||
blocker <- Blocker[IO]
|
||||
} yield (dir, blocker)
|
||||
|
||||
things
|
||||
.use {
|
||||
case (dir, blocker) =>
|
||||
val classifier = new StanfordTextClassifier[IO](cfg, blocker)
|
||||
|
||||
val modelFile = dir.resolve("test.ser.gz")
|
||||
for {
|
||||
_ <-
|
||||
LenientUri
|
||||
.fromJava(getClass.getResource("/test.ser.gz"))
|
||||
.readURL[IO](4096, blocker)
|
||||
.through(fs2.io.file.writeAll(modelFile, blocker))
|
||||
.compile
|
||||
.drain
|
||||
model = ClassifierModel(modelFile)
|
||||
cat <- classifier.classify(logger, model, "there is receipt always")
|
||||
_ = assertEquals(cat, Some("receipt"))
|
||||
} yield ()
|
||||
}
|
||||
.unsafeRunSync()
|
||||
}
|
||||
}
|
@ -298,7 +298,7 @@ docspell.joex {
|
||||
# These settings are used to configure the classifier. If
|
||||
# multiple are given, they are all tried and the "best" is
|
||||
# chosen at the end. See
|
||||
# https://nlp.stanford.edu/wiki/Software/Classifier/20_Newsgroups
|
||||
# https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/classify/ColumnDataClassifier.html
|
||||
# for more info about these settings. The settings are almost
|
||||
# identical to them, as they yielded best results with *my*
|
||||
# dataset.
|
||||
|
@ -2,7 +2,10 @@ package docspell.joex
|
||||
|
||||
import java.nio.file.Path
|
||||
|
||||
import cats.data.NonEmptyList
|
||||
|
||||
import docspell.analysis.TextAnalysisConfig
|
||||
import docspell.analysis.nlp.TextClassifierConfig
|
||||
import docspell.backend.Config.Files
|
||||
import docspell.common._
|
||||
import docspell.convert.ConvertConfig
|
||||
@ -62,7 +65,15 @@ object Config {
|
||||
) {
|
||||
|
||||
def textAnalysisConfig: TextAnalysisConfig =
|
||||
TextAnalysisConfig(maxLength)
|
||||
TextAnalysisConfig(
|
||||
maxLength,
|
||||
TextClassifierConfig(
|
||||
workingDir,
|
||||
NonEmptyList
|
||||
.fromList(classification.classifiers)
|
||||
.getOrElse(NonEmptyList.of(Map.empty))
|
||||
)
|
||||
)
|
||||
|
||||
def regexNerFileConfig: RegexNerFile.Config =
|
||||
RegexNerFile.Config(regexNer.enabled, workingDir, regexNer.fileCacheTime)
|
||||
|
@ -0,0 +1,64 @@
|
||||
package docspell.joex.learn
|
||||
|
||||
import cats.data.Kleisli
|
||||
import cats.data.OptionT
|
||||
import cats.effect._
|
||||
import fs2.Stream
|
||||
|
||||
import docspell.analysis.TextAnalyser
|
||||
import docspell.analysis.nlp.ClassifierModel
|
||||
import docspell.analysis.nlp.TextClassifier.Data
|
||||
import docspell.backend.ops.OCollective
|
||||
import docspell.common._
|
||||
import docspell.joex.Config
|
||||
import docspell.joex.scheduler._
|
||||
|
||||
object LearnClassifierTask {
|
||||
|
||||
type Args = LearnClassifierArgs
|
||||
|
||||
def apply[F[_]: Sync: ContextShift](
|
||||
cfg: Config.TextAnalysis,
|
||||
blocker: Blocker,
|
||||
analyser: TextAnalyser[F]
|
||||
): Task[F, Args, Unit] =
|
||||
Task { ctx =>
|
||||
(for {
|
||||
sett <- findActiveSettings[F](ctx.args.collective, cfg)
|
||||
data = selectItems(
|
||||
ctx,
|
||||
math.min(cfg.classification.itemCount, sett.itemCount),
|
||||
sett.category.getOrElse("")
|
||||
)
|
||||
_ <- OptionT.liftF(
|
||||
analyser
|
||||
.classifier(blocker)
|
||||
.trainClassifier[Unit](ctx.logger, data)(Kleisli(handleModel(ctx)))
|
||||
)
|
||||
} yield ())
|
||||
.getOrElseF(logInactiveWarning(ctx.logger))
|
||||
}
|
||||
|
||||
private def handleModel[F[_]](
|
||||
ctx: Context[F, Args]
|
||||
)(trainedModel: ClassifierModel): F[Unit] =
|
||||
???
|
||||
|
||||
private def selectItems[F[_]](
|
||||
ctx: Context[F, Args],
|
||||
max: Int,
|
||||
category: String
|
||||
): Stream[F, Data] =
|
||||
???
|
||||
|
||||
private def findActiveSettings[F[_]: Sync](
|
||||
coll: Ident,
|
||||
cfg: Config.TextAnalysis
|
||||
): OptionT[F, OCollective.Classifier] =
|
||||
???
|
||||
|
||||
private def logInactiveWarning[F[_]: Sync](logger: Logger[F]): F[Unit] =
|
||||
logger.warn(
|
||||
"Classification is disabled. Check joex config and the collective settings."
|
||||
)
|
||||
}
|
@ -4,7 +4,7 @@ import cats.effect._
|
||||
import cats.implicits._
|
||||
|
||||
import docspell.analysis.TextAnalyser
|
||||
import docspell.analysis.nlp.StanfordSettings
|
||||
import docspell.analysis.nlp.StanfordNerSettings
|
||||
import docspell.common._
|
||||
import docspell.joex.analysis.RegexNerFile
|
||||
import docspell.joex.process.ItemData.AttachmentDates
|
||||
@ -42,7 +42,7 @@ object TextAnalysis {
|
||||
analyser: TextAnalyser[F],
|
||||
nerFile: RegexNerFile[F]
|
||||
)(rm: RAttachmentMeta): F[(RAttachmentMeta, AttachmentDates)] = {
|
||||
val settings = StanfordSettings(ctx.args.meta.language, false, None)
|
||||
val settings = StanfordNerSettings(ctx.args.meta.language, false, None)
|
||||
for {
|
||||
customNer <- nerFile.makeFile(ctx.args.meta.collective)
|
||||
sett = settings.copy(regexNer = customNer)
|
||||
|
Loading…
Reference in New Issue
Block a user