Add a state parameter to oidc flow

Refs: #1619
This commit is contained in:
eikek
2022-07-06 23:46:16 +02:00
parent 44243b3d6d
commit 2e5ad4960b
8 changed files with 194 additions and 25 deletions

View File

@ -6,11 +6,10 @@
package docspell.backend.auth
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import cats.implicits._
import docspell.common.util.SignUtil
import scodec.bits._
private[auth] object TokenUtil {
@ -34,11 +33,8 @@ private[auth] object TokenUtil {
signRaw(raw, key)
}
private def signRaw(data: String, key: ByteVector): String = {
val mac = Mac.getInstance("HmacSHA1")
mac.init(new SecretKeySpec(key.toArray, "HmacSHA1"))
ByteVector.view(mac.doFinal(data.getBytes(utf8))).toBase64
}
private def signRaw(data: String, key: ByteVector): String =
SignUtil.signString(data, key).toBase64
def b64enc(s: String): String =
ByteVector.view(s.getBytes(utf8)).toBase64
@ -52,5 +48,4 @@ private[auth] object TokenUtil {
def constTimeEq(s1: String, s2: String): Boolean =
s1.zip(s2)
.foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.common.util
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import scodec.bits.ByteVector
object SignUtil {
private val utf8 = java.nio.charset.StandardCharsets.UTF_8
private val macAlgo = "HmacSHA1"
private def getMac(key: ByteVector) = {
val mac = Mac.getInstance(macAlgo)
mac.init(new SecretKeySpec(key.toArray, macAlgo))
mac
}
def signString(data: String, key: ByteVector): ByteVector = {
val mac = getMac(key)
ByteVector.view(mac.doFinal(data.getBytes(utf8)))
}
def signBytes(data: ByteVector, key: ByteVector): ByteVector = {
val mac = getMac(key)
ByteVector.view(mac.doFinal(data.toArray))
}
def isEqual(sig1: ByteVector, sig2: ByteVector): Boolean =
sig1
.zipWith(sig2)((b1, b2) => (b1 - b2).toByte)
.foldLeft(true)(_ && _ == 0) && sig1.length == sig2.length
}

View File

@ -0,0 +1,20 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.common.util
import munit.FunSuite
import scodec.bits.ByteVector
class SignUtilTest extends FunSuite {
private val key = ByteVector.fromValidHex("caffee")
test("create and validate") {
val sig = SignUtil.signString("hello", key)
assert(SignUtil.isEqual(sig, SignUtil.signString("hello", key)))
}
}

View File

@ -9,6 +9,7 @@ package docspell.oidc
import docspell.common._
import org.http4s.Request
import scodec.bits.ByteVector
trait CodeFlowConfig[F[_]] {
@ -22,17 +23,20 @@ trait CodeFlowConfig[F[_]] {
*/
def findProvider(id: Ident): Option[ProviderConfig]
def serverSecret: ByteVector
}
object CodeFlowConfig {
def apply[F[_]](
url: Request[F] => LenientUri,
provider: Ident => Option[ProviderConfig]
provider: Ident => Option[ProviderConfig],
secret: ByteVector
): CodeFlowConfig[F] =
new CodeFlowConfig[F] {
def getEndpointUrl(req: Request[F]): LenientUri = url(req)
def findProvider(id: Ident): Option[ProviderConfig] = provider(id)
val serverSecret = secret
}
private[oidc] def resumeUri[F[_]](
@ -41,5 +45,4 @@ object CodeFlowConfig {
cfg: CodeFlowConfig[F]
): LenientUri =
cfg.getEndpointUrl(req) / prov.providerId.id / "resume"
}

View File

@ -41,19 +41,23 @@ object CodeFlowRoutes {
case req @ GET -> Root / Ident(id) =>
config.findProvider(id) match {
case Some(cfg) =>
val uri = cfg.authorizeUrl
.withQuery("client_id", cfg.clientId)
.withQuery("scope", cfg.scope)
.withQuery(
"redirect_uri",
CodeFlowConfig.resumeUri(req, cfg, config).asString
for {
state <- StateParam.generate[F](config.serverSecret)
uri = cfg.authorizeUrl
.withQuery("client_id", cfg.clientId)
.withQuery("scope", cfg.scope)
.withQuery(
"redirect_uri",
CodeFlowConfig.resumeUri(req, cfg, config).asString
)
.withQuery("response_type", "code")
.withQuery("state", state.asString)
_ <- logger.debug(
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
)
.withQuery("response_type", "code")
.withQuery("state", cfg.clientId)
logger.debug(
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
) *>
Found(Location(Uri.unsafeFromString(uri.asString)))
resp <- Found(Location(Uri.unsafeFromString(uri.asString)))
} yield resp
case None =>
logger.debug(s"No OAuth/OIDC provider found with id '$id'") *>
NotFound()
@ -66,9 +70,20 @@ object CodeFlowRoutes {
NotFound()
case Some(provider) =>
val codeFromReq = OptionT.fromOption[F](req.params.get("code"))
val stateParamValid = req.params
.get("state")
.exists(state => StateParam.isValidStateParam(state, config.serverSecret))
val userInfo = for {
_ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}"))
_ <-
if (stateParamValid) OptionT.pure[F](())
else
OptionT(
logger
.warn(s"Invalid state parameter returned from Idp!")
.as(Option.empty[Unit])
)
code <- codeFromReq
_ <- OptionT.liftF(
logger.trace(
@ -91,7 +106,7 @@ object CodeFlowRoutes {
s"$err$descr"
}
.map(err => s": $err")
.getOrElse("")
.getOrElse(": <no reason>")
logger.warn(s"Error resuming code flow from '${id.id}'$reason") *>
onUserInfo.handle(req, provider, None)

View File

@ -0,0 +1,49 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.oidc
import cats.effect._
import cats.syntax.all._
import docspell.common.util.{Random, SignUtil}
import scodec.bits.Bases.Alphabets
import scodec.bits.ByteVector
final case class StateParam(value: String, sig: ByteVector) {
def asString: String =
s"$value$$${sig.toBase64UrlNoPad}"
def isValid(key: ByteVector): Boolean = {
val actual = SignUtil.signString(value, key)
SignUtil.isEqual(actual, sig)
}
}
object StateParam {
def generate[F[_]: Sync](key: ByteVector): F[StateParam] =
Random[F].string(8).map { v =>
val sig = SignUtil.signString(v, key)
StateParam(v, sig)
}
def fromString(str: String, key: ByteVector): Either[String, StateParam] =
str.split('$') match {
case Array(v, sig) =>
ByteVector
.fromBase64Descriptive(sig, Alphabets.Base64UrlNoPad)
.map(s => StateParam(v, s))
.filterOrElse(_.isValid(key), s"Invalid signature in state param: $str")
case _ =>
Left(s"Invalid state parameter: $str")
}
def isValidStateParam(state: String, key: ByteVector) =
fromString(state, key).isRight
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.oidc
import cats.effect._
import munit.CatsEffectSuite
import scodec.bits.ByteVector
class StateParamTest extends CatsEffectSuite {
private val key = ByteVector.fromValidHex("caffee")
private val key2 = key ++ key
test("generate") {
for {
p <- StateParam.generate[IO](key)
_ = {
assert(p.value.length > 8)
assert(p.sig.nonEmpty)
assert(p.isValid(key))
assert(!p.isValid(key2))
}
} yield ()
}
test("fromString") {
for {
p <- StateParam.generate[IO](key)
str = p.asString
p2 = StateParam.fromString(str, key).fold(sys.error, identity)
p3 = StateParam.fromString(str, key2)
p4 = StateParam.fromString("uiaeuiaeue", key)
p5 = StateParam.fromString(str + "$" + str, key)
_ = {
assertEquals(p2, p)
assert(p3.isLeft)
assert(p4.isLeft)
assert(p5.isLeft)
}
} yield ()
}
}

View File

@ -31,7 +31,8 @@ object OpenId {
ClientRequestInfo
.getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid",
id =>
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider)
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider),
config.auth.serverSecret
)
def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] =