mirror of
				https://github.com/TheAnachronism/docspell.git
				synced 2025-11-04 12:30:12 +00:00 
			
		
		
		
	Refactor config validation
This commit is contained in:
		@@ -28,20 +28,21 @@ object ConfigFactory {
 | 
				
			|||||||
    *      the default config
 | 
					    *      the default config
 | 
				
			||||||
    */
 | 
					    */
 | 
				
			||||||
  def default[F[_]: Async, C: ClassTag: ConfigReader](logger: Logger[F], atPath: String)(
 | 
					  def default[F[_]: Async, C: ClassTag: ConfigReader](logger: Logger[F], atPath: String)(
 | 
				
			||||||
      args: List[String]
 | 
					      args: List[String],
 | 
				
			||||||
 | 
					      validation: Validation[C]
 | 
				
			||||||
  ): F[C] =
 | 
					  ): F[C] =
 | 
				
			||||||
    findFileFromArgs(args).flatMap {
 | 
					    findFileFromArgs(args).flatMap {
 | 
				
			||||||
      case Some(file) =>
 | 
					      case Some(file) =>
 | 
				
			||||||
        logger.info(s"Using config file: $file") *>
 | 
					        logger.info(s"Using config file: $file") *>
 | 
				
			||||||
          readFile[F, C](file, atPath)
 | 
					          readFile[F, C](file, atPath).map(validation.validOrThrow)
 | 
				
			||||||
      case None =>
 | 
					      case None =>
 | 
				
			||||||
        checkSystemProperty.value.flatMap {
 | 
					        checkSystemProperty.value.flatMap {
 | 
				
			||||||
          case Some(file) =>
 | 
					          case Some(file) =>
 | 
				
			||||||
            logger.info(s"Using config file from system property: $file") *>
 | 
					            logger.info(s"Using config file from system property: $file") *>
 | 
				
			||||||
              readConfig(atPath)
 | 
					              readConfig(atPath).map(validation.validOrThrow)
 | 
				
			||||||
          case None =>
 | 
					          case None =>
 | 
				
			||||||
            logger.info("Using config from environment variables!") *>
 | 
					            logger.info("Using config from environment variables!") *>
 | 
				
			||||||
              readEnv(atPath)
 | 
					              readEnv(atPath).map(validation.validOrThrow)
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,71 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 * Copyright 2020 Eike K. & Contributors
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * SPDX-License-Identifier: AGPL-3.0-or-later
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package docspell.config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import cats._
 | 
				
			||||||
 | 
					import cats.data.{NonEmptyChain, Validated, ValidatedNec}
 | 
				
			||||||
 | 
					import cats.implicits._
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					final case class Validation[C](run: C => ValidatedNec[String, C]) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def validOrThrow(c: C): C =
 | 
				
			||||||
 | 
					    run(c) match {
 | 
				
			||||||
 | 
					      case Validated.Valid(cfg) => cfg
 | 
				
			||||||
 | 
					      case Validated.Invalid(errs) =>
 | 
				
			||||||
 | 
					        val msg = errs.toList.mkString("- ", "\n- ", "\n")
 | 
				
			||||||
 | 
					        throw sys.error(s"\n\n$msg")
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def andThen(next: Validation[C]): Validation[C] =
 | 
				
			||||||
 | 
					    Validation(c =>
 | 
				
			||||||
 | 
					      run(c) match {
 | 
				
			||||||
 | 
					        case Validated.Valid(c2) => next.run(c2)
 | 
				
			||||||
 | 
					        case f: Validated.Invalid[NonEmptyChain[String]] =>
 | 
				
			||||||
 | 
					          next.run(c) match {
 | 
				
			||||||
 | 
					            case Validated.Valid(_) => f
 | 
				
			||||||
 | 
					            case Validated.Invalid(errs2) =>
 | 
				
			||||||
 | 
					              Validation.invalid(f.e ++ errs2)
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					object Validation {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def flatten[C](run: C => Validation[C]): Validation[C] =
 | 
				
			||||||
 | 
					    Validation(c => run(c).run(c))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def failWhen[C](isInvalid: C => Boolean, msg: => String): Validation[C] =
 | 
				
			||||||
 | 
					    Validation(c => if (isInvalid(c)) invalid(msg) else valid(c))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def okWhen[C](isValid: C => Boolean, msg: => String): Validation[C] =
 | 
				
			||||||
 | 
					    Validation(c => if (isValid(c)) valid(c) else invalid(msg))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def valid[C](c: C): ValidatedNec[String, C] =
 | 
				
			||||||
 | 
					    Validated.validNec(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def invalid[C](msgs: NonEmptyChain[String]): ValidatedNec[String, C] =
 | 
				
			||||||
 | 
					    Validated.Invalid(msgs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def invalid[C](msg: String, msgs: String*): ValidatedNec[String, C] =
 | 
				
			||||||
 | 
					    Validated.Invalid(NonEmptyChain(msg, msgs: _*))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def asValid[C]: Validation[C] =
 | 
				
			||||||
 | 
					    Validation(c => valid(c))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def insert[C](c: C): Validation[C] =
 | 
				
			||||||
 | 
					    Validation(_ => valid(c))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def error[C](msg: String, msgs: String*): Validation[C] =
 | 
				
			||||||
 | 
					    Validation(_ => invalid(msg, msgs: _*))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  implicit def validationMonoid[C]: Monoid[Validation[C]] =
 | 
				
			||||||
 | 
					    Monoid.instance(asValid, (v1, v2) => v1.andThen(v2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def of[C](v1: Validation[C], vn: Validation[C]*): Validation[C] =
 | 
				
			||||||
 | 
					    Monoid[Validation[C]].combineAll(v1 :: vn.toList)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,25 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 * Copyright 2020 Eike K. & Contributors
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * SPDX-License-Identifier: AGPL-3.0-or-later
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package docspell.config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import munit.FunSuite
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ValidationTest extends FunSuite {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  test("thread value through validations") {
 | 
				
			||||||
 | 
					    val v1 = Validation[Int](n => Validation.valid(n + 1))
 | 
				
			||||||
 | 
					    assertEquals(v1.validOrThrow(0), 1)
 | 
				
			||||||
 | 
					    assertEquals(Validation.of(v1, v1, v1).validOrThrow(0), 3)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  test("fail if there is at least one error") {
 | 
				
			||||||
 | 
					    val v1 = Validation[Int](n => Validation.valid(n + 1))
 | 
				
			||||||
 | 
					    val v2 = Validation.error[Int]("error")
 | 
				
			||||||
 | 
					    assertEquals(Validation.of(v1, v2).run(0), Validation.invalid("error"))
 | 
				
			||||||
 | 
					    assertEquals(Validation.of(v2, v1).run(0), Validation.invalid("error"))
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -6,14 +6,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package docspell.joex
 | 
					package docspell.joex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import cats.data.Validated
 | 
					 | 
				
			||||||
import cats.data.ValidatedNec
 | 
					 | 
				
			||||||
import cats.effect.Async
 | 
					import cats.effect.Async
 | 
				
			||||||
import cats.implicits._
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import docspell.common.Logger
 | 
					import docspell.common.Logger
 | 
				
			||||||
import docspell.config.ConfigFactory
 | 
					 | 
				
			||||||
import docspell.config.Implicits._
 | 
					import docspell.config.Implicits._
 | 
				
			||||||
 | 
					import docspell.config.{ConfigFactory, Validation}
 | 
				
			||||||
import docspell.joex.scheduler.CountingScheme
 | 
					import docspell.joex.scheduler.CountingScheme
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import emil.MailAddress
 | 
					import emil.MailAddress
 | 
				
			||||||
@@ -28,13 +25,9 @@ object ConfigFile {
 | 
				
			|||||||
  def loadConfig[F[_]: Async](args: List[String]): F[Config] = {
 | 
					  def loadConfig[F[_]: Async](args: List[String]): F[Config] = {
 | 
				
			||||||
    val logger = Logger.log4s[F](org.log4s.getLogger)
 | 
					    val logger = Logger.log4s[F](org.log4s.getLogger)
 | 
				
			||||||
    ConfigFactory
 | 
					    ConfigFactory
 | 
				
			||||||
      .default[F, Config](logger, "docspell.joex")(args)
 | 
					      .default[F, Config](logger, "docspell.joex")(args, validate)
 | 
				
			||||||
      .map(cfg => validOrThrow(cfg))
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private def validOrThrow(cfg: Config): Config =
 | 
					 | 
				
			||||||
    validate(cfg).fold(err => sys.error(err.toList.mkString("- ", "\n", "")), identity)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  object Implicits {
 | 
					  object Implicits {
 | 
				
			||||||
    implicit val countingSchemeReader: ConfigReader[CountingScheme] =
 | 
					    implicit val countingSchemeReader: ConfigReader[CountingScheme] =
 | 
				
			||||||
      ConfigReader[String].emap(reason(CountingScheme.readString))
 | 
					      ConfigReader[String].emap(reason(CountingScheme.readString))
 | 
				
			||||||
@@ -46,23 +39,19 @@ object ConfigFile {
 | 
				
			|||||||
      ConfigReader[String].emap(reason(MailAddress.parse))
 | 
					      ConfigReader[String].emap(reason(MailAddress.parse))
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def validate(cfg: Config): ValidatedNec[String, Config] =
 | 
					  def validate: Validation[Config] =
 | 
				
			||||||
    List(
 | 
					    Validation.of[Config](
 | 
				
			||||||
      failWhen(
 | 
					      Validation.failWhen(
 | 
				
			||||||
        cfg.updateCheck.enabled && cfg.updateCheck.recipients.isEmpty,
 | 
					        cfg => cfg.updateCheck.enabled && cfg.updateCheck.recipients.isEmpty,
 | 
				
			||||||
        "No recipients given for enabled update check!"
 | 
					        "No recipients given for enabled update check!"
 | 
				
			||||||
      ),
 | 
					      ),
 | 
				
			||||||
      failWhen(
 | 
					      Validation.failWhen(
 | 
				
			||||||
        cfg.updateCheck.enabled && cfg.updateCheck.smtpId.isEmpty,
 | 
					        cfg => cfg.updateCheck.enabled && cfg.updateCheck.smtpId.isEmpty,
 | 
				
			||||||
        "No recipients given for enabled update check!"
 | 
					        "No recipients given for enabled update check!"
 | 
				
			||||||
      ),
 | 
					      ),
 | 
				
			||||||
      failWhen(
 | 
					      Validation.failWhen(
 | 
				
			||||||
        cfg.updateCheck.enabled && cfg.updateCheck.subject.els.isEmpty,
 | 
					        cfg => cfg.updateCheck.enabled && cfg.updateCheck.subject.els.isEmpty,
 | 
				
			||||||
        "No subject given for enabled update check!"
 | 
					        "No subject given for enabled update check!"
 | 
				
			||||||
      )
 | 
					      )
 | 
				
			||||||
    ).reduce(_ |+| _).map(_ => cfg)
 | 
					    )
 | 
				
			||||||
 | 
					 | 
				
			||||||
  def failWhen(cond: Boolean, msg: => String): ValidatedNec[String, Unit] =
 | 
					 | 
				
			||||||
    Validated.condNec(!cond, (), msg)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,15 +8,13 @@ package docspell.restserver
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import java.security.SecureRandom
 | 
					import java.security.SecureRandom
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import cats.Semigroup
 | 
					import cats.Monoid
 | 
				
			||||||
import cats.data.{Validated, ValidatedNec}
 | 
					 | 
				
			||||||
import cats.effect.Async
 | 
					import cats.effect.Async
 | 
				
			||||||
import cats.implicits._
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import docspell.backend.signup.{Config => SignupConfig}
 | 
					import docspell.backend.signup.{Config => SignupConfig}
 | 
				
			||||||
import docspell.common.Logger
 | 
					import docspell.common.Logger
 | 
				
			||||||
import docspell.config.ConfigFactory
 | 
					 | 
				
			||||||
import docspell.config.Implicits._
 | 
					import docspell.config.Implicits._
 | 
				
			||||||
 | 
					import docspell.config.{ConfigFactory, Validation}
 | 
				
			||||||
import docspell.oidc.{ProviderConfig, SignatureAlgo}
 | 
					import docspell.oidc.{ProviderConfig, SignatureAlgo}
 | 
				
			||||||
import docspell.restserver.auth.OpenId
 | 
					import docspell.restserver.auth.OpenId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -30,9 +28,10 @@ object ConfigFile {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  def loadConfig[F[_]: Async](args: List[String]): F[Config] = {
 | 
					  def loadConfig[F[_]: Async](args: List[String]): F[Config] = {
 | 
				
			||||||
    val logger = Logger.log4s(unsafeLogger)
 | 
					    val logger = Logger.log4s(unsafeLogger)
 | 
				
			||||||
 | 
					    val validate =
 | 
				
			||||||
 | 
					      Validation.of(generateSecretIfEmpty, duplicateOpenIdProvider, signKeyVsUserUrl)
 | 
				
			||||||
    ConfigFactory
 | 
					    ConfigFactory
 | 
				
			||||||
      .default[F, Config](logger, "docspell.server")(args)
 | 
					      .default[F, Config](logger, "docspell.server")(args, validate)
 | 
				
			||||||
      .map(cfg => Validate(cfg))
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  object Implicits {
 | 
					  object Implicits {
 | 
				
			||||||
@@ -46,29 +45,8 @@ object ConfigFile {
 | 
				
			|||||||
      ConfigReader[String].emap(reason(OpenId.UserInfo.Extractor.fromString))
 | 
					      ConfigReader[String].emap(reason(OpenId.UserInfo.Extractor.fromString))
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  object Validate {
 | 
					  def generateSecretIfEmpty: Validation[Config] =
 | 
				
			||||||
 | 
					    Validation { cfg =>
 | 
				
			||||||
    implicit val firstConfigSemigroup: Semigroup[Config] =
 | 
					 | 
				
			||||||
      Semigroup.first
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def apply(config: Config): Config =
 | 
					 | 
				
			||||||
      all(config).foldLeft(valid(config))(_.combine(_)) match {
 | 
					 | 
				
			||||||
        case Validated.Valid(cfg) => cfg
 | 
					 | 
				
			||||||
        case Validated.Invalid(errs) =>
 | 
					 | 
				
			||||||
          val msg = errs.toList.mkString("- ", "\n- ", "\n")
 | 
					 | 
				
			||||||
          throw sys.error(s"\n\n$msg")
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def all(cfg: Config) = List(
 | 
					 | 
				
			||||||
      duplicateOpenIdProvider(cfg),
 | 
					 | 
				
			||||||
      signKeyVsUserUrl(cfg),
 | 
					 | 
				
			||||||
      generateSecretIfEmpty(cfg)
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    private def valid(cfg: Config): ValidatedNec[String, Config] =
 | 
					 | 
				
			||||||
      Validated.validNec(cfg)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generateSecretIfEmpty(cfg: Config): ValidatedNec[String, Config] =
 | 
					 | 
				
			||||||
      if (cfg.auth.serverSecret.isEmpty) {
 | 
					      if (cfg.auth.serverSecret.isEmpty) {
 | 
				
			||||||
        unsafeLogger.warn(
 | 
					        unsafeLogger.warn(
 | 
				
			||||||
          "No serverSecret specified. Generating a random one. It is recommended to add a server-secret in the config file."
 | 
					          "No serverSecret specified. Generating a random one. It is recommended to add a server-secret in the config file."
 | 
				
			||||||
@@ -77,10 +55,12 @@ object ConfigFile {
 | 
				
			|||||||
        val buffer = new Array[Byte](32)
 | 
					        val buffer = new Array[Byte](32)
 | 
				
			||||||
        random.nextBytes(buffer)
 | 
					        random.nextBytes(buffer)
 | 
				
			||||||
        val secret = ByteVector.view(buffer)
 | 
					        val secret = ByteVector.view(buffer)
 | 
				
			||||||
        valid(cfg.copy(auth = cfg.auth.copy(serverSecret = secret)))
 | 
					        Validation.valid(cfg.copy(auth = cfg.auth.copy(serverSecret = secret)))
 | 
				
			||||||
      } else valid(cfg)
 | 
					      } else Validation.valid(cfg)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def duplicateOpenIdProvider(cfg: Config): ValidatedNec[String, Config] = {
 | 
					  def duplicateOpenIdProvider: Validation[Config] =
 | 
				
			||||||
 | 
					    Validation { cfg =>
 | 
				
			||||||
      val dupes =
 | 
					      val dupes =
 | 
				
			||||||
        cfg.openid
 | 
					        cfg.openid
 | 
				
			||||||
          .filter(_.enabled)
 | 
					          .filter(_.enabled)
 | 
				
			||||||
@@ -90,27 +70,31 @@ object ConfigFile {
 | 
				
			|||||||
          .toList
 | 
					          .toList
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      val dupesStr = dupes.mkString(", ")
 | 
					      val dupesStr = dupes.mkString(", ")
 | 
				
			||||||
      if (dupes.isEmpty) valid(cfg)
 | 
					      if (dupes.isEmpty) Validation.valid(cfg)
 | 
				
			||||||
      else Validated.invalidNec(s"There is a duplicate openId provider: $dupesStr")
 | 
					      else Validation.invalid(s"There is a duplicate openId provider: $dupesStr")
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def signKeyVsUserUrl(cfg: Config): ValidatedNec[String, Config] = {
 | 
					  def signKeyVsUserUrl: Validation[Config] =
 | 
				
			||||||
      def checkProvider(p: ProviderConfig): ValidatedNec[String, Config] =
 | 
					    Validation.flatten { cfg =>
 | 
				
			||||||
 | 
					      def checkProvider(p: ProviderConfig): Validation[Config] =
 | 
				
			||||||
 | 
					        Validation { _ =>
 | 
				
			||||||
          if (p.signKey.isEmpty && p.userUrl.isEmpty)
 | 
					          if (p.signKey.isEmpty && p.userUrl.isEmpty)
 | 
				
			||||||
          Validated.invalidNec(
 | 
					            Validation.invalid(
 | 
				
			||||||
              s"Either user-url or sign-key must be set for provider ${p.providerId.id}"
 | 
					              s"Either user-url or sign-key must be set for provider ${p.providerId.id}"
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
          else if (p.signKey.nonEmpty && p.scope.isEmpty)
 | 
					          else if (p.signKey.nonEmpty && p.scope.isEmpty)
 | 
				
			||||||
          Validated.invalidNec(
 | 
					            Validation.invalid(
 | 
				
			||||||
              s"A scope is missing for OIDC auth at provider ${p.providerId.id}"
 | 
					              s"A scope is missing for OIDC auth at provider ${p.providerId.id}"
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        else Validated.valid(cfg)
 | 
					          else Validation.valid(cfg)
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Monoid[Validation[Config]]
 | 
				
			||||||
 | 
					        .combineAll(
 | 
				
			||||||
          cfg.openid
 | 
					          cfg.openid
 | 
				
			||||||
            .filter(_.enabled)
 | 
					            .filter(_.enabled)
 | 
				
			||||||
            .map(_.provider)
 | 
					            .map(_.provider)
 | 
				
			||||||
            .map(checkProvider)
 | 
					            .map(checkProvider)
 | 
				
			||||||
        .foldLeft(valid(cfg))(_.combine(_))
 | 
					        )
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user