Add constraints from config to classifier training

For large and/or many documents, training the classifier can lead to
OOM errors. Some limits have been set by default.
This commit is contained in:
Eike Kettner 2021-01-21 17:46:39 +01:00
parent 363cf5aef0
commit 9957c3267e
7 changed files with 87 additions and 50 deletions

View File

@ -269,9 +269,9 @@ docspell.joex {
# All text to analyse must fit into RAM. A large document may take
# too much heap. Also, most important information is at the
# beginning of a document, so in most cases the first two pages
# should suffice. Default is 10000, which are about 2-3 pages
# (just a rough guess, of course).
max-length = 10000
# should suffice. Default is 8000, which are about 2-3 pages (just
# a rough guess, of course).
max-length = 8000
# A working directory for the analyser to store temporary/working
# files.
@ -363,7 +363,7 @@ docspell.joex {
# If concerned with memory consumption, this restricts the
# number of items to consider. More are better for training. A
# negative value or zero means to train on all items.
item-count = 0
item-count = 600
# These settings are used to configure the classifier. If
# multiple are given, they are all tried and the "best" is

View File

@ -94,5 +94,10 @@ object Config {
enabled: Boolean,
itemCount: Int,
classifiers: List[Map[String, String]]
)
) {
def itemCountOrWhenLower(other: Int): Int =
if (itemCount <= 0 || (itemCount > other && other > 0)) other
else itemCount
}
}

View File

