mirror of
https://github.com/TheAnachronism/docspell.git
synced 2025-06-21 18:08:25 +00:00
Implement learning a text classifier from collective data
This commit is contained in:
@ -26,7 +26,7 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
||||
.use { dir =>
|
||||
for {
|
||||
rawData <- writeDataFile(blocker, dir, data)
|
||||
_ <- logger.debug(s"Learning from ${rawData.count} items.")
|
||||
_ <- logger.info(s"Learning from ${rawData.count} items.")
|
||||
trainData <- splitData(logger, rawData)
|
||||
scores <- cfg.classifierConfigs.traverse(m => train(logger, trainData, m))
|
||||
sorted = scores.sortBy(-_.score)
|
||||
@ -43,7 +43,7 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
||||
val cls = ColumnDataClassifier.getClassifier(
|
||||
model.model.normalize().toAbsolutePath().toString()
|
||||
)
|
||||
val cat = cls.classOf(cls.makeDatumFromLine(normalisedText(text)))
|
||||
val cat = cls.classOf(cls.makeDatumFromLine("\t\t" + normalisedText(text)))
|
||||
Option(cat)
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
||||
} yield res
|
||||
|
||||
def splitData(logger: Logger[F], in: RawData): F[TrainData] = {
|
||||
val nTest = (in.count * 0.25).toLong
|
||||
val nTest = (in.count * 0.15).toLong
|
||||
|
||||
val td =
|
||||
TrainData(in.file.resolveSibling("train.txt"), in.file.resolveSibling("test.txt"))
|
||||
@ -106,9 +106,10 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
||||
counter <- Ref.of[F, Long](0L)
|
||||
_ <-
|
||||
data
|
||||
.map(d => s"${d.cls}\t${d.ref}\t${normalisedText(d.text)}")
|
||||
.filter(_.text.nonEmpty)
|
||||
.map(d => s"${d.cls}\t${fixRef(d.ref)}\t${normalisedText(d.text)}")
|
||||
.evalTap(_ => counter.update(_ + 1))
|
||||
.intersperse("\n")
|
||||
.intersperse("\r\n")
|
||||
.through(fs2.text.utf8Encode)
|
||||
.through(fs2.io.file.writeAll(target, blocker))
|
||||
.compile
|
||||
@ -119,13 +120,16 @@ final class StanfordTextClassifier[F[_]: Sync: ContextShift](
|
||||
}
|
||||
|
||||
def normalisedText(text: String): String =
|
||||
text.replaceAll("[\n\t]+", " ")
|
||||
text.replaceAll("[\n\r\t]+", " ")
|
||||
|
||||
def fixRef(str: String): String =
|
||||
str.replace('\t', '_')
|
||||
|
||||
def amendProps(
|
||||
trainData: TrainData,
|
||||
props: Map[String, String]
|
||||
): Map[String, String] =
|
||||
prepend("2", props) ++ Map(
|
||||
prepend("2.", props) ++ Map(
|
||||
"trainFile" -> trainData.train.normalize().toAbsolutePath().toString(),
|
||||
"testFile" -> trainData.test.normalize().toAbsolutePath().toString(),
|
||||
"serializeTo" -> trainData.modelFile.normalize().toAbsolutePath().toString()
|
||||
|
Reference in New Issue
Block a user