Merge pull request #1633 from eikek/oidc-improvements

OIDC improvements
This commit is contained in:
mergify[bot] 2022-07-07 13:41:41 +00:00 committed by GitHub
commit d0d8a8fbe7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 251 additions and 74 deletions

View File

@ -50,7 +50,7 @@ object Contact {
p match { p match {
case LenientUri.RootPath => false case LenientUri.RootPath => false
case LenientUri.EmptyPath => false case LenientUri.EmptyPath => false
case LenientUri.NonEmptyPath(segs) => case LenientUri.NonEmptyPath(segs, _) =>
Ident.fromString(segs.last).isRight && Ident.fromString(segs.last).isRight &&
segs.init.takeRight(3) == List("open", "upload", "item") segs.init.takeRight(3) == List("open", "upload", "item")
} }

View File

@ -6,11 +6,10 @@
package docspell.backend.auth package docspell.backend.auth
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import cats.implicits._ import cats.implicits._
import docspell.common.util.SignUtil
import scodec.bits._ import scodec.bits._
private[auth] object TokenUtil { private[auth] object TokenUtil {
@ -34,11 +33,8 @@ private[auth] object TokenUtil {
signRaw(raw, key) signRaw(raw, key)
} }
private def signRaw(data: String, key: ByteVector): String = { private def signRaw(data: String, key: ByteVector): String =
val mac = Mac.getInstance("HmacSHA1") SignUtil.signString(data, key).toBase64
mac.init(new SecretKeySpec(key.toArray, "HmacSHA1"))
ByteVector.view(mac.doFinal(data.getBytes(utf8))).toBase64
}
def b64enc(s: String): String = def b64enc(s: String): String =
ByteVector.view(s.getBytes(utf8)).toBase64 ByteVector.view(s.getBytes(utf8)).toBase64
@ -52,5 +48,4 @@ private[auth] object TokenUtil {
def constTimeEq(s1: String, s2: String): Boolean = def constTimeEq(s1: String, s2: String): Boolean =
s1.zip(s2) s1.zip(s2)
.foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length .foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length
} }

View File

@ -20,8 +20,10 @@ trait OSignup[F[_]] {
def register(cfg: Config)(data: RegisterData): F[SignupResult] def register(cfg: Config)(data: RegisterData): F[SignupResult]
/** Creates the given account if it doesn't exist. */ /** Creates the given account if it doesn't exist. This is independent from signup
def setupExternal(cfg: Config)(data: ExternalAccount): F[SignupResult] * configuration.
*/
def setupExternal(data: ExternalAccount): F[SignupResult]
def newInvite(cfg: Config)(password: Password): F[NewInviteResult] def newInvite(cfg: Config)(password: Password): F[NewInviteResult]
} }
@ -77,36 +79,31 @@ object OSignup {
} }
} }
def setupExternal(cfg: Config)(data: ExternalAccount): F[SignupResult] = def setupExternal(data: ExternalAccount): F[SignupResult] =
cfg.mode match { if (data.source == AccountSource.Local)
case Config.Mode.Closed => SignupResult
SignupResult.signupClosed.pure[F] .failure(new Exception("Account source must not be LOCAL!"))
case _ => .pure[F]
if (data.source == AccountSource.Local) else
SignupResult for {
.failure(new Exception("Account source must not be LOCAL!")) recs <- makeRecords(data.collName, data.login, Password(""), data.source)
.pure[F] cres <- store.add(
else RCollective.insert(recs._1),
for { RCollective.existsById(data.collName)
recs <- makeRecords(data.collName, data.login, Password(""), data.source) )
cres <- store.add( ures <- store.add(RUser.insert(recs._2), RUser.exists(data.login))
RCollective.insert(recs._1), res = cres match {
RCollective.existsById(data.collName) case AddResult.Failure(ex) =>
) SignupResult.failure(ex)
ures <- store.add(RUser.insert(recs._2), RUser.exists(data.login)) case _ =>
res = cres match { ures match {
case AddResult.Failure(ex) => case AddResult.Failure(ex) =>
SignupResult.failure(ex) SignupResult.failure(ex)
case _ => case _ =>
ures match { SignupResult.success
case AddResult.Failure(ex) =>
SignupResult.failure(ex)
case _ =>
SignupResult.success
}
} }
} yield res }
} } yield res
private def retryInvite(res: SignupResult): Boolean = private def retryInvite(res: SignupResult): Boolean =
res match { res match {

View File

@ -121,7 +121,7 @@ object LenientUri {
val isRoot = true val isRoot = true
val isEmpty = false val isEmpty = false
def /(seg: String): Path = def /(seg: String): Path =
NonEmptyPath(NonEmptyList.of(seg)) NonEmptyPath(NonEmptyList.of(seg), false)
def asString = "/" def asString = "/"
} }
case object EmptyPath extends Path { case object EmptyPath extends Path {
@ -129,20 +129,22 @@ object LenientUri {
val isRoot = false val isRoot = false
val isEmpty = true val isEmpty = true
def /(seg: String): Path = def /(seg: String): Path =
NonEmptyPath(NonEmptyList.of(seg)) NonEmptyPath(NonEmptyList.of(seg), false)
def asString = "" def asString = ""
} }
case class NonEmptyPath(segs: NonEmptyList[String]) extends Path { case class NonEmptyPath(segs: NonEmptyList[String], trailingSlash: Boolean)
extends Path {
def segments = segs.toList def segments = segs.toList
val isEmpty = false val isEmpty = false
val isRoot = false val isRoot = false
private val slashSuffix = if (trailingSlash) "/" else ""
def /(seg: String): Path = def /(seg: String): Path =
copy(segs = segs.append(seg)) copy(segs = segs.append(seg))
def asString = def asString =
segs.head match { segs.head match {
case "." => segments.map(percentEncode).mkString("/") case "." => segments.map(percentEncode).mkString("/") + slashSuffix
case ".." => segments.map(percentEncode).mkString("/") case ".." => segments.map(percentEncode).mkString("/") + slashSuffix
case _ => "/" + segments.map(percentEncode).mkString("/") case _ => "/" + segments.map(percentEncode).mkString("/") + slashSuffix
} }
} }
@ -157,14 +159,14 @@ object LenientUri {
str.trim match { str.trim match {
case "/" => Right(RootPath) case "/" => Right(RootPath)
case "" => Right(EmptyPath) case "" => Right(EmptyPath)
case _ => case uriStr =>
Either.fromOption( Either.fromOption(
stripLeading(str, '/') stripLeading(uriStr, '/')
.split('/') .split('/')
.toList .toList
.traverse(percentDecode) .traverse(percentDecode)
.flatMap(NonEmptyList.fromList) .flatMap(NonEmptyList.fromList)
.map(NonEmptyPath.apply), .map(NonEmptyPath(_, uriStr.endsWith("/"))),
s"Invalid path: $str" s"Invalid path: $str"
) )
} }

View File

@ -62,7 +62,7 @@ object UrlMatcher {
// strip path to only match prefixes // strip path to only match prefixes
val mPath: LenientUri.Path = val mPath: LenientUri.Path =
NonEmptyList.fromList(url.path.segments.take(pathSegmentCount)) match { NonEmptyList.fromList(url.path.segments.take(pathSegmentCount)) match {
case Some(nel) => LenientUri.NonEmptyPath(nel) case Some(nel) => LenientUri.NonEmptyPath(nel, false)
case None => LenientUri.RootPath case None => LenientUri.RootPath
} }

View File

@ -0,0 +1,39 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.common.util
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import scodec.bits.ByteVector
object SignUtil {
private val utf8 = java.nio.charset.StandardCharsets.UTF_8
private val macAlgo = "HmacSHA1"
private def getMac(key: ByteVector) = {
val mac = Mac.getInstance(macAlgo)
mac.init(new SecretKeySpec(key.toArray, macAlgo))
mac
}
def signString(data: String, key: ByteVector): ByteVector = {
val mac = getMac(key)
ByteVector.view(mac.doFinal(data.getBytes(utf8)))
}
def signBytes(data: ByteVector, key: ByteVector): ByteVector = {
val mac = getMac(key)
ByteVector.view(mac.doFinal(data.toArray))
}
def isEqual(sig1: ByteVector, sig2: ByteVector): Boolean =
sig1
.zipWith(sig2)((b1, b2) => (b1 - b2).toByte)
.foldLeft(true)(_ && _ == 0) && sig1.length == sig2.length
}

View File

@ -29,4 +29,11 @@ class LenientUriTest extends FunSuite {
) )
assertEquals(LenientUri.percentDecode("a%25b%5Cc%7Cd%23e"), "a%b\\c|d#e".some) assertEquals(LenientUri.percentDecode("a%25b%5Cc%7Cd%23e"), "a%b\\c|d#e".some)
} }
test("parse with trailing slash") {
assertEquals(LenientUri.unsafe("http://a.com/").asString, "http://a.com/")
assertEquals(LenientUri.unsafe("http://a.com").asString, "http://a.com")
assertEquals(LenientUri.unsafe("http://a.com/path").asString, "http://a.com/path")
assertEquals(LenientUri.unsafe("http://a.com/path/").asString, "http://a.com/path/")
}
} }

