From 11de82402eb61c6e721f66887b322cf495ef973c Mon Sep 17 00:00:00 2001
From: eikek <eike.kettner@posteo.de>
Date: Mon, 6 Sep 2021 13:49:59 +0200
Subject: [PATCH] Add cross checks for the server config

---
 .../docspell/restserver/ConfigFile.scala      | 63 ++++++++++++++++++-
 1 file changed, 61 insertions(+), 2 deletions(-)

diff --git a/modules/restserver/src/main/scala/docspell/restserver/ConfigFile.scala b/modules/restserver/src/main/scala/docspell/restserver/ConfigFile.scala
index 8818e92a..c838ca07 100644
--- a/modules/restserver/src/main/scala/docspell/restserver/ConfigFile.scala
+++ b/modules/restserver/src/main/scala/docspell/restserver/ConfigFile.scala
@@ -6,9 +6,13 @@
 
 package docspell.restserver
 
+import cats.Semigroup
+import cats.data.{Validated, ValidatedNec}
+import cats.implicits._
+
 import docspell.backend.signup.{Config => SignupConfig}
 import docspell.common.config.Implicits._
-import docspell.oidc.SignatureAlgo
+import docspell.oidc.{ProviderConfig, SignatureAlgo}
 import docspell.restserver.auth.OpenId
 
 import pureconfig._
@@ -18,7 +22,7 @@ object ConfigFile {
   import Implicits._
 
   def loadConfig: Config =
-    ConfigSource.default.at("docspell.server").loadOrThrow[Config]
+    Validate(ConfigSource.default.at("docspell.server").loadOrThrow[Config])
 
   object Implicits {
     implicit val signupModeReader: ConfigReader[SignupConfig.Mode] =
@@ -30,4 +34,59 @@ object ConfigFile {
     implicit val openIdExtractorReader: ConfigReader[OpenId.UserInfo.Extractor] =
       ConfigReader[String].emap(reason(OpenId.UserInfo.Extractor.fromString))
   }
+
+  object Validate {
+
+    implicit val firstConfigSemigroup: Semigroup[Config] =
+      Semigroup.first
+
+    def apply(config: Config): Config =
+      all(config).foldLeft(valid(config))(_.combine(_)) match {
+        case Validated.Valid(cfg) => cfg
+        case Validated.Invalid(errs) =>
+          val msg = errs.toList.mkString("- ", "\n- ", "\n")
+          throw sys.error(s"\n\n$msg")
+      }
+
+    def all(cfg: Config) = List(
+      duplicateOpenIdProvider(cfg),
+      signKeyVsUserUrl(cfg)
+    )
+
+    private def valid(cfg: Config): ValidatedNec[String, Config] =
+      Validated.validNec(cfg)
+
+    def duplicateOpenIdProvider(cfg: Config): ValidatedNec[String, Config] = {
+      val dupes =
+        cfg.openid
+          .filter(_.enabled)
+          .groupBy(_.provider.providerId)
+          .filter(_._2.size > 1)
+          .map(_._1.id)
+          .toList
+
+      val dupesStr = dupes.mkString(", ")
+      if (dupes.isEmpty) valid(cfg)
+      else Validated.invalidNec(s"There is a duplicate openId provider: $dupesStr")
+    }
+
+    def signKeyVsUserUrl(cfg: Config): ValidatedNec[String, Config] = {
+      def checkProvider(p: ProviderConfig): ValidatedNec[String, Config] =
+        if (p.signKey.isEmpty && p.userUrl.isEmpty)
+          Validated.invalidNec(
+            s"Either user-url or sign-key must be set for provider ${p.providerId.id}"
+          )
+        else if (p.signKey.nonEmpty && p.scope.isEmpty)
+          Validated.invalidNec(
+            s"A scope is missing for OIDC auth at provider ${p.providerId.id}"
+          )
+        else Validated.valid(cfg)
+
+      cfg.openid
+        .filter(_.enabled)
+        .map(_.provider)
+        .map(checkProvider)
+        .foldLeft(valid(cfg))(_.combine(_))
+    }
+  }
 }