mirror of
https://github.com/TheAnachronism/docspell.git
synced 2025-06-21 18:08:25 +00:00
@ -6,11 +6,10 @@
|
||||
|
||||
package docspell.backend.auth
|
||||
|
||||
import javax.crypto.Mac
|
||||
import javax.crypto.spec.SecretKeySpec
|
||||
|
||||
import cats.implicits._
|
||||
|
||||
import docspell.common.util.SignUtil
|
||||
|
||||
import scodec.bits._
|
||||
|
||||
private[auth] object TokenUtil {
|
||||
@ -34,11 +33,8 @@ private[auth] object TokenUtil {
|
||||
signRaw(raw, key)
|
||||
}
|
||||
|
||||
private def signRaw(data: String, key: ByteVector): String = {
|
||||
val mac = Mac.getInstance("HmacSHA1")
|
||||
mac.init(new SecretKeySpec(key.toArray, "HmacSHA1"))
|
||||
ByteVector.view(mac.doFinal(data.getBytes(utf8))).toBase64
|
||||
}
|
||||
private def signRaw(data: String, key: ByteVector): String =
|
||||
SignUtil.signString(data, key).toBase64
|
||||
|
||||
def b64enc(s: String): String =
|
||||
ByteVector.view(s.getBytes(utf8)).toBase64
|
||||
@ -52,5 +48,4 @@ private[auth] object TokenUtil {
|
||||
def constTimeEq(s1: String, s2: String): Boolean =
|
||||
s1.zip(s2)
|
||||
.foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright 2020 Eike K. & Contributors
|
||||
*
|
||||
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
*/
|
||||
|
||||
package docspell.common.util
|
||||
|
||||
import javax.crypto.Mac
|
||||
import javax.crypto.spec.SecretKeySpec
|
||||
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
object SignUtil {
|
||||
private val utf8 = java.nio.charset.StandardCharsets.UTF_8
|
||||
|
||||
private val macAlgo = "HmacSHA1"
|
||||
|
||||
private def getMac(key: ByteVector) = {
|
||||
val mac = Mac.getInstance(macAlgo)
|
||||
mac.init(new SecretKeySpec(key.toArray, macAlgo))
|
||||
mac
|
||||
}
|
||||
|
||||
def signString(data: String, key: ByteVector): ByteVector = {
|
||||
val mac = getMac(key)
|
||||
ByteVector.view(mac.doFinal(data.getBytes(utf8)))
|
||||
}
|
||||
|
||||
def signBytes(data: ByteVector, key: ByteVector): ByteVector = {
|
||||
val mac = getMac(key)
|
||||
ByteVector.view(mac.doFinal(data.toArray))
|
||||
}
|
||||
|
||||
def isEqual(sig1: ByteVector, sig2: ByteVector): Boolean =
|
||||
sig1
|
||||
.zipWith(sig2)((b1, b2) => (b1 - b2).toByte)
|
||||
.foldLeft(true)(_ && _ == 0) && sig1.length == sig2.length
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Copyright 2020 Eike K. & Contributors
|
||||
*
|
||||
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
*/
|
||||
|
||||
package docspell.common.util
|
||||
|
||||
import munit.FunSuite
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
class SignUtilTest extends FunSuite {
|
||||
|
||||
private val key = ByteVector.fromValidHex("caffee")
|
||||
|
||||
test("create and validate") {
|
||||
val sig = SignUtil.signString("hello", key)
|
||||
assert(SignUtil.isEqual(sig, SignUtil.signString("hello", key)))
|
||||
}
|
||||
}
|
@ -9,6 +9,7 @@ package docspell.oidc
|
||||
import docspell.common._
|
||||
|
||||
import org.http4s.Request
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
trait CodeFlowConfig[F[_]] {
|
||||
|
||||
@ -22,17 +23,20 @@ trait CodeFlowConfig[F[_]] {
|
||||
*/
|
||||
def findProvider(id: Ident): Option[ProviderConfig]
|
||||
|
||||
def serverSecret: ByteVector
|
||||
}
|
||||
|
||||
object CodeFlowConfig {
|
||||
|
||||
def apply[F[_]](
|
||||
url: Request[F] => LenientUri,
|
||||
provider: Ident => Option[ProviderConfig]
|
||||
provider: Ident => Option[ProviderConfig],
|
||||
secret: ByteVector
|
||||
): CodeFlowConfig[F] =
|
||||
new CodeFlowConfig[F] {
|
||||
def getEndpointUrl(req: Request[F]): LenientUri = url(req)
|
||||
def findProvider(id: Ident): Option[ProviderConfig] = provider(id)
|
||||
val serverSecret = secret
|
||||
}
|
||||
|
||||
private[oidc] def resumeUri[F[_]](
|
||||
@ -41,5 +45,4 @@ object CodeFlowConfig {
|
||||
cfg: CodeFlowConfig[F]
|
||||
): LenientUri =
|
||||
cfg.getEndpointUrl(req) / prov.providerId.id / "resume"
|
||||
|
||||
}
|
||||
|
@ -41,19 +41,23 @@ object CodeFlowRoutes {
|
||||
case req @ GET -> Root / Ident(id) =>
|
||||
config.findProvider(id) match {
|
||||
case Some(cfg) =>
|
||||
val uri = cfg.authorizeUrl
|
||||
.withQuery("client_id", cfg.clientId)
|
||||
.withQuery("scope", cfg.scope)
|
||||
.withQuery(
|
||||
"redirect_uri",
|
||||
CodeFlowConfig.resumeUri(req, cfg, config).asString
|
||||
for {
|
||||
state <- StateParam.generate[F](config.serverSecret)
|
||||
uri = cfg.authorizeUrl
|
||||
.withQuery("client_id", cfg.clientId)
|
||||
.withQuery("scope", cfg.scope)
|
||||
.withQuery(
|
||||
"redirect_uri",
|
||||
CodeFlowConfig.resumeUri(req, cfg, config).asString
|
||||
)
|
||||
.withQuery("response_type", "code")
|
||||
.withQuery("state", state.asString)
|
||||
_ <- logger.debug(
|
||||
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
|
||||
)
|
||||
.withQuery("response_type", "code")
|
||||
.withQuery("state", cfg.clientId)
|
||||
logger.debug(
|
||||
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
|
||||
) *>
|
||||
Found(Location(Uri.unsafeFromString(uri.asString)))
|
||||
resp <- Found(Location(Uri.unsafeFromString(uri.asString)))
|
||||
} yield resp
|
||||
|
||||
case None =>
|
||||
logger.debug(s"No OAuth/OIDC provider found with id '$id'") *>
|
||||
NotFound()
|
||||
@ -66,9 +70,20 @@ object CodeFlowRoutes {
|
||||
NotFound()
|
||||
case Some(provider) =>
|
||||
val codeFromReq = OptionT.fromOption[F](req.params.get("code"))
|
||||
val stateParamValid = req.params
|
||||
.get("state")
|
||||
.exists(state => StateParam.isValidStateParam(state, config.serverSecret))
|
||||
|
||||
val userInfo = for {
|
||||
_ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}"))
|
||||
_ <-
|
||||
if (stateParamValid) OptionT.pure[F](())
|
||||
else
|
||||
OptionT(
|
||||
logger
|
||||
.warn(s"Invalid state parameter returned from Idp!")
|
||||
.as(Option.empty[Unit])
|
||||
)
|
||||
code <- codeFromReq
|
||||
_ <- OptionT.liftF(
|
||||
logger.trace(
|
||||
@ -91,7 +106,7 @@ object CodeFlowRoutes {
|
||||
s"$err$descr"
|
||||
}
|
||||
.map(err => s": $err")
|
||||
.getOrElse("")
|
||||
.getOrElse(": <no reason>")
|
||||
|
||||
logger.warn(s"Error resuming code flow from '${id.id}'$reason") *>
|
||||
onUserInfo.handle(req, provider, None)
|
||||
|
49
modules/oidc/src/main/scala/docspell/oidc/StateParam.scala
Normal file
49
modules/oidc/src/main/scala/docspell/oidc/StateParam.scala
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright 2020 Eike K. & Contributors
|
||||
*
|
||||
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
*/
|
||||
|
||||
package docspell.oidc
|
||||
|
||||
import cats.effect._
|
||||
import cats.syntax.all._
|
||||
|
||||
import docspell.common.util.{Random, SignUtil}
|
||||
|
||||
import scodec.bits.Bases.Alphabets
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
final case class StateParam(value: String, sig: ByteVector) {
|
||||
def asString: String =
|
||||
s"$value$$${sig.toBase64UrlNoPad}"
|
||||
|
||||
def isValid(key: ByteVector): Boolean = {
|
||||
val actual = SignUtil.signString(value, key)
|
||||
SignUtil.isEqual(actual, sig)
|
||||
}
|
||||
}
|
||||
|
||||
object StateParam {
|
||||
|
||||
def generate[F[_]: Sync](key: ByteVector): F[StateParam] =
|
||||
Random[F].string(8).map { v =>
|
||||
val sig = SignUtil.signString(v, key)
|
||||
StateParam(v, sig)
|
||||
}
|
||||
|
||||
def fromString(str: String, key: ByteVector): Either[String, StateParam] =
|
||||
str.split('$') match {
|
||||
case Array(v, sig) =>
|
||||
ByteVector
|
||||
.fromBase64Descriptive(sig, Alphabets.Base64UrlNoPad)
|
||||
.map(s => StateParam(v, s))
|
||||
.filterOrElse(_.isValid(key), s"Invalid signature in state param: $str")
|
||||
|
||||
case _ =>
|
||||
Left(s"Invalid state parameter: $str")
|
||||
}
|
||||
|
||||
def isValidStateParam(state: String, key: ByteVector) =
|
||||
fromString(state, key).isRight
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright 2020 Eike K. & Contributors
|
||||
*
|
||||
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
*/
|
||||
|
||||
package docspell.oidc
|
||||
|
||||
import cats.effect._
|
||||
|
||||
import munit.CatsEffectSuite
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
class StateParamTest extends CatsEffectSuite {
|
||||
|
||||
private val key = ByteVector.fromValidHex("caffee")
|
||||
private val key2 = key ++ key
|
||||
|
||||
test("generate") {
|
||||
for {
|
||||
p <- StateParam.generate[IO](key)
|
||||
_ = {
|
||||
assert(p.value.length > 8)
|
||||
assert(p.sig.nonEmpty)
|
||||
assert(p.isValid(key))
|
||||
assert(!p.isValid(key2))
|
||||
}
|
||||
} yield ()
|
||||
}
|
||||
|
||||
test("fromString") {
|
||||
for {
|
||||
p <- StateParam.generate[IO](key)
|
||||
str = p.asString
|
||||
p2 = StateParam.fromString(str, key).fold(sys.error, identity)
|
||||
p3 = StateParam.fromString(str, key2)
|
||||
p4 = StateParam.fromString("uiaeuiaeue", key)
|
||||
p5 = StateParam.fromString(str + "$" + str, key)
|
||||
_ = {
|
||||
assertEquals(p2, p)
|
||||
assert(p3.isLeft)
|
||||
assert(p4.isLeft)
|
||||
assert(p5.isLeft)
|
||||
}
|
||||
} yield ()
|
||||
}
|
||||
}
|
@ -31,7 +31,8 @@ object OpenId {
|
||||
ClientRequestInfo
|
||||
.getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid",
|
||||
id =>
|
||||
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider)
|
||||
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider),
|
||||
config.auth.serverSecret
|
||||
)
|
||||
|
||||
def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] =
|
||||
|
Reference in New Issue
Block a user