View File

@ -0,0 +1,20 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.common.util
import munit.FunSuite
import scodec.bits.ByteVector
class SignUtilTest extends FunSuite {
private val key = ByteVector.fromValidHex("caffee")
test("create and validate") {
val sig = SignUtil.signString("hello", key)
assert(SignUtil.isEqual(sig, SignUtil.signString("hello", key)))
}
}

View File

@ -111,15 +111,19 @@ object CodeFlow {
token <- r.attemptAs[AccessToken].value token <- r.attemptAs[AccessToken].value
_ <- token match { _ <- token match {
case Right(t) => case Right(t) =>
logger.trace(s"Got token response: $t") logger.trace(s"Got token response (status=${r.status.code}): $t")
case Left(err) => case Left(err) =>
logger.error(err)(s"Error decoding access token: ${err.getMessage}") logger.error(err)(
s"Error decoding access token (status=${r.status.code}): ${err.getMessage}"
)
} }
} yield token.toOption } yield token.toOption
case r => case r =>
logger for {
.error(s"Error obtaining access token '${r.status.code}' / ${r.as[String]}") body <- r.bodyText.compile.string
.map(_ => None) _ <- logger
.error(s"Error obtaining access token status=${r.status.code}, body=$body")
} yield None
}) })
} }
@ -177,5 +181,4 @@ object CodeFlow {
logAction = Some((msg: String) => logger.trace(msg)) logAction = Some((msg: String) => logger.trace(msg))
)(c) )(c)
} }
} }