@ -37,7 +37,8 @@ object LearnClassifierTask {
.learnAll(
analyser,
ctx.args.collective,
cfg.classification.itemCount
cfg.classification.itemCount,
cfg.maxLength
)
.run(ctx)
else ().pure[F]
@ -51,10 +52,14 @@ object LearnClassifierTask {
val learnTags =
for {
sett <- findActiveSettings[F](ctx, cfg)
maxItems = math.min(cfg.classification.itemCount, sett.itemCount)
maxItems = cfg.classification.itemCountOrWhenLower(sett.itemCount)
_ <- OptionT.liftF(
LearnTags
.learnAllTagCategories(analyser)(ctx.args.collective, maxItems)
.learnAllTagCategories(analyser)(
ctx.args.collective,
maxItems,
cfg.maxLength
)
.run(ctx)
)
} yield ()

View File

@ -14,51 +14,56 @@ object LearnItemEntities {
def learnAll[F[_]: Sync: ContextShift, A](
analyser: TextAnalyser[F],
collective: Ident,
maxItems: Int
maxItems: Int,
maxTextLen: Int
): Task[F, A, Unit] =
learnCorrOrg(analyser, collective, maxItems)
.flatMap(_ => learnCorrPerson[F, A](analyser, collective, maxItems))
.flatMap(_ => learnConcPerson(analyser, collective, maxItems))
.flatMap(_ => learnConcEquip(analyser, collective, maxItems))
learnCorrOrg(analyser, collective, maxItems, maxTextLen)
.flatMap(_ => learnCorrPerson[F, A](analyser, collective, maxItems, maxTextLen))
.flatMap(_ => learnConcPerson(analyser, collective, maxItems, maxTextLen))
.flatMap(_ => learnConcEquip(analyser, collective, maxItems, maxTextLen))
def learnCorrOrg[F[_]: Sync: ContextShift, A](
analyser: TextAnalyser[F],
collective: Ident,
maxItems: Int
maxItems: Int,
maxTextLen: Int
): Task[F, A, Unit] =
learn(analyser, collective)(
ClassifierName.correspondentOrg,
ctx => SelectItems.forCorrOrg(ctx.store, collective, maxItems)
ctx => SelectItems.forCorrOrg(ctx.store, collective, maxItems, maxTextLen)
)
def learnCorrPerson[F[_]: Sync: ContextShift, A](
analyser: TextAnalyser[F],
collective: Ident,
maxItems: Int
maxItems: Int,
maxTextLen: Int
): Task[F, A, Unit] =
learn(analyser, collective)(
ClassifierName.correspondentPerson,
ctx => SelectItems.forCorrPerson(ctx.store, collective, maxItems)
ctx => SelectItems.forCorrPerson(ctx.store, collective, maxItems, maxTextLen)
)
def learnConcPerson[F[_]: Sync: ContextShift, A](
analyser: TextAnalyser[F],
collective: Ident,
maxItems: Int
maxItems: Int,
maxTextLen: Int
): Task[F, A, Unit] =
learn(analyser, collective)(
ClassifierName.concernedPerson,
ctx => SelectItems.forConcPerson(ctx.store, collective, maxItems)
ctx => SelectItems.forConcPerson(ctx.store, collective, maxItems, maxTextLen)
)
def learnConcEquip[F[_]: Sync: ContextShift, A](
analyser: TextAnalyser[F],
collective: Ident,
maxItems: Int
maxItems: Int,
maxTextLen: Int
): Task[F, A, Unit] =
learn(analyser, collective)(
ClassifierName.concernedEquip,
ctx => SelectItems.forConcEquip(ctx.store, collective, maxItems)
ctx => SelectItems.forConcEquip(ctx.store, collective, maxItems, maxTextLen)
)
private def learn[F[_]: Sync: ContextShift, A](

View File

@ -14,12 +14,13 @@ object LearnTags {
def learnTagCategory[F[_]: Sync: ContextShift, A](
analyser: TextAnalyser[F],
collective: Ident,
maxItems: Int
maxItems: Int,
maxTextLen: Int
)(
category: String
): Task[F, A, Unit] =
Task { ctx =>
val data = SelectItems.forCategory(ctx, collective)(maxItems, category)
val data = SelectItems.forCategory(ctx, collective)(maxItems, category, maxTextLen)
ctx.logger.info(s"Learn classifier for tag category: $category") *>
analyser.classifier.trainClassifier(ctx.logger, data)(
Kleisli(
@ -34,12 +35,13 @@ object LearnTags {
def learnAllTagCategories[F[_]: Sync: ContextShift, A](analyser: TextAnalyser[F])(
collective: Ident,
maxItems: Int
maxItems: Int,
maxTextLen: Int
): Task[F, A, Unit] =
Task { ctx =>
for {
cats <- ctx.store.transact(RClassifierSetting.getActiveCategories(collective))
task = learnTagCategory[F, A](analyser, collective, maxItems) _
task = learnTagCategory[F, A](analyser, collective, maxItems, maxTextLen) _
_ <- cats.map(task).traverse(_.run(ctx))
} yield ()
}

View File

@ -16,20 +16,24 @@ object SelectItems {
val noClass = LearnClassifierTask.noClass
def forCategory[F[_]](ctx: Context[F, _], collective: Ident)(
max: Int,
category: String
maxItems: Int,
category: String,
maxTextLen: Int
): Stream[F, Data] =
forCategory(ctx.store, collective, max, category)
forCategory(ctx.store, collective, maxItems, category, maxTextLen)
def forCategory[F[_]](
store: Store[F],
collective: Ident,
max: Int,
category: String
maxItems: Int,
category: String,
maxTextLen: Int
): Stream[F, Data] = {
val connStream =
allItems(collective, max)
.evalMap(item => QItem.resolveTextAndTag(collective, item, category, pageSep))
allItems(collective, maxItems)
.evalMap(item =>
QItem.resolveTextAndTag(collective, item, category, maxTextLen, pageSep)
)
.through(mkData)
store.transact(connStream)
}
@ -37,11 +41,14 @@ object SelectItems {
def forCorrOrg[F[_]](
store: Store[F],
collective: Ident,
max: Int
maxItems: Int,
maxTextLen: Int
): Stream[F, Data] = {
val connStream =
allItems(collective, max)
.evalMap(item => QItem.resolveTextAndCorrOrg(collective, item, pageSep))
allItems(collective, maxItems)
.evalMap(item =>
QItem.resolveTextAndCorrOrg(collective, item, maxTextLen, pageSep)
)
.through(mkData)
store.transact(connStream)
}
@ -49,11 +56,14 @@ object SelectItems {
def forCorrPerson[F[_]](
store: Store[F],
collective: Ident,
max: Int
maxItems: Int,
maxTextLen: Int
): Stream[F, Data] = {
val connStream =
allItems(collective, max)
.evalMap(item => QItem.resolveTextAndCorrPerson(collective, item, pageSep))
allItems(collective, maxItems)
.evalMap(item =>
QItem.resolveTextAndCorrPerson(collective, item, maxTextLen, pageSep)
)
.through(mkData)
store.transact(connStream)
}
@ -61,11 +71,14 @@ object SelectItems {
def forConcPerson[F[_]](
store: Store[F],
collective: Ident,
max: Int
maxItems: Int,
maxTextLen: Int
): Stream[F, Data] = {
val connStream =
allItems(collective, max)
.evalMap(item => QItem.resolveTextAndConcPerson(collective, item, pageSep))
allItems(collective, maxItems)
.evalMap(item =>
QItem.resolveTextAndConcPerson(collective, item, maxTextLen, pageSep)
)
.through(mkData)
store.transact(connStream)
}
@ -73,11 +86,14 @@ object SelectItems {
def forConcEquip[F[_]](
store: Store[F],
collective: Ident,
max: Int
maxItems: Int,
maxTextLen: Int
): Stream[F, Data] = {
val connStream =
allItems(collective, max)
.evalMap(item => QItem.resolveTextAndConcEquip(collective, item, pageSep))
allItems(collective, maxItems)
.evalMap(item =>
QItem.resolveTextAndConcEquip(collective, item, maxTextLen, pageSep)
)
.through(mkData)
store.transact(connStream)
}

View File

@ -547,7 +547,6 @@ object QItem {
chunkSize: Int,
limit: Batch
): Stream[ConnectionIO, Ident] = {
val i = RItem.as("i")
Select(i.id.s, from(i), i.cid === collective && i.state === ItemState.confirmed)
.orderBy(i.created.desc)
@ -561,6 +560,7 @@ object QItem {
collective: Ident,
itemId: Ident,
tagCategory: String,
maxLen: Int,
pageSep: String
): ConnectionIO[TextAndTag] = {
val tags = TableDef("tags").as("tt")
@ -578,7 +578,7 @@ object QItem {
)
)(
Select(
select(m.content, tagsTid, tagsName),
select(substring(m.content.s, 0, maxLen).s, tagsTid.s, tagsName.s),
from(i)
.innerJoin(a, a.itemId === i.id)
.innerJoin(m, a.id === m.id)
@ -592,11 +592,12 @@ object QItem {
def resolveTextAndCorrOrg(
collective: Ident,
itemId: Ident,
maxLen: Int,
pageSep: String
): ConnectionIO[TextAndTag] =
readTextAndTag(collective, itemId, pageSep) {
Select(
select(m.content, org.oid, org.name),
select(substring(m.content.s, 0, maxLen).s, org.oid.s, org.name.s),
from(i)
.innerJoin(a, a.itemId === i.id)
.innerJoin(m, m.id === a.id)
@ -608,11 +609,12 @@ object QItem {
def resolveTextAndCorrPerson(
collective: Ident,
itemId: Ident,
maxLen: Int,
pageSep: String
): ConnectionIO[TextAndTag] =
readTextAndTag(collective, itemId, pageSep) {
Select(
select(m.content, pers0.pid, pers0.name),
select(substring(m.content.s, 0, maxLen).s, pers0.pid.s, pers0.name.s),
from(i)
.innerJoin(a, a.itemId === i.id)
.innerJoin(m, m.id === a.id)
@ -624,11 +626,12 @@ object QItem {
def resolveTextAndConcPerson(
collective: Ident,
itemId: Ident,
maxLen: Int,
pageSep: String
): ConnectionIO[TextAndTag] =
readTextAndTag(collective, itemId, pageSep) {
Select(
select(m.content, pers0.pid, pers0.name),
select(substring(m.content.s, 0, maxLen).s, pers0.pid.s, pers0.name.s),
from(i)
.innerJoin(a, a.itemId === i.id)
.innerJoin(m, m.id === a.id)
@ -640,11 +643,12 @@ object QItem {
def resolveTextAndConcEquip(
collective: Ident,
itemId: Ident,
maxLen: Int,
pageSep: String
): ConnectionIO[TextAndTag] =
readTextAndTag(collective, itemId, pageSep) {
Select(
select(m.content, equip.eid, equip.name),
select(substring(m.content.s, 0, maxLen).s, equip.eid.s, equip.name.s),
from(i)
.innerJoin(a, a.itemId === i.id)
.innerJoin(m, m.id === a.id)