Merge pull request #1633 from eikek/oidc-improvements

OIDC improvements
This commit is contained in:
mergify[bot] 2022-07-07 13:41:41 +00:00 committed by GitHub
commit d0d8a8fbe7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 251 additions and 74 deletions

View File

@ -50,7 +50,7 @@ object Contact {
p match {
case LenientUri.RootPath => false
case LenientUri.EmptyPath => false
case LenientUri.NonEmptyPath(segs) =>
case LenientUri.NonEmptyPath(segs, _) =>
Ident.fromString(segs.last).isRight &&
segs.init.takeRight(3) == List("open", "upload", "item")
}

View File

@ -6,11 +6,10 @@
package docspell.backend.auth
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import cats.implicits._
import docspell.common.util.SignUtil
import scodec.bits._
private[auth] object TokenUtil {
@ -34,11 +33,8 @@ private[auth] object TokenUtil {
signRaw(raw, key)
}
private def signRaw(data: String, key: ByteVector): String = {
val mac = Mac.getInstance("HmacSHA1")
mac.init(new SecretKeySpec(key.toArray, "HmacSHA1"))
ByteVector.view(mac.doFinal(data.getBytes(utf8))).toBase64
}
private def signRaw(data: String, key: ByteVector): String =
SignUtil.signString(data, key).toBase64
def b64enc(s: String): String =
ByteVector.view(s.getBytes(utf8)).toBase64
@ -52,5 +48,4 @@ private[auth] object TokenUtil {
def constTimeEq(s1: String, s2: String): Boolean =
s1.zip(s2)
.foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length
}

View File

@ -20,8 +20,10 @@ trait OSignup[F[_]] {
def register(cfg: Config)(data: RegisterData): F[SignupResult]
/** Creates the given account if it doesn't exist. */
def setupExternal(cfg: Config)(data: ExternalAccount): F[SignupResult]
/** Creates the given account if it doesn't exist. This is independent from signup
* configuration.
*/
def setupExternal(data: ExternalAccount): F[SignupResult]
def newInvite(cfg: Config)(password: Password): F[NewInviteResult]
}
@ -77,36 +79,31 @@ object OSignup {
}
}
def setupExternal(cfg: Config)(data: ExternalAccount): F[SignupResult] =
cfg.mode match {
case Config.Mode.Closed =>
SignupResult.signupClosed.pure[F]
case _ =>
if (data.source == AccountSource.Local)
SignupResult
.failure(new Exception("Account source must not be LOCAL!"))
.pure[F]
else
for {
recs <- makeRecords(data.collName, data.login, Password(""), data.source)
cres <- store.add(
RCollective.insert(recs._1),
RCollective.existsById(data.collName)
)
ures <- store.add(RUser.insert(recs._2), RUser.exists(data.login))
res = cres match {
def setupExternal(data: ExternalAccount): F[SignupResult] =
if (data.source == AccountSource.Local)
SignupResult
.failure(new Exception("Account source must not be LOCAL!"))
.pure[F]
else
for {
recs <- makeRecords(data.collName, data.login, Password(""), data.source)
cres <- store.add(
RCollective.insert(recs._1),
RCollective.existsById(data.collName)
)
ures <- store.add(RUser.insert(recs._2), RUser.exists(data.login))
res = cres match {
case AddResult.Failure(ex) =>
SignupResult.failure(ex)
case _ =>
ures match {
case AddResult.Failure(ex) =>
SignupResult.failure(ex)
case _ =>
ures match {
case AddResult.Failure(ex) =>
SignupResult.failure(ex)
case _ =>
SignupResult.success
}
SignupResult.success
}
} yield res
}
}
} yield res
private def retryInvite(res: SignupResult): Boolean =
res match {

View File

@ -121,7 +121,7 @@ object LenientUri {
val isRoot = true
val isEmpty = false
def /(seg: String): Path =
NonEmptyPath(NonEmptyList.of(seg))
NonEmptyPath(NonEmptyList.of(seg), false)
def asString = "/"
}
case object EmptyPath extends Path {
@ -129,20 +129,22 @@ object LenientUri {
val isRoot = false
val isEmpty = true
def /(seg: String): Path =
NonEmptyPath(NonEmptyList.of(seg))
NonEmptyPath(NonEmptyList.of(seg), false)
def asString = ""
}
case class NonEmptyPath(segs: NonEmptyList[String]) extends Path {
case class NonEmptyPath(segs: NonEmptyList[String], trailingSlash: Boolean)
extends Path {
def segments = segs.toList
val isEmpty = false
val isRoot = false
private val slashSuffix = if (trailingSlash) "/" else ""
def /(seg: String): Path =
copy(segs = segs.append(seg))
def asString =
segs.head match {
case "." => segments.map(percentEncode).mkString("/")
case ".." => segments.map(percentEncode).mkString("/")
case _ => "/" + segments.map(percentEncode).mkString("/")
case "." => segments.map(percentEncode).mkString("/") + slashSuffix
case ".." => segments.map(percentEncode).mkString("/") + slashSuffix
case _ => "/" + segments.map(percentEncode).mkString("/") + slashSuffix
}
}
@ -157,14 +159,14 @@ object LenientUri {
str.trim match {
case "/" => Right(RootPath)
case "" => Right(EmptyPath)
case _ =>
case uriStr =>
Either.fromOption(
stripLeading(str, '/')
stripLeading(uriStr, '/')
.split('/')
.toList
.traverse(percentDecode)
.flatMap(NonEmptyList.fromList)
.map(NonEmptyPath.apply),
.map(NonEmptyPath(_, uriStr.endsWith("/"))),
s"Invalid path: $str"
)
}

View File

@ -62,7 +62,7 @@ object UrlMatcher {
// strip path to only match prefixes
val mPath: LenientUri.Path =
NonEmptyList.fromList(url.path.segments.take(pathSegmentCount)) match {
case Some(nel) => LenientUri.NonEmptyPath(nel)
case Some(nel) => LenientUri.NonEmptyPath(nel, false)
case None => LenientUri.RootPath
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.common.util
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import scodec.bits.ByteVector
object SignUtil {
private val utf8 = java.nio.charset.StandardCharsets.UTF_8
private val macAlgo = "HmacSHA1"
private def getMac(key: ByteVector) = {
val mac = Mac.getInstance(macAlgo)
mac.init(new SecretKeySpec(key.toArray, macAlgo))
mac
}
def signString(data: String, key: ByteVector): ByteVector = {
val mac = getMac(key)
ByteVector.view(mac.doFinal(data.getBytes(utf8)))
}
def signBytes(data: ByteVector, key: ByteVector): ByteVector = {
val mac = getMac(key)
ByteVector.view(mac.doFinal(data.toArray))
}
def isEqual(sig1: ByteVector, sig2: ByteVector): Boolean =
sig1
.zipWith(sig2)((b1, b2) => (b1 - b2).toByte)
.foldLeft(true)(_ && _ == 0) && sig1.length == sig2.length
}

View File

@ -29,4 +29,11 @@ class LenientUriTest extends FunSuite {
)
assertEquals(LenientUri.percentDecode("a%25b%5Cc%7Cd%23e"), "a%b\\c|d#e".some)
}
test("parse with trailing slash") {
assertEquals(LenientUri.unsafe("http://a.com/").asString, "http://a.com/")
assertEquals(LenientUri.unsafe("http://a.com").asString, "http://a.com")
assertEquals(LenientUri.unsafe("http://a.com/path").asString, "http://a.com/path")
assertEquals(LenientUri.unsafe("http://a.com/path/").asString, "http://a.com/path/")
}
}

View File

@ -0,0 +1,20 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.common.util
import munit.FunSuite
import scodec.bits.ByteVector
class SignUtilTest extends FunSuite {
private val key = ByteVector.fromValidHex("caffee")
test("create and validate") {
val sig = SignUtil.signString("hello", key)
assert(SignUtil.isEqual(sig, SignUtil.signString("hello", key)))
}
}

View File

@ -111,15 +111,19 @@ object CodeFlow {
token <- r.attemptAs[AccessToken].value
_ <- token match {
case Right(t) =>
logger.trace(s"Got token response: $t")
logger.trace(s"Got token response (status=${r.status.code}): $t")
case Left(err) =>
logger.error(err)(s"Error decoding access token: ${err.getMessage}")
logger.error(err)(
s"Error decoding access token (status=${r.status.code}): ${err.getMessage}"
)
}
} yield token.toOption
case r =>
logger
.error(s"Error obtaining access token '${r.status.code}' / ${r.as[String]}")
.map(_ => None)
for {
body <- r.bodyText.compile.string
_ <- logger
.error(s"Error obtaining access token status=${r.status.code}, body=$body")
} yield None
})
}
@ -177,5 +181,4 @@ object CodeFlow {
logAction = Some((msg: String) => logger.trace(msg))
)(c)
}
}

View File

@ -9,6 +9,7 @@ package docspell.oidc
import docspell.common._
import org.http4s.Request
import scodec.bits.ByteVector
trait CodeFlowConfig[F[_]] {
@ -22,17 +23,20 @@ trait CodeFlowConfig[F[_]] {
*/
def findProvider(id: Ident): Option[ProviderConfig]
def serverSecret: ByteVector
}
object CodeFlowConfig {
def apply[F[_]](
url: Request[F] => LenientUri,
provider: Ident => Option[ProviderConfig]
provider: Ident => Option[ProviderConfig],
secret: ByteVector
): CodeFlowConfig[F] =
new CodeFlowConfig[F] {
def getEndpointUrl(req: Request[F]): LenientUri = url(req)
def findProvider(id: Ident): Option[ProviderConfig] = provider(id)
val serverSecret = secret
}
private[oidc] def resumeUri[F[_]](
@ -41,5 +45,4 @@ object CodeFlowConfig {
cfg: CodeFlowConfig[F]
): LenientUri =
cfg.getEndpointUrl(req) / prov.providerId.id / "resume"
}

View File

@ -41,19 +41,23 @@ object CodeFlowRoutes {
case req @ GET -> Root / Ident(id) =>
config.findProvider(id) match {
case Some(cfg) =>
val uri = cfg.authorizeUrl
.withQuery("client_id", cfg.clientId)
.withQuery("scope", cfg.scope)
.withQuery(
"redirect_uri",
CodeFlowConfig.resumeUri(req, cfg, config).asString
for {
state <- StateParam.generate[F](config.serverSecret)
uri = cfg.authorizeUrl
.withQuery("client_id", cfg.clientId)
.withQuery("scope", cfg.scope)
.withQuery(
"redirect_uri",
CodeFlowConfig.resumeUri(req, cfg, config).asString
)
.withQuery("response_type", "code")
.withQuery("state", state.asString)
_ <- logger.debug(
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
)
.withQuery("response_type", "code")
.withQuery("state", cfg.clientId)
logger.debug(
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
) *>
Found(Location(Uri.unsafeFromString(uri.asString)))
resp <- Found(Location(Uri.unsafeFromString(uri.asString)))
} yield resp
case None =>
logger.debug(s"No OAuth/OIDC provider found with id '$id'") *>
NotFound()
@ -66,9 +70,20 @@ object CodeFlowRoutes {
NotFound()
case Some(provider) =>
val codeFromReq = OptionT.fromOption[F](req.params.get("code"))
val stateParamValid = req.params
.get("state")
.exists(state => StateParam.isValidStateParam(state, config.serverSecret))
val userInfo = for {
_ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}"))
_ <-
if (stateParamValid) OptionT.pure[F](())
else
OptionT(
logger
.warn(s"Invalid state parameter returned from Idp!")
.as(Option.empty[Unit])
)
code <- codeFromReq
_ <- OptionT.liftF(
logger.trace(
@ -91,7 +106,7 @@ object CodeFlowRoutes {
s"$err$descr"
}
.map(err => s": $err")
.getOrElse("")
.getOrElse(": <no reason>")
logger.warn(s"Error resuming code flow from '${id.id}'$reason") *>
onUserInfo.handle(req, provider, None)

View File

@ -0,0 +1,49 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.oidc
import cats.effect._
import cats.syntax.all._
import docspell.common.util.{Random, SignUtil}
import scodec.bits.Bases.Alphabets
import scodec.bits.ByteVector
final case class StateParam(value: String, sig: ByteVector) {
def asString: String =
s"$value$$${sig.toBase64UrlNoPad}"
def isValid(key: ByteVector): Boolean = {
val actual = SignUtil.signString(value, key)
SignUtil.isEqual(actual, sig)
}
}
object StateParam {
def generate[F[_]: Sync](key: ByteVector): F[StateParam] =
Random[F].string(8).map { v =>
val sig = SignUtil.signString(v, key)
StateParam(v, sig)
}
def fromString(str: String, key: ByteVector): Either[String, StateParam] =
str.split('$') match {
case Array(v, sig) =>
ByteVector
.fromBase64Descriptive(sig, Alphabets.Base64UrlNoPad)
.map(s => StateParam(v, s))
.filterOrElse(_.isValid(key), s"Invalid signature in state param: $str")
case _ =>
Left(s"Invalid state parameter: $str")
}
def isValidStateParam(state: String, key: ByteVector) =
fromString(state, key).isRight
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.oidc
import cats.effect._
import munit.CatsEffectSuite
import scodec.bits.ByteVector
class StateParamTest extends CatsEffectSuite {
private val key = ByteVector.fromValidHex("caffee")
private val key2 = key ++ key
test("generate") {
for {
p <- StateParam.generate[IO](key)
_ = {
assert(p.value.length > 8)
assert(p.sig.nonEmpty)
assert(p.isValid(key))
assert(!p.isValid(key2))
}
} yield ()
}
test("fromString") {
for {
p <- StateParam.generate[IO](key)
str = p.asString
p2 = StateParam.fromString(str, key).fold(sys.error, identity)
p3 = StateParam.fromString(str, key2)
p4 = StateParam.fromString("uiaeuiaeue", key)
p5 = StateParam.fromString(str + "$" + str, key)
_ = {
assertEquals(p2, p)
assert(p3.isLeft)
assert(p4.isLeft)
assert(p5.isLeft)
}
} yield ()
}
}

View File

@ -31,7 +31,8 @@ object OpenId {
ClientRequestInfo
.getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid",
id =>
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider)
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider),
config.auth.serverSecret
)
def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] =
@ -104,9 +105,7 @@ object OpenId {
import dsl._
for {
setup <- backend.signup.setupExternal(cfg.backend.signup)(
ExternalAccount(accountId)
)
setup <- backend.signup.setupExternal(ExternalAccount(accountId))
res <- setup match {
case SignupResult.Failure(ex) =>
logger.error(ex)(s"Error when creating external account!") *>

View File

@ -24,7 +24,8 @@ object FileUrlReader {
scheme = Nel.of(scheme),
authority = Some(""),
path = LenientUri.NonEmptyPath(
Nel.of(key.collective.id, key.category.id.id, key.id.id)
Nel.of(key.collective.id, key.category.id.id, key.id.id),
false
),
query = None,
fragment = None