View File

@ -9,6 +9,7 @@ package docspell.oidc
import docspell.common._ import docspell.common._
import org.http4s.Request import org.http4s.Request
import scodec.bits.ByteVector
trait CodeFlowConfig[F[_]] { trait CodeFlowConfig[F[_]] {
@ -22,17 +23,20 @@ trait CodeFlowConfig[F[_]] {
*/ */
def findProvider(id: Ident): Option[ProviderConfig] def findProvider(id: Ident): Option[ProviderConfig]
def serverSecret: ByteVector
} }
object CodeFlowConfig { object CodeFlowConfig {
def apply[F[_]]( def apply[F[_]](
url: Request[F] => LenientUri, url: Request[F] => LenientUri,
provider: Ident => Option[ProviderConfig] provider: Ident => Option[ProviderConfig],
secret: ByteVector
): CodeFlowConfig[F] = ): CodeFlowConfig[F] =
new CodeFlowConfig[F] { new CodeFlowConfig[F] {
def getEndpointUrl(req: Request[F]): LenientUri = url(req) def getEndpointUrl(req: Request[F]): LenientUri = url(req)
def findProvider(id: Ident): Option[ProviderConfig] = provider(id) def findProvider(id: Ident): Option[ProviderConfig] = provider(id)
val serverSecret = secret
} }
private[oidc] def resumeUri[F[_]]( private[oidc] def resumeUri[F[_]](
@ -41,5 +45,4 @@ object CodeFlowConfig {
cfg: CodeFlowConfig[F] cfg: CodeFlowConfig[F]
): LenientUri = ): LenientUri =
cfg.getEndpointUrl(req) / prov.providerId.id / "resume" cfg.getEndpointUrl(req) / prov.providerId.id / "resume"
} }

View File

@ -41,19 +41,23 @@ object CodeFlowRoutes {
case req @ GET -> Root / Ident(id) => case req @ GET -> Root / Ident(id) =>
config.findProvider(id) match { config.findProvider(id) match {
case Some(cfg) => case Some(cfg) =>
val uri = cfg.authorizeUrl for {
.withQuery("client_id", cfg.clientId) state <- StateParam.generate[F](config.serverSecret)
.withQuery("scope", cfg.scope) uri = cfg.authorizeUrl
.withQuery( .withQuery("client_id", cfg.clientId)
"redirect_uri", .withQuery("scope", cfg.scope)
CodeFlowConfig.resumeUri(req, cfg, config).asString .withQuery(
"redirect_uri",
CodeFlowConfig.resumeUri(req, cfg, config).asString
)
.withQuery("response_type", "code")
.withQuery("state", state.asString)
_ <- logger.debug(
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
) )
.withQuery("response_type", "code") resp <- Found(Location(Uri.unsafeFromString(uri.asString)))
.withQuery("state", cfg.clientId) } yield resp
logger.debug(
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
) *>
Found(Location(Uri.unsafeFromString(uri.asString)))
case None => case None =>
logger.debug(s"No OAuth/OIDC provider found with id '$id'") *> logger.debug(s"No OAuth/OIDC provider found with id '$id'") *>
NotFound() NotFound()
@ -66,9 +70,20 @@ object CodeFlowRoutes {
NotFound() NotFound()
case Some(provider) => case Some(provider) =>
val codeFromReq = OptionT.fromOption[F](req.params.get("code")) val codeFromReq = OptionT.fromOption[F](req.params.get("code"))
val stateParamValid = req.params
.get("state")
.exists(state => StateParam.isValidStateParam(state, config.serverSecret))
val userInfo = for { val userInfo = for {
_ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}")) _ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}"))
_ <-
if (stateParamValid) OptionT.pure[F](())
else
OptionT(
logger
.warn(s"Invalid state parameter returned from Idp!")
.as(Option.empty[Unit])
)
code <- codeFromReq code <- codeFromReq
_ <- OptionT.liftF( _ <- OptionT.liftF(
logger.trace( logger.trace(
@ -91,7 +106,7 @@ object CodeFlowRoutes {
s"$err$descr" s"$err$descr"
} }
.map(err => s": $err") .map(err => s": $err")
.getOrElse("") .getOrElse(": <no reason>")
logger.warn(s"Error resuming code flow from '${id.id}'$reason") *> logger.warn(s"Error resuming code flow from '${id.id}'$reason") *>
onUserInfo.handle(req, provider, None) onUserInfo.handle(req, provider, None)

View File

@ -0,0 +1,49 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.oidc
import cats.effect._
import cats.syntax.all._
import docspell.common.util.{Random, SignUtil}
import scodec.bits.Bases.Alphabets
import scodec.bits.ByteVector
final case class StateParam(value: String, sig: ByteVector) {
def asString: String =
s"$value$$${sig.toBase64UrlNoPad}"
def isValid(key: ByteVector): Boolean = {
val actual = SignUtil.signString(value, key)
SignUtil.isEqual(actual, sig)
}
}
object StateParam {
def generate[F[_]: Sync](key: ByteVector): F[StateParam] =
Random[F].string(8).map { v =>
val sig = SignUtil.signString(v, key)
StateParam(v, sig)
}
def fromString(str: String, key: ByteVector): Either[String, StateParam] =
str.split('$') match {
case Array(v, sig) =>
ByteVector
.fromBase64Descriptive(sig, Alphabets.Base64UrlNoPad)
.map(s => StateParam(v, s))
.filterOrElse(_.isValid(key), s"Invalid signature in state param: $str")
case _ =>
Left(s"Invalid state parameter: $str")
}
def isValidStateParam(state: String, key: ByteVector) =
fromString(state, key).isRight
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2020 Eike K. & Contributors
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
package docspell.oidc
import cats.effect._
import munit.CatsEffectSuite
import scodec.bits.ByteVector
class StateParamTest extends CatsEffectSuite {
private val key = ByteVector.fromValidHex("caffee")
private val key2 = key ++ key
test("generate") {
for {
p <- StateParam.generate[IO](key)
_ = {
assert(p.value.length > 8)
assert(p.sig.nonEmpty)
assert(p.isValid(key))
assert(!p.isValid(key2))
}
} yield ()
}
test("fromString") {
for {
p <- StateParam.generate[IO](key)
str = p.asString
p2 = StateParam.fromString(str, key).fold(sys.error, identity)
p3 = StateParam.fromString(str, key2)
p4 = StateParam.fromString("uiaeuiaeue", key)
p5 = StateParam.fromString(str + "$" + str, key)
_ = {
assertEquals(p2, p)
assert(p3.isLeft)
assert(p4.isLeft)
assert(p5.isLeft)
}
} yield ()
}
}

View File

@ -31,7 +31,8 @@ object OpenId {
ClientRequestInfo ClientRequestInfo
.getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid", .getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid",
id => id =>
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider) config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider),
config.auth.serverSecret
) )
def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] = def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] =
@ -104,9 +105,7 @@ object OpenId {
import dsl._ import dsl._
for { for {
setup <- backend.signup.setupExternal(cfg.backend.signup)( setup <- backend.signup.setupExternal(ExternalAccount(accountId))
ExternalAccount(accountId)
)
res <- setup match { res <- setup match {
case SignupResult.Failure(ex) => case SignupResult.Failure(ex) =>
logger.error(ex)(s"Error when creating external account!") *> logger.error(ex)(s"Error when creating external account!") *>

View File

@ -24,7 +24,8 @@ object FileUrlReader {
scheme = Nel.of(scheme), scheme = Nel.of(scheme),
authority = Some(""), authority = Some(""),
path = LenientUri.NonEmptyPath( path = LenientUri.NonEmptyPath(
Nel.of(key.collective.id, key.category.id.id, key.id.id) Nel.of(key.collective.id, key.category.id.id, key.id.id),
false
), ),
query = None, query = None,
fragment = None fragment = None