From 9957c3267ed56e30fdb363def8e821d0ae8a0d2f Mon Sep 17 00:00:00 2001 From: Eike Kettner <eike.kettner@posteo.de> Date: Thu, 21 Jan 2021 17:46:39 +0100 Subject: [PATCH] Add constraints from config to classifier training For large and/or many documents, training the classifier can lead to OOM errors. Some limits have been set by default. --- .../joex/src/main/resources/reference.conf | 8 +-- .../src/main/scala/docspell/joex/Config.scala | 7 ++- .../joex/learn/LearnClassifierTask.scala | 11 ++-- .../joex/learn/LearnItemEntities.scala | 31 ++++++----- .../scala/docspell/joex/learn/LearnTags.scala | 10 ++-- .../docspell/joex/learn/SelectItems.scala | 54 ++++++++++++------- .../scala/docspell/store/queries/QItem.scala | 16 +++--- 7 files changed, 87 insertions(+), 50 deletions(-) diff --git a/modules/joex/src/main/resources/reference.conf b/modules/joex/src/main/resources/reference.conf index 00f8d435..7f2ee7d0 100644 --- a/modules/joex/src/main/resources/reference.conf +++ b/modules/joex/src/main/resources/reference.conf @@ -269,9 +269,9 @@ docspell.joex { # All text to analyse must fit into RAM. A large document may take # too much heap. Also, most important information is at the # beginning of a document, so in most cases the first two pages - # should suffice. Default is 10000, which are about 2-3 pages - # (just a rough guess, of course). - max-length = 10000 + # should suffice. Default is 8000, which are about 2-3 pages (just + # a rough guess, of course). + max-length = 8000 # A working directory for the analyser to store temporary/working # files. @@ -363,7 +363,7 @@ docspell.joex { # If concerned with memory consumption, this restricts the # number of items to consider. More are better for training. A # negative value or zero means to train on all items. - item-count = 0 + item-count = 600 # These settings are used to configure the classifier. If # multiple are given, they are all tried and the "best" is diff --git a/modules/joex/src/main/scala/docspell/joex/Config.scala b/modules/joex/src/main/scala/docspell/joex/Config.scala index 922e83c7..e995e757 100644 --- a/modules/joex/src/main/scala/docspell/joex/Config.scala +++ b/modules/joex/src/main/scala/docspell/joex/Config.scala @@ -94,5 +94,10 @@ object Config { enabled: Boolean, itemCount: Int, classifiers: List[Map[String, String]] - ) + ) { + + def itemCountOrWhenLower(other: Int): Int = + if (itemCount <= 0 || (itemCount > other && other > 0)) other + else itemCount + } } diff --git a/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala b/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala index e3aae66f..be3d7143 100644 --- a/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala +++ b/modules/joex/src/main/scala/docspell/joex/learn/LearnClassifierTask.scala @@ -37,7 +37,8 @@ object LearnClassifierTask { .learnAll( analyser, ctx.args.collective, - cfg.classification.itemCount + cfg.classification.itemCount, + cfg.maxLength ) .run(ctx) else ().pure[F] @@ -51,10 +52,14 @@ object LearnClassifierTask { val learnTags = for { sett <- findActiveSettings[F](ctx, cfg) - maxItems = math.min(cfg.classification.itemCount, sett.itemCount) + maxItems = cfg.classification.itemCountOrWhenLower(sett.itemCount) _ <- OptionT.liftF( LearnTags - .learnAllTagCategories(analyser)(ctx.args.collective, maxItems) + .learnAllTagCategories(analyser)( + ctx.args.collective, + maxItems, + cfg.maxLength + ) .run(ctx) ) } yield () diff --git a/modules/joex/src/main/scala/docspell/joex/learn/LearnItemEntities.scala b/modules/joex/src/main/scala/docspell/joex/learn/LearnItemEntities.scala index 1dc48975..f47f1e9c 100644 --- a/modules/joex/src/main/scala/docspell/joex/learn/LearnItemEntities.scala +++ b/modules/joex/src/main/scala/docspell/joex/learn/LearnItemEntities.scala @@ -14,51 +14,56 @@ object LearnItemEntities { def learnAll[F[_]: Sync: ContextShift, A]( analyser: TextAnalyser[F], collective: Ident, - maxItems: Int + maxItems: Int, + maxTextLen: Int ): Task[F, A, Unit] = - learnCorrOrg(analyser, collective, maxItems) - .flatMap(_ => learnCorrPerson[F, A](analyser, collective, maxItems)) - .flatMap(_ => learnConcPerson(analyser, collective, maxItems)) - .flatMap(_ => learnConcEquip(analyser, collective, maxItems)) + learnCorrOrg(analyser, collective, maxItems, maxTextLen) + .flatMap(_ => learnCorrPerson[F, A](analyser, collective, maxItems, maxTextLen)) + .flatMap(_ => learnConcPerson(analyser, collective, maxItems, maxTextLen)) + .flatMap(_ => learnConcEquip(analyser, collective, maxItems, maxTextLen)) def learnCorrOrg[F[_]: Sync: ContextShift, A]( analyser: TextAnalyser[F], collective: Ident, - maxItems: Int + maxItems: Int, + maxTextLen: Int ): Task[F, A, Unit] = learn(analyser, collective)( ClassifierName.correspondentOrg, - ctx => SelectItems.forCorrOrg(ctx.store, collective, maxItems) + ctx => SelectItems.forCorrOrg(ctx.store, collective, maxItems, maxTextLen) ) def learnCorrPerson[F[_]: Sync: ContextShift, A]( analyser: TextAnalyser[F], collective: Ident, - maxItems: Int + maxItems: Int, + maxTextLen: Int ): Task[F, A, Unit] = learn(analyser, collective)( ClassifierName.correspondentPerson, - ctx => SelectItems.forCorrPerson(ctx.store, collective, maxItems) + ctx => SelectItems.forCorrPerson(ctx.store, collective, maxItems, maxTextLen) ) def learnConcPerson[F[_]: Sync: ContextShift, A]( analyser: TextAnalyser[F], collective: Ident, - maxItems: Int + maxItems: Int, + maxTextLen: Int ): Task[F, A, Unit] = learn(analyser, collective)( ClassifierName.concernedPerson, - ctx => SelectItems.forConcPerson(ctx.store, collective, maxItems) + ctx => SelectItems.forConcPerson(ctx.store, collective, maxItems, maxTextLen) ) def learnConcEquip[F[_]: Sync: ContextShift, A]( analyser: TextAnalyser[F], collective: Ident, - maxItems: Int + maxItems: Int, + maxTextLen: Int ): Task[F, A, Unit] = learn(analyser, collective)( ClassifierName.concernedEquip, - ctx => SelectItems.forConcEquip(ctx.store, collective, maxItems) + ctx => SelectItems.forConcEquip(ctx.store, collective, maxItems, maxTextLen) ) private def learn[F[_]: Sync: ContextShift, A]( diff --git a/modules/joex/src/main/scala/docspell/joex/learn/LearnTags.scala b/modules/joex/src/main/scala/docspell/joex/learn/LearnTags.scala index b24eb28d..234a548f 100644 --- a/modules/joex/src/main/scala/docspell/joex/learn/LearnTags.scala +++ b/modules/joex/src/main/scala/docspell/joex/learn/LearnTags.scala @@ -14,12 +14,13 @@ object LearnTags { def learnTagCategory[F[_]: Sync: ContextShift, A]( analyser: TextAnalyser[F], collective: Ident, - maxItems: Int + maxItems: Int, + maxTextLen: Int )( category: String ): Task[F, A, Unit] = Task { ctx => - val data = SelectItems.forCategory(ctx, collective)(maxItems, category) + val data = SelectItems.forCategory(ctx, collective)(maxItems, category, maxTextLen) ctx.logger.info(s"Learn classifier for tag category: $category") *> analyser.classifier.trainClassifier(ctx.logger, data)( Kleisli( @@ -34,12 +35,13 @@ object LearnTags { def learnAllTagCategories[F[_]: Sync: ContextShift, A](analyser: TextAnalyser[F])( collective: Ident, - maxItems: Int + maxItems: Int, + maxTextLen: Int ): Task[F, A, Unit] = Task { ctx => for { cats <- ctx.store.transact(RClassifierSetting.getActiveCategories(collective)) - task = learnTagCategory[F, A](analyser, collective, maxItems) _ + task = learnTagCategory[F, A](analyser, collective, maxItems, maxTextLen) _ _ <- cats.map(task).traverse(_.run(ctx)) } yield () } diff --git a/modules/joex/src/main/scala/docspell/joex/learn/SelectItems.scala b/modules/joex/src/main/scala/docspell/joex/learn/SelectItems.scala index c6dab2f0..8ce77f62 100644 --- a/modules/joex/src/main/scala/docspell/joex/learn/SelectItems.scala +++ b/modules/joex/src/main/scala/docspell/joex/learn/SelectItems.scala @@ -16,20 +16,24 @@ object SelectItems { val noClass = LearnClassifierTask.noClass def forCategory[F[_]](ctx: Context[F, _], collective: Ident)( - max: Int, - category: String + maxItems: Int, + category: String, + maxTextLen: Int ): Stream[F, Data] = - forCategory(ctx.store, collective, max, category) + forCategory(ctx.store, collective, maxItems, category, maxTextLen) def forCategory[F[_]]( store: Store[F], collective: Ident, - max: Int, - category: String + maxItems: Int, + category: String, + maxTextLen: Int ): Stream[F, Data] = { val connStream = - allItems(collective, max) - .evalMap(item => QItem.resolveTextAndTag(collective, item, category, pageSep)) + allItems(collective, maxItems) + .evalMap(item => + QItem.resolveTextAndTag(collective, item, category, maxTextLen, pageSep) + ) .through(mkData) store.transact(connStream) } @@ -37,11 +41,14 @@ object SelectItems { def forCorrOrg[F[_]]( store: Store[F], collective: Ident, - max: Int + maxItems: Int, + maxTextLen: Int ): Stream[F, Data] = { val connStream = - allItems(collective, max) - .evalMap(item => QItem.resolveTextAndCorrOrg(collective, item, pageSep)) + allItems(collective, maxItems) + .evalMap(item => + QItem.resolveTextAndCorrOrg(collective, item, maxTextLen, pageSep) + ) .through(mkData) store.transact(connStream) } @@ -49,11 +56,14 @@ object SelectItems { def forCorrPerson[F[_]]( store: Store[F], collective: Ident, - max: Int + maxItems: Int, + maxTextLen: Int ): Stream[F, Data] = { val connStream = - allItems(collective, max) - .evalMap(item => QItem.resolveTextAndCorrPerson(collective, item, pageSep)) + allItems(collective, maxItems) + .evalMap(item => + QItem.resolveTextAndCorrPerson(collective, item, maxTextLen, pageSep) + ) .through(mkData) store.transact(connStream) } @@ -61,11 +71,14 @@ object SelectItems { def forConcPerson[F[_]]( store: Store[F], collective: Ident, - max: Int + maxItems: Int, + maxTextLen: Int ): Stream[F, Data] = { val connStream = - allItems(collective, max) - .evalMap(item => QItem.resolveTextAndConcPerson(collective, item, pageSep)) + allItems(collective, maxItems) + .evalMap(item => + QItem.resolveTextAndConcPerson(collective, item, maxTextLen, pageSep) + ) .through(mkData) store.transact(connStream) } @@ -73,11 +86,14 @@ object SelectItems { def forConcEquip[F[_]]( store: Store[F], collective: Ident, - max: Int + maxItems: Int, + maxTextLen: Int ): Stream[F, Data] = { val connStream = - allItems(collective, max) - .evalMap(item => QItem.resolveTextAndConcEquip(collective, item, pageSep)) + allItems(collective, maxItems) + .evalMap(item => + QItem.resolveTextAndConcEquip(collective, item, maxTextLen, pageSep) + ) .through(mkData) store.transact(connStream) } diff --git a/modules/store/src/main/scala/docspell/store/queries/QItem.scala b/modules/store/src/main/scala/docspell/store/queries/QItem.scala index 7a53a192..b8ee49e2 100644 --- a/modules/store/src/main/scala/docspell/store/queries/QItem.scala +++ b/modules/store/src/main/scala/docspell/store/queries/QItem.scala @@ -547,7 +547,6 @@ object QItem { chunkSize: Int, limit: Batch ): Stream[ConnectionIO, Ident] = { - val i = RItem.as("i") Select(i.id.s, from(i), i.cid === collective && i.state === ItemState.confirmed) .orderBy(i.created.desc) @@ -561,6 +560,7 @@ object QItem { collective: Ident, itemId: Ident, tagCategory: String, + maxLen: Int, pageSep: String ): ConnectionIO[TextAndTag] = { val tags = TableDef("tags").as("tt") @@ -578,7 +578,7 @@ object QItem { ) )( Select( - select(m.content, tagsTid, tagsName), + select(substring(m.content.s, 0, maxLen).s, tagsTid.s, tagsName.s), from(i) .innerJoin(a, a.itemId === i.id) .innerJoin(m, a.id === m.id) @@ -592,11 +592,12 @@ object QItem { def resolveTextAndCorrOrg( collective: Ident, itemId: Ident, + maxLen: Int, pageSep: String ): ConnectionIO[TextAndTag] = readTextAndTag(collective, itemId, pageSep) { Select( - select(m.content, org.oid, org.name), + select(substring(m.content.s, 0, maxLen).s, org.oid.s, org.name.s), from(i) .innerJoin(a, a.itemId === i.id) .innerJoin(m, m.id === a.id) @@ -608,11 +609,12 @@ object QItem { def resolveTextAndCorrPerson( collective: Ident, itemId: Ident, + maxLen: Int, pageSep: String ): ConnectionIO[TextAndTag] = readTextAndTag(collective, itemId, pageSep) { Select( - select(m.content, pers0.pid, pers0.name), + select(substring(m.content.s, 0, maxLen).s, pers0.pid.s, pers0.name.s), from(i) .innerJoin(a, a.itemId === i.id) .innerJoin(m, m.id === a.id) @@ -624,11 +626,12 @@ object QItem { def resolveTextAndConcPerson( collective: Ident, itemId: Ident, + maxLen: Int, pageSep: String ): ConnectionIO[TextAndTag] = readTextAndTag(collective, itemId, pageSep) { Select( - select(m.content, pers0.pid, pers0.name), + select(substring(m.content.s, 0, maxLen).s, pers0.pid.s, pers0.name.s), from(i) .innerJoin(a, a.itemId === i.id) .innerJoin(m, m.id === a.id) @@ -640,11 +643,12 @@ object QItem { def resolveTextAndConcEquip( collective: Ident, itemId: Ident, + maxLen: Int, pageSep: String ): ConnectionIO[TextAndTag] = readTextAndTag(collective, itemId, pageSep) { Select( - select(m.content, equip.eid, equip.name), + select(substring(m.content.s, 0, maxLen).s, equip.eid.s, equip.name.s), from(i) .innerJoin(a, a.itemId === i.id) .innerJoin(m, m.id === a.id)