mirror of
https://github.com/TheAnachronism/docspell.git
synced 2025-06-22 02:18:26 +00:00
@ -6,11 +6,10 @@
|
|||||||
|
|
||||||
package docspell.backend.auth
|
package docspell.backend.auth
|
||||||
|
|
||||||
import javax.crypto.Mac
|
|
||||||
import javax.crypto.spec.SecretKeySpec
|
|
||||||
|
|
||||||
import cats.implicits._
|
import cats.implicits._
|
||||||
|
|
||||||
|
import docspell.common.util.SignUtil
|
||||||
|
|
||||||
import scodec.bits._
|
import scodec.bits._
|
||||||
|
|
||||||
private[auth] object TokenUtil {
|
private[auth] object TokenUtil {
|
||||||
@ -34,11 +33,8 @@ private[auth] object TokenUtil {
|
|||||||
signRaw(raw, key)
|
signRaw(raw, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def signRaw(data: String, key: ByteVector): String = {
|
private def signRaw(data: String, key: ByteVector): String =
|
||||||
val mac = Mac.getInstance("HmacSHA1")
|
SignUtil.signString(data, key).toBase64
|
||||||
mac.init(new SecretKeySpec(key.toArray, "HmacSHA1"))
|
|
||||||
ByteVector.view(mac.doFinal(data.getBytes(utf8))).toBase64
|
|
||||||
}
|
|
||||||
|
|
||||||
def b64enc(s: String): String =
|
def b64enc(s: String): String =
|
||||||
ByteVector.view(s.getBytes(utf8)).toBase64
|
ByteVector.view(s.getBytes(utf8)).toBase64
|
||||||
@ -52,5 +48,4 @@ private[auth] object TokenUtil {
|
|||||||
def constTimeEq(s1: String, s2: String): Boolean =
|
def constTimeEq(s1: String, s2: String): Boolean =
|
||||||
s1.zip(s2)
|
s1.zip(s2)
|
||||||
.foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length
|
.foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 Eike K. & Contributors
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
*/
|
||||||
|
|
||||||
|
package docspell.common.util
|
||||||
|
|
||||||
|
import javax.crypto.Mac
|
||||||
|
import javax.crypto.spec.SecretKeySpec
|
||||||
|
|
||||||
|
import scodec.bits.ByteVector
|
||||||
|
|
||||||
|
object SignUtil {
|
||||||
|
private val utf8 = java.nio.charset.StandardCharsets.UTF_8
|
||||||
|
|
||||||
|
private val macAlgo = "HmacSHA1"
|
||||||
|
|
||||||
|
private def getMac(key: ByteVector) = {
|
||||||
|
val mac = Mac.getInstance(macAlgo)
|
||||||
|
mac.init(new SecretKeySpec(key.toArray, macAlgo))
|
||||||
|
mac
|
||||||
|
}
|
||||||
|
|
||||||
|
def signString(data: String, key: ByteVector): ByteVector = {
|
||||||
|
val mac = getMac(key)
|
||||||
|
ByteVector.view(mac.doFinal(data.getBytes(utf8)))
|
||||||
|
}
|
||||||
|
|
||||||
|
def signBytes(data: ByteVector, key: ByteVector): ByteVector = {
|
||||||
|
val mac = getMac(key)
|
||||||
|
ByteVector.view(mac.doFinal(data.toArray))
|
||||||
|
}
|
||||||
|
|
||||||
|
def isEqual(sig1: ByteVector, sig2: ByteVector): Boolean =
|
||||||
|
sig1
|
||||||
|
.zipWith(sig2)((b1, b2) => (b1 - b2).toByte)
|
||||||
|
.foldLeft(true)(_ && _ == 0) && sig1.length == sig2.length
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 Eike K. & Contributors
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
*/
|
||||||
|
|
||||||
|
package docspell.common.util
|
||||||
|
|
||||||
|
import munit.FunSuite
|
||||||
|
import scodec.bits.ByteVector
|
||||||
|
|
||||||
|
class SignUtilTest extends FunSuite {
|
||||||
|
|
||||||
|
private val key = ByteVector.fromValidHex("caffee")
|
||||||
|
|
||||||
|
test("create and validate") {
|
||||||
|
val sig = SignUtil.signString("hello", key)
|
||||||
|
assert(SignUtil.isEqual(sig, SignUtil.signString("hello", key)))
|
||||||
|
}
|
||||||
|
}
|
@ -9,6 +9,7 @@ package docspell.oidc
|
|||||||
import docspell.common._
|
import docspell.common._
|
||||||
|
|
||||||
import org.http4s.Request
|
import org.http4s.Request
|
||||||
|
import scodec.bits.ByteVector
|
||||||
|
|
||||||
trait CodeFlowConfig[F[_]] {
|
trait CodeFlowConfig[F[_]] {
|
||||||
|
|
||||||
@ -22,17 +23,20 @@ trait CodeFlowConfig[F[_]] {
|
|||||||
*/
|
*/
|
||||||
def findProvider(id: Ident): Option[ProviderConfig]
|
def findProvider(id: Ident): Option[ProviderConfig]
|
||||||
|
|
||||||
|
def serverSecret: ByteVector
|
||||||
}
|
}
|
||||||
|
|
||||||
object CodeFlowConfig {
|
object CodeFlowConfig {
|
||||||
|
|
||||||
def apply[F[_]](
|
def apply[F[_]](
|
||||||
url: Request[F] => LenientUri,
|
url: Request[F] => LenientUri,
|
||||||
provider: Ident => Option[ProviderConfig]
|
provider: Ident => Option[ProviderConfig],
|
||||||
|
secret: ByteVector
|
||||||
): CodeFlowConfig[F] =
|
): CodeFlowConfig[F] =
|
||||||
new CodeFlowConfig[F] {
|
new CodeFlowConfig[F] {
|
||||||
def getEndpointUrl(req: Request[F]): LenientUri = url(req)
|
def getEndpointUrl(req: Request[F]): LenientUri = url(req)
|
||||||
def findProvider(id: Ident): Option[ProviderConfig] = provider(id)
|
def findProvider(id: Ident): Option[ProviderConfig] = provider(id)
|
||||||
|
val serverSecret = secret
|
||||||
}
|
}
|
||||||
|
|
||||||
private[oidc] def resumeUri[F[_]](
|
private[oidc] def resumeUri[F[_]](
|
||||||
@ -41,5 +45,4 @@ object CodeFlowConfig {
|
|||||||
cfg: CodeFlowConfig[F]
|
cfg: CodeFlowConfig[F]
|
||||||
): LenientUri =
|
): LenientUri =
|
||||||
cfg.getEndpointUrl(req) / prov.providerId.id / "resume"
|
cfg.getEndpointUrl(req) / prov.providerId.id / "resume"
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -41,19 +41,23 @@ object CodeFlowRoutes {
|
|||||||
case req @ GET -> Root / Ident(id) =>
|
case req @ GET -> Root / Ident(id) =>
|
||||||
config.findProvider(id) match {
|
config.findProvider(id) match {
|
||||||
case Some(cfg) =>
|
case Some(cfg) =>
|
||||||
val uri = cfg.authorizeUrl
|
for {
|
||||||
.withQuery("client_id", cfg.clientId)
|
state <- StateParam.generate[F](config.serverSecret)
|
||||||
.withQuery("scope", cfg.scope)
|
uri = cfg.authorizeUrl
|
||||||
.withQuery(
|
.withQuery("client_id", cfg.clientId)
|
||||||
"redirect_uri",
|
.withQuery("scope", cfg.scope)
|
||||||
CodeFlowConfig.resumeUri(req, cfg, config).asString
|
.withQuery(
|
||||||
|
"redirect_uri",
|
||||||
|
CodeFlowConfig.resumeUri(req, cfg, config).asString
|
||||||
|
)
|
||||||
|
.withQuery("response_type", "code")
|
||||||
|
.withQuery("state", state.asString)
|
||||||
|
_ <- logger.debug(
|
||||||
|
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
|
||||||
)
|
)
|
||||||
.withQuery("response_type", "code")
|
resp <- Found(Location(Uri.unsafeFromString(uri.asString)))
|
||||||
.withQuery("state", cfg.clientId)
|
} yield resp
|
||||||
logger.debug(
|
|
||||||
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
|
|
||||||
) *>
|
|
||||||
Found(Location(Uri.unsafeFromString(uri.asString)))
|
|
||||||
case None =>
|
case None =>
|
||||||
logger.debug(s"No OAuth/OIDC provider found with id '$id'") *>
|
logger.debug(s"No OAuth/OIDC provider found with id '$id'") *>
|
||||||
NotFound()
|
NotFound()
|
||||||
@ -66,9 +70,20 @@ object CodeFlowRoutes {
|
|||||||
NotFound()
|
NotFound()
|
||||||
case Some(provider) =>
|
case Some(provider) =>
|
||||||
val codeFromReq = OptionT.fromOption[F](req.params.get("code"))
|
val codeFromReq = OptionT.fromOption[F](req.params.get("code"))
|
||||||
|
val stateParamValid = req.params
|
||||||
|
.get("state")
|
||||||
|
.exists(state => StateParam.isValidStateParam(state, config.serverSecret))
|
||||||
|
|
||||||
val userInfo = for {
|
val userInfo = for {
|
||||||
_ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}"))
|
_ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}"))
|
||||||
|
_ <-
|
||||||
|
if (stateParamValid) OptionT.pure[F](())
|
||||||
|
else
|
||||||
|
OptionT(
|
||||||
|
logger
|
||||||
|
.warn(s"Invalid state parameter returned from Idp!")
|
||||||
|
.as(Option.empty[Unit])
|
||||||
|
)
|
||||||
code <- codeFromReq
|
code <- codeFromReq
|
||||||
_ <- OptionT.liftF(
|
_ <- OptionT.liftF(
|
||||||
logger.trace(
|
logger.trace(
|
||||||
@ -91,7 +106,7 @@ object CodeFlowRoutes {
|
|||||||
s"$err$descr"
|
s"$err$descr"
|
||||||
}
|
}
|
||||||
.map(err => s": $err")
|
.map(err => s": $err")
|
||||||
.getOrElse("")
|
.getOrElse(": <no reason>")
|
||||||
|
|
||||||
logger.warn(s"Error resuming code flow from '${id.id}'$reason") *>
|
logger.warn(s"Error resuming code flow from '${id.id}'$reason") *>
|
||||||
onUserInfo.handle(req, provider, None)
|
onUserInfo.handle(req, provider, None)
|
||||||
|
49
modules/oidc/src/main/scala/docspell/oidc/StateParam.scala
Normal file
49
modules/oidc/src/main/scala/docspell/oidc/StateParam.scala
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 Eike K. & Contributors
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
*/
|
||||||
|
|
||||||
|
package docspell.oidc
|
||||||
|
|
||||||
|
import cats.effect._
|
||||||
|
import cats.syntax.all._
|
||||||
|
|
||||||
|
import docspell.common.util.{Random, SignUtil}
|
||||||
|
|
||||||
|
import scodec.bits.Bases.Alphabets
|
||||||
|
import scodec.bits.ByteVector
|
||||||
|
|
||||||
|
final case class StateParam(value: String, sig: ByteVector) {
|
||||||
|
def asString: String =
|
||||||
|
s"$value$$${sig.toBase64UrlNoPad}"
|
||||||
|
|
||||||
|
def isValid(key: ByteVector): Boolean = {
|
||||||
|
val actual = SignUtil.signString(value, key)
|
||||||
|
SignUtil.isEqual(actual, sig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object StateParam {
|
||||||
|
|
||||||
|
def generate[F[_]: Sync](key: ByteVector): F[StateParam] =
|
||||||
|
Random[F].string(8).map { v =>
|
||||||
|
val sig = SignUtil.signString(v, key)
|
||||||
|
StateParam(v, sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
def fromString(str: String, key: ByteVector): Either[String, StateParam] =
|
||||||
|
str.split('$') match {
|
||||||
|
case Array(v, sig) =>
|
||||||
|
ByteVector
|
||||||
|
.fromBase64Descriptive(sig, Alphabets.Base64UrlNoPad)
|
||||||
|
.map(s => StateParam(v, s))
|
||||||
|
.filterOrElse(_.isValid(key), s"Invalid signature in state param: $str")
|
||||||
|
|
||||||
|
case _ =>
|
||||||
|
Left(s"Invalid state parameter: $str")
|
||||||
|
}
|
||||||
|
|
||||||
|
def isValidStateParam(state: String, key: ByteVector) =
|
||||||
|
fromString(state, key).isRight
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 Eike K. & Contributors
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
*/
|
||||||
|
|
||||||
|
package docspell.oidc
|
||||||
|
|
||||||
|
import cats.effect._
|
||||||
|
|
||||||
|
import munit.CatsEffectSuite
|
||||||
|
import scodec.bits.ByteVector
|
||||||
|
|
||||||
|
class StateParamTest extends CatsEffectSuite {
|
||||||
|
|
||||||
|
private val key = ByteVector.fromValidHex("caffee")
|
||||||
|
private val key2 = key ++ key
|
||||||
|
|
||||||
|
test("generate") {
|
||||||
|
for {
|
||||||
|
p <- StateParam.generate[IO](key)
|
||||||
|
_ = {
|
||||||
|
assert(p.value.length > 8)
|
||||||
|
assert(p.sig.nonEmpty)
|
||||||
|
assert(p.isValid(key))
|
||||||
|
assert(!p.isValid(key2))
|
||||||
|
}
|
||||||
|
} yield ()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("fromString") {
|
||||||
|
for {
|
||||||
|
p <- StateParam.generate[IO](key)
|
||||||
|
str = p.asString
|
||||||
|
p2 = StateParam.fromString(str, key).fold(sys.error, identity)
|
||||||
|
p3 = StateParam.fromString(str, key2)
|
||||||
|
p4 = StateParam.fromString("uiaeuiaeue", key)
|
||||||
|
p5 = StateParam.fromString(str + "$" + str, key)
|
||||||
|
_ = {
|
||||||
|
assertEquals(p2, p)
|
||||||
|
assert(p3.isLeft)
|
||||||
|
assert(p4.isLeft)
|
||||||
|
assert(p5.isLeft)
|
||||||
|
}
|
||||||
|
} yield ()
|
||||||
|
}
|
||||||
|
}
|
@ -31,7 +31,8 @@ object OpenId {
|
|||||||
ClientRequestInfo
|
ClientRequestInfo
|
||||||
.getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid",
|
.getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid",
|
||||||
id =>
|
id =>
|
||||||
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider)
|
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider),
|
||||||
|
config.auth.serverSecret
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] =
|
def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] =
|
||||||
|
Reference in New Issue
Block a user