From 2e5ad4960b338861e2855cd31874c2b5e4454472 Mon Sep 17 00:00:00 2001 From: eikek Date: Wed, 6 Jul 2022 23:46:16 +0200 Subject: [PATCH] Add a state parameter to oidc flow Refs: #1619 --- .../docspell/backend/auth/TokenUtil.scala | 13 ++--- .../scala/docspell/common/util/SignUtil.scala | 39 +++++++++++++++ .../docspell/common/util/SignUtilTest.scala | 20 ++++++++ .../scala/docspell/oidc/CodeFlowConfig.scala | 7 ++- .../scala/docspell/oidc/CodeFlowRoutes.scala | 41 +++++++++++----- .../main/scala/docspell/oidc/StateParam.scala | 49 +++++++++++++++++++ .../scala/docspell/oidc/StateParamTest.scala | 47 ++++++++++++++++++ .../docspell/restserver/auth/OpenId.scala | 3 +- 8 files changed, 194 insertions(+), 25 deletions(-) create mode 100644 modules/common/src/main/scala/docspell/common/util/SignUtil.scala create mode 100644 modules/common/src/test/scala/docspell/common/util/SignUtilTest.scala create mode 100644 modules/oidc/src/main/scala/docspell/oidc/StateParam.scala create mode 100644 modules/oidc/src/test/scala/docspell/oidc/StateParamTest.scala diff --git a/modules/backend/src/main/scala/docspell/backend/auth/TokenUtil.scala b/modules/backend/src/main/scala/docspell/backend/auth/TokenUtil.scala index 329e8e4a..060bc439 100644 --- a/modules/backend/src/main/scala/docspell/backend/auth/TokenUtil.scala +++ b/modules/backend/src/main/scala/docspell/backend/auth/TokenUtil.scala @@ -6,11 +6,10 @@ package docspell.backend.auth -import javax.crypto.Mac -import javax.crypto.spec.SecretKeySpec - import cats.implicits._ +import docspell.common.util.SignUtil + import scodec.bits._ private[auth] object TokenUtil { @@ -34,11 +33,8 @@ private[auth] object TokenUtil { signRaw(raw, key) } - private def signRaw(data: String, key: ByteVector): String = { - val mac = Mac.getInstance("HmacSHA1") - mac.init(new SecretKeySpec(key.toArray, "HmacSHA1")) - ByteVector.view(mac.doFinal(data.getBytes(utf8))).toBase64 - } + private def signRaw(data: String, key: ByteVector): String = + SignUtil.signString(data, key).toBase64 def b64enc(s: String): String = ByteVector.view(s.getBytes(utf8)).toBase64 @@ -52,5 +48,4 @@ private[auth] object TokenUtil { def constTimeEq(s1: String, s2: String): Boolean = s1.zip(s2) .foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length - } diff --git a/modules/common/src/main/scala/docspell/common/util/SignUtil.scala b/modules/common/src/main/scala/docspell/common/util/SignUtil.scala new file mode 100644 index 00000000..04ccd455 --- /dev/null +++ b/modules/common/src/main/scala/docspell/common/util/SignUtil.scala @@ -0,0 +1,39 @@ +/* + * Copyright 2020 Eike K. & Contributors + * + * SPDX-License-Identifier: AGPL-3.0-or-later + */ + +package docspell.common.util + +import javax.crypto.Mac +import javax.crypto.spec.SecretKeySpec + +import scodec.bits.ByteVector + +object SignUtil { + private val utf8 = java.nio.charset.StandardCharsets.UTF_8 + + private val macAlgo = "HmacSHA1" + + private def getMac(key: ByteVector) = { + val mac = Mac.getInstance(macAlgo) + mac.init(new SecretKeySpec(key.toArray, macAlgo)) + mac + } + + def signString(data: String, key: ByteVector): ByteVector = { + val mac = getMac(key) + ByteVector.view(mac.doFinal(data.getBytes(utf8))) + } + + def signBytes(data: ByteVector, key: ByteVector): ByteVector = { + val mac = getMac(key) + ByteVector.view(mac.doFinal(data.toArray)) + } + + def isEqual(sig1: ByteVector, sig2: ByteVector): Boolean = + sig1 + .zipWith(sig2)((b1, b2) => (b1 - b2).toByte) + .foldLeft(true)(_ && _ == 0) && sig1.length == sig2.length +} diff --git a/modules/common/src/test/scala/docspell/common/util/SignUtilTest.scala b/modules/common/src/test/scala/docspell/common/util/SignUtilTest.scala new file mode 100644 index 00000000..d985f014 --- /dev/null +++ b/modules/common/src/test/scala/docspell/common/util/SignUtilTest.scala @@ -0,0 +1,20 @@ +/* + * Copyright 2020 Eike K. & Contributors + * + * SPDX-License-Identifier: AGPL-3.0-or-later + */ + +package docspell.common.util + +import munit.FunSuite +import scodec.bits.ByteVector + +class SignUtilTest extends FunSuite { + + private val key = ByteVector.fromValidHex("caffee") + + test("create and validate") { + val sig = SignUtil.signString("hello", key) + assert(SignUtil.isEqual(sig, SignUtil.signString("hello", key))) + } +} diff --git a/modules/oidc/src/main/scala/docspell/oidc/CodeFlowConfig.scala b/modules/oidc/src/main/scala/docspell/oidc/CodeFlowConfig.scala index 1be95021..8d626e3e 100644 --- a/modules/oidc/src/main/scala/docspell/oidc/CodeFlowConfig.scala +++ b/modules/oidc/src/main/scala/docspell/oidc/CodeFlowConfig.scala @@ -9,6 +9,7 @@ package docspell.oidc import docspell.common._ import org.http4s.Request +import scodec.bits.ByteVector trait CodeFlowConfig[F[_]] { @@ -22,17 +23,20 @@ trait CodeFlowConfig[F[_]] { */ def findProvider(id: Ident): Option[ProviderConfig] + def serverSecret: ByteVector } object CodeFlowConfig { def apply[F[_]]( url: Request[F] => LenientUri, - provider: Ident => Option[ProviderConfig] + provider: Ident => Option[ProviderConfig], + secret: ByteVector ): CodeFlowConfig[F] = new CodeFlowConfig[F] { def getEndpointUrl(req: Request[F]): LenientUri = url(req) def findProvider(id: Ident): Option[ProviderConfig] = provider(id) + val serverSecret = secret } private[oidc] def resumeUri[F[_]]( @@ -41,5 +45,4 @@ object CodeFlowConfig { cfg: CodeFlowConfig[F] ): LenientUri = cfg.getEndpointUrl(req) / prov.providerId.id / "resume" - } diff --git a/modules/oidc/src/main/scala/docspell/oidc/CodeFlowRoutes.scala b/modules/oidc/src/main/scala/docspell/oidc/CodeFlowRoutes.scala index c0adcdd2..af383a37 100644 --- a/modules/oidc/src/main/scala/docspell/oidc/CodeFlowRoutes.scala +++ b/modules/oidc/src/main/scala/docspell/oidc/CodeFlowRoutes.scala @@ -41,19 +41,23 @@ object CodeFlowRoutes { case req @ GET -> Root / Ident(id) => config.findProvider(id) match { case Some(cfg) => - val uri = cfg.authorizeUrl - .withQuery("client_id", cfg.clientId) - .withQuery("scope", cfg.scope) - .withQuery( - "redirect_uri", - CodeFlowConfig.resumeUri(req, cfg, config).asString + for { + state <- StateParam.generate[F](config.serverSecret) + uri = cfg.authorizeUrl + .withQuery("client_id", cfg.clientId) + .withQuery("scope", cfg.scope) + .withQuery( + "redirect_uri", + CodeFlowConfig.resumeUri(req, cfg, config).asString + ) + .withQuery("response_type", "code") + .withQuery("state", state.asString) + _ <- logger.debug( + s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}" ) - .withQuery("response_type", "code") - .withQuery("state", cfg.clientId) - logger.debug( - s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}" - ) *> - Found(Location(Uri.unsafeFromString(uri.asString))) + resp <- Found(Location(Uri.unsafeFromString(uri.asString))) + } yield resp + case None => logger.debug(s"No OAuth/OIDC provider found with id '$id'") *> NotFound() @@ -66,9 +70,20 @@ object CodeFlowRoutes { NotFound() case Some(provider) => val codeFromReq = OptionT.fromOption[F](req.params.get("code")) + val stateParamValid = req.params + .get("state") + .exists(state => StateParam.isValidStateParam(state, config.serverSecret)) val userInfo = for { _ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}")) + _ <- + if (stateParamValid) OptionT.pure[F](()) + else + OptionT( + logger + .warn(s"Invalid state parameter returned from Idp!") + .as(Option.empty[Unit]) + ) code <- codeFromReq _ <- OptionT.liftF( logger.trace( @@ -91,7 +106,7 @@ object CodeFlowRoutes { s"$err$descr" } .map(err => s": $err") - .getOrElse("") + .getOrElse(": ") logger.warn(s"Error resuming code flow from '${id.id}'$reason") *> onUserInfo.handle(req, provider, None) diff --git a/modules/oidc/src/main/scala/docspell/oidc/StateParam.scala b/modules/oidc/src/main/scala/docspell/oidc/StateParam.scala new file mode 100644 index 00000000..cbf70334 --- /dev/null +++ b/modules/oidc/src/main/scala/docspell/oidc/StateParam.scala @@ -0,0 +1,49 @@ +/* + * Copyright 2020 Eike K. & Contributors + * + * SPDX-License-Identifier: AGPL-3.0-or-later + */ + +package docspell.oidc + +import cats.effect._ +import cats.syntax.all._ + +import docspell.common.util.{Random, SignUtil} + +import scodec.bits.Bases.Alphabets +import scodec.bits.ByteVector + +final case class StateParam(value: String, sig: ByteVector) { + def asString: String = + s"$value$$${sig.toBase64UrlNoPad}" + + def isValid(key: ByteVector): Boolean = { + val actual = SignUtil.signString(value, key) + SignUtil.isEqual(actual, sig) + } +} + +object StateParam { + + def generate[F[_]: Sync](key: ByteVector): F[StateParam] = + Random[F].string(8).map { v => + val sig = SignUtil.signString(v, key) + StateParam(v, sig) + } + + def fromString(str: String, key: ByteVector): Either[String, StateParam] = + str.split('$') match { + case Array(v, sig) => + ByteVector + .fromBase64Descriptive(sig, Alphabets.Base64UrlNoPad) + .map(s => StateParam(v, s)) + .filterOrElse(_.isValid(key), s"Invalid signature in state param: $str") + + case _ => + Left(s"Invalid state parameter: $str") + } + + def isValidStateParam(state: String, key: ByteVector) = + fromString(state, key).isRight +} diff --git a/modules/oidc/src/test/scala/docspell/oidc/StateParamTest.scala b/modules/oidc/src/test/scala/docspell/oidc/StateParamTest.scala new file mode 100644 index 00000000..a189d935 --- /dev/null +++ b/modules/oidc/src/test/scala/docspell/oidc/StateParamTest.scala @@ -0,0 +1,47 @@ +/* + * Copyright 2020 Eike K. & Contributors + * + * SPDX-License-Identifier: AGPL-3.0-or-later + */ + +package docspell.oidc + +import cats.effect._ + +import munit.CatsEffectSuite +import scodec.bits.ByteVector + +class StateParamTest extends CatsEffectSuite { + + private val key = ByteVector.fromValidHex("caffee") + private val key2 = key ++ key + + test("generate") { + for { + p <- StateParam.generate[IO](key) + _ = { + assert(p.value.length > 8) + assert(p.sig.nonEmpty) + assert(p.isValid(key)) + assert(!p.isValid(key2)) + } + } yield () + } + + test("fromString") { + for { + p <- StateParam.generate[IO](key) + str = p.asString + p2 = StateParam.fromString(str, key).fold(sys.error, identity) + p3 = StateParam.fromString(str, key2) + p4 = StateParam.fromString("uiaeuiaeue", key) + p5 = StateParam.fromString(str + "$" + str, key) + _ = { + assertEquals(p2, p) + assert(p3.isLeft) + assert(p4.isLeft) + assert(p5.isLeft) + } + } yield () + } +} diff --git a/modules/restserver/src/main/scala/docspell/restserver/auth/OpenId.scala b/modules/restserver/src/main/scala/docspell/restserver/auth/OpenId.scala index 9ec8e7ea..e57fe256 100644 --- a/modules/restserver/src/main/scala/docspell/restserver/auth/OpenId.scala +++ b/modules/restserver/src/main/scala/docspell/restserver/auth/OpenId.scala @@ -31,7 +31,8 @@ object OpenId { ClientRequestInfo .getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid", id => - config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider) + config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider), + config.auth.serverSecret ) def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] =