mirror of
https://github.com/TheAnachronism/docspell.git
synced 2025-04-05 10:59:33 +00:00
Merge pull request #1633 from eikek/oidc-improvements
OIDC improvements
This commit is contained in:
commit
d0d8a8fbe7
@ -50,7 +50,7 @@ object Contact {
|
||||
p match {
|
||||
case LenientUri.RootPath => false
|
||||
case LenientUri.EmptyPath => false
|
||||
case LenientUri.NonEmptyPath(segs) =>
|
||||
case LenientUri.NonEmptyPath(segs, _) =>
|
||||
Ident.fromString(segs.last).isRight &&
|
||||
segs.init.takeRight(3) == List("open", "upload", "item")
|
||||
}
|
||||
|
@ -6,11 +6,10 @@
|
||||
|
||||
package docspell.backend.auth
|
||||
|
||||
import javax.crypto.Mac
|
||||
import javax.crypto.spec.SecretKeySpec
|
||||
|
||||
import cats.implicits._
|
||||
|
||||
import docspell.common.util.SignUtil
|
||||
|
||||
import scodec.bits._
|
||||
|
||||
private[auth] object TokenUtil {
|
||||
@ -34,11 +33,8 @@ private[auth] object TokenUtil {
|
||||
signRaw(raw, key)
|
||||
}
|
||||
|
||||
private def signRaw(data: String, key: ByteVector): String = {
|
||||
val mac = Mac.getInstance("HmacSHA1")
|
||||
mac.init(new SecretKeySpec(key.toArray, "HmacSHA1"))
|
||||
ByteVector.view(mac.doFinal(data.getBytes(utf8))).toBase64
|
||||
}
|
||||
private def signRaw(data: String, key: ByteVector): String =
|
||||
SignUtil.signString(data, key).toBase64
|
||||
|
||||
def b64enc(s: String): String =
|
||||
ByteVector.view(s.getBytes(utf8)).toBase64
|
||||
@ -52,5 +48,4 @@ private[auth] object TokenUtil {
|
||||
def constTimeEq(s1: String, s2: String): Boolean =
|
||||
s1.zip(s2)
|
||||
.foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length
|
||||
|
||||
}
|
||||
|
@ -20,8 +20,10 @@ trait OSignup[F[_]] {
|
||||
|
||||
def register(cfg: Config)(data: RegisterData): F[SignupResult]
|
||||
|
||||
/** Creates the given account if it doesn't exist. */
|
||||
def setupExternal(cfg: Config)(data: ExternalAccount): F[SignupResult]
|
||||
/** Creates the given account if it doesn't exist. This is independent from signup
|
||||
* configuration.
|
||||
*/
|
||||
def setupExternal(data: ExternalAccount): F[SignupResult]
|
||||
|
||||
def newInvite(cfg: Config)(password: Password): F[NewInviteResult]
|
||||
}
|
||||
@ -77,36 +79,31 @@ object OSignup {
|
||||
}
|
||||
}
|
||||
|
||||
def setupExternal(cfg: Config)(data: ExternalAccount): F[SignupResult] =
|
||||
cfg.mode match {
|
||||
case Config.Mode.Closed =>
|
||||
SignupResult.signupClosed.pure[F]
|
||||
case _ =>
|
||||
if (data.source == AccountSource.Local)
|
||||
SignupResult
|
||||
.failure(new Exception("Account source must not be LOCAL!"))
|
||||
.pure[F]
|
||||
else
|
||||
for {
|
||||
recs <- makeRecords(data.collName, data.login, Password(""), data.source)
|
||||
cres <- store.add(
|
||||
RCollective.insert(recs._1),
|
||||
RCollective.existsById(data.collName)
|
||||
)
|
||||
ures <- store.add(RUser.insert(recs._2), RUser.exists(data.login))
|
||||
res = cres match {
|
||||
def setupExternal(data: ExternalAccount): F[SignupResult] =
|
||||
if (data.source == AccountSource.Local)
|
||||
SignupResult
|
||||
.failure(new Exception("Account source must not be LOCAL!"))
|
||||
.pure[F]
|
||||
else
|
||||
for {
|
||||
recs <- makeRecords(data.collName, data.login, Password(""), data.source)
|
||||
cres <- store.add(
|
||||
RCollective.insert(recs._1),
|
||||
RCollective.existsById(data.collName)
|
||||
)
|
||||
ures <- store.add(RUser.insert(recs._2), RUser.exists(data.login))
|
||||
res = cres match {
|
||||
case AddResult.Failure(ex) =>
|
||||
SignupResult.failure(ex)
|
||||
case _ =>
|
||||
ures match {
|
||||
case AddResult.Failure(ex) =>
|
||||
SignupResult.failure(ex)
|
||||
case _ =>
|
||||
ures match {
|
||||
case AddResult.Failure(ex) =>
|
||||
SignupResult.failure(ex)
|
||||
case _ =>
|
||||
SignupResult.success
|
||||
}
|
||||
SignupResult.success
|
||||
}
|
||||
} yield res
|
||||
}
|
||||
}
|
||||
} yield res
|
||||
|
||||
private def retryInvite(res: SignupResult): Boolean =
|
||||
res match {
|
||||
|
@ -121,7 +121,7 @@ object LenientUri {
|
||||
val isRoot = true
|
||||
val isEmpty = false
|
||||
def /(seg: String): Path =
|
||||
NonEmptyPath(NonEmptyList.of(seg))
|
||||
NonEmptyPath(NonEmptyList.of(seg), false)
|
||||
def asString = "/"
|
||||
}
|
||||
case object EmptyPath extends Path {
|
||||
@ -129,20 +129,22 @@ object LenientUri {
|
||||
val isRoot = false
|
||||
val isEmpty = true
|
||||
def /(seg: String): Path =
|
||||
NonEmptyPath(NonEmptyList.of(seg))
|
||||
NonEmptyPath(NonEmptyList.of(seg), false)
|
||||
def asString = ""
|
||||
}
|
||||
case class NonEmptyPath(segs: NonEmptyList[String]) extends Path {
|
||||
case class NonEmptyPath(segs: NonEmptyList[String], trailingSlash: Boolean)
|
||||
extends Path {
|
||||
def segments = segs.toList
|
||||
val isEmpty = false
|
||||
val isRoot = false
|
||||
private val slashSuffix = if (trailingSlash) "/" else ""
|
||||
def /(seg: String): Path =
|
||||
copy(segs = segs.append(seg))
|
||||
def asString =
|
||||
segs.head match {
|
||||
case "." => segments.map(percentEncode).mkString("/")
|
||||
case ".." => segments.map(percentEncode).mkString("/")
|
||||
case _ => "/" + segments.map(percentEncode).mkString("/")
|
||||
case "." => segments.map(percentEncode).mkString("/") + slashSuffix
|
||||
case ".." => segments.map(percentEncode).mkString("/") + slashSuffix
|
||||
case _ => "/" + segments.map(percentEncode).mkString("/") + slashSuffix
|
||||
}
|
||||
}
|
||||
|
||||
@ -157,14 +159,14 @@ object LenientUri {
|
||||
str.trim match {
|
||||
case "/" => Right(RootPath)
|
||||
case "" => Right(EmptyPath)
|
||||
case _ =>
|
||||
case uriStr =>
|
||||
Either.fromOption(
|
||||
stripLeading(str, '/')
|
||||
stripLeading(uriStr, '/')
|
||||
.split('/')
|
||||
.toList
|
||||
.traverse(percentDecode)
|
||||
.flatMap(NonEmptyList.fromList)
|
||||
.map(NonEmptyPath.apply),
|
||||
.map(NonEmptyPath(_, uriStr.endsWith("/"))),
|
||||
s"Invalid path: $str"
|
||||
)
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ object UrlMatcher {
|
||||
// strip path to only match prefixes
|
||||
val mPath: LenientUri.Path =
|
||||
NonEmptyList.fromList(url.path.segments.take(pathSegmentCount)) match {
|
||||
case Some(nel) => LenientUri.NonEmptyPath(nel)
|
||||
case Some(nel) => LenientUri.NonEmptyPath(nel, false)
|
||||
case None => LenientUri.RootPath
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright 2020 Eike K. & Contributors
|
||||
*
|
||||
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
*/
|
||||
|
||||
package docspell.common.util
|
||||
|
||||
import javax.crypto.Mac
|
||||
import javax.crypto.spec.SecretKeySpec
|
||||
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
object SignUtil {
|
||||
private val utf8 = java.nio.charset.StandardCharsets.UTF_8
|
||||
|
||||
private val macAlgo = "HmacSHA1"
|
||||
|
||||
private def getMac(key: ByteVector) = {
|
||||
val mac = Mac.getInstance(macAlgo)
|
||||
mac.init(new SecretKeySpec(key.toArray, macAlgo))
|
||||
mac
|
||||
}
|
||||
|
||||
def signString(data: String, key: ByteVector): ByteVector = {
|
||||
val mac = getMac(key)
|
||||
ByteVector.view(mac.doFinal(data.getBytes(utf8)))
|
||||
}
|
||||
|
||||
def signBytes(data: ByteVector, key: ByteVector): ByteVector = {
|
||||
val mac = getMac(key)
|
||||
ByteVector.view(mac.doFinal(data.toArray))
|
||||
}
|
||||
|
||||
def isEqual(sig1: ByteVector, sig2: ByteVector): Boolean =
|
||||
sig1
|
||||
.zipWith(sig2)((b1, b2) => (b1 - b2).toByte)
|
||||
.foldLeft(true)(_ && _ == 0) && sig1.length == sig2.length
|
||||
}
|
@ -29,4 +29,11 @@ class LenientUriTest extends FunSuite {
|
||||
)
|
||||
assertEquals(LenientUri.percentDecode("a%25b%5Cc%7Cd%23e"), "a%b\\c|d#e".some)
|
||||
}
|
||||
|
||||
test("parse with trailing slash") {
|
||||
assertEquals(LenientUri.unsafe("http://a.com/").asString, "http://a.com/")
|
||||
assertEquals(LenientUri.unsafe("http://a.com").asString, "http://a.com")
|
||||
assertEquals(LenientUri.unsafe("http://a.com/path").asString, "http://a.com/path")
|
||||
assertEquals(LenientUri.unsafe("http://a.com/path/").asString, "http://a.com/path/")
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Copyright 2020 Eike K. & Contributors
|
||||
*
|
||||
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
*/
|
||||
|
||||
package docspell.common.util
|
||||
|
||||
import munit.FunSuite
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
class SignUtilTest extends FunSuite {
|
||||
|
||||
private val key = ByteVector.fromValidHex("caffee")
|
||||
|
||||
test("create and validate") {
|
||||
val sig = SignUtil.signString("hello", key)
|
||||
assert(SignUtil.isEqual(sig, SignUtil.signString("hello", key)))
|
||||
}
|
||||
}
|
@ -111,15 +111,19 @@ object CodeFlow {
|
||||
token <- r.attemptAs[AccessToken].value
|
||||
_ <- token match {
|
||||
case Right(t) =>
|
||||
logger.trace(s"Got token response: $t")
|
||||
logger.trace(s"Got token response (status=${r.status.code}): $t")
|
||||
case Left(err) =>
|
||||
logger.error(err)(s"Error decoding access token: ${err.getMessage}")
|
||||
logger.error(err)(
|
||||
s"Error decoding access token (status=${r.status.code}): ${err.getMessage}"
|
||||
)
|
||||
}
|
||||
} yield token.toOption
|
||||
case r =>
|
||||
logger
|
||||
.error(s"Error obtaining access token '${r.status.code}' / ${r.as[String]}")
|
||||
.map(_ => None)
|
||||
for {
|
||||
body <- r.bodyText.compile.string
|
||||
_ <- logger
|
||||
.error(s"Error obtaining access token status=${r.status.code}, body=$body")
|
||||
} yield None
|
||||
})
|
||||
}
|
||||
|
||||
@ -177,5 +181,4 @@ object CodeFlow {
|
||||
logAction = Some((msg: String) => logger.trace(msg))
|
||||
)(c)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ package docspell.oidc
|
||||
import docspell.common._
|
||||
|
||||
import org.http4s.Request
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
trait CodeFlowConfig[F[_]] {
|
||||
|
||||
@ -22,17 +23,20 @@ trait CodeFlowConfig[F[_]] {
|
||||
*/
|
||||
def findProvider(id: Ident): Option[ProviderConfig]
|
||||
|
||||
def serverSecret: ByteVector
|
||||
}
|
||||
|
||||
object CodeFlowConfig {
|
||||
|
||||
def apply[F[_]](
|
||||
url: Request[F] => LenientUri,
|
||||
provider: Ident => Option[ProviderConfig]
|
||||
provider: Ident => Option[ProviderConfig],
|
||||
secret: ByteVector
|
||||
): CodeFlowConfig[F] =
|
||||
new CodeFlowConfig[F] {
|
||||
def getEndpointUrl(req: Request[F]): LenientUri = url(req)
|
||||
def findProvider(id: Ident): Option[ProviderConfig] = provider(id)
|
||||
val serverSecret = secret
|
||||
}
|
||||
|
||||
private[oidc] def resumeUri[F[_]](
|
||||
@ -41,5 +45,4 @@ object CodeFlowConfig {
|
||||
cfg: CodeFlowConfig[F]
|
||||
): LenientUri =
|
||||
cfg.getEndpointUrl(req) / prov.providerId.id / "resume"
|
||||
|
||||
}
|
||||
|
@ -41,19 +41,23 @@ object CodeFlowRoutes {
|
||||
case req @ GET -> Root / Ident(id) =>
|
||||
config.findProvider(id) match {
|
||||
case Some(cfg) =>
|
||||
val uri = cfg.authorizeUrl
|
||||
.withQuery("client_id", cfg.clientId)
|
||||
.withQuery("scope", cfg.scope)
|
||||
.withQuery(
|
||||
"redirect_uri",
|
||||
CodeFlowConfig.resumeUri(req, cfg, config).asString
|
||||
for {
|
||||
state <- StateParam.generate[F](config.serverSecret)
|
||||
uri = cfg.authorizeUrl
|
||||
.withQuery("client_id", cfg.clientId)
|
||||
.withQuery("scope", cfg.scope)
|
||||
.withQuery(
|
||||
"redirect_uri",
|
||||
CodeFlowConfig.resumeUri(req, cfg, config).asString
|
||||
)
|
||||
.withQuery("response_type", "code")
|
||||
.withQuery("state", state.asString)
|
||||
_ <- logger.debug(
|
||||
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
|
||||
)
|
||||
.withQuery("response_type", "code")
|
||||
.withQuery("state", cfg.clientId)
|
||||
logger.debug(
|
||||
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
|
||||
) *>
|
||||
Found(Location(Uri.unsafeFromString(uri.asString)))
|
||||
resp <- Found(Location(Uri.unsafeFromString(uri.asString)))
|
||||
} yield resp
|
||||
|
||||
case None =>
|
||||
logger.debug(s"No OAuth/OIDC provider found with id '$id'") *>
|
||||
NotFound()
|
||||
@ -66,9 +70,20 @@ object CodeFlowRoutes {
|
||||
NotFound()
|
||||
case Some(provider) =>
|
||||
val codeFromReq = OptionT.fromOption[F](req.params.get("code"))
|
||||
val stateParamValid = req.params
|
||||
.get("state")
|
||||
.exists(state => StateParam.isValidStateParam(state, config.serverSecret))
|
||||
|
||||
val userInfo = for {
|
||||
_ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}"))
|
||||
_ <-
|
||||
if (stateParamValid) OptionT.pure[F](())
|
||||
else
|
||||
OptionT(
|
||||
logger
|
||||
.warn(s"Invalid state parameter returned from Idp!")
|
||||
.as(Option.empty[Unit])
|
||||
)
|
||||
code <- codeFromReq
|
||||
_ <- OptionT.liftF(
|
||||
logger.trace(
|
||||
@ -91,7 +106,7 @@ object CodeFlowRoutes {
|
||||
s"$err$descr"
|
||||
}
|
||||
.map(err => s": $err")
|
||||
.getOrElse("")
|
||||
.getOrElse(": <no reason>")
|
||||
|
||||
logger.warn(s"Error resuming code flow from '${id.id}'$reason") *>
|
||||
onUserInfo.handle(req, provider, None)
|
||||
|
49
modules/oidc/src/main/scala/docspell/oidc/StateParam.scala
Normal file
49
modules/oidc/src/main/scala/docspell/oidc/StateParam.scala
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright 2020 Eike K. & Contributors
|
||||
*
|
||||
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
*/
|
||||
|
||||
package docspell.oidc
|
||||
|
||||
import cats.effect._
|
||||
import cats.syntax.all._
|
||||
|
||||
import docspell.common.util.{Random, SignUtil}
|
||||
|
||||
import scodec.bits.Bases.Alphabets
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
final case class StateParam(value: String, sig: ByteVector) {
|
||||
def asString: String =
|
||||
s"$value$$${sig.toBase64UrlNoPad}"
|
||||
|
||||
def isValid(key: ByteVector): Boolean = {
|
||||
val actual = SignUtil.signString(value, key)
|
||||
SignUtil.isEqual(actual, sig)
|
||||
}
|
||||
}
|
||||
|
||||
object StateParam {
|
||||
|
||||
def generate[F[_]: Sync](key: ByteVector): F[StateParam] =
|
||||
Random[F].string(8).map { v =>
|
||||
val sig = SignUtil.signString(v, key)
|
||||
StateParam(v, sig)
|
||||
}
|
||||
|
||||
def fromString(str: String, key: ByteVector): Either[String, StateParam] =
|
||||
str.split('$') match {
|
||||
case Array(v, sig) =>
|
||||
ByteVector
|
||||
.fromBase64Descriptive(sig, Alphabets.Base64UrlNoPad)
|
||||
.map(s => StateParam(v, s))
|
||||
.filterOrElse(_.isValid(key), s"Invalid signature in state param: $str")
|
||||
|
||||
case _ =>
|
||||
Left(s"Invalid state parameter: $str")
|
||||
}
|
||||
|
||||
def isValidStateParam(state: String, key: ByteVector) =
|
||||
fromString(state, key).isRight
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright 2020 Eike K. & Contributors
|
||||
*
|
||||
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
*/
|
||||
|
||||
package docspell.oidc
|
||||
|
||||
import cats.effect._
|
||||
|
||||
import munit.CatsEffectSuite
|
||||
import scodec.bits.ByteVector
|
||||
|
||||
class StateParamTest extends CatsEffectSuite {
|
||||
|
||||
private val key = ByteVector.fromValidHex("caffee")
|
||||
private val key2 = key ++ key
|
||||
|
||||
test("generate") {
|
||||
for {
|
||||
p <- StateParam.generate[IO](key)
|
||||
_ = {
|
||||
assert(p.value.length > 8)
|
||||
assert(p.sig.nonEmpty)
|
||||
assert(p.isValid(key))
|
||||
assert(!p.isValid(key2))
|
||||
}
|
||||
} yield ()
|
||||
}
|
||||
|
||||
test("fromString") {
|
||||
for {
|
||||
p <- StateParam.generate[IO](key)
|
||||
str = p.asString
|
||||
p2 = StateParam.fromString(str, key).fold(sys.error, identity)
|
||||
p3 = StateParam.fromString(str, key2)
|
||||
p4 = StateParam.fromString("uiaeuiaeue", key)
|
||||
p5 = StateParam.fromString(str + "$" + str, key)
|
||||
_ = {
|
||||
assertEquals(p2, p)
|
||||
assert(p3.isLeft)
|
||||
assert(p4.isLeft)
|
||||
assert(p5.isLeft)
|
||||
}
|
||||
} yield ()
|
||||
}
|
||||
}
|
@ -31,7 +31,8 @@ object OpenId {
|
||||
ClientRequestInfo
|
||||
.getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid",
|
||||
id =>
|
||||
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider)
|
||||
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider),
|
||||
config.auth.serverSecret
|
||||
)
|
||||
|
||||
def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] =
|
||||
@ -104,9 +105,7 @@ object OpenId {
|
||||
import dsl._
|
||||
|
||||
for {
|
||||
setup <- backend.signup.setupExternal(cfg.backend.signup)(
|
||||
ExternalAccount(accountId)
|
||||
)
|
||||
setup <- backend.signup.setupExternal(ExternalAccount(accountId))
|
||||
res <- setup match {
|
||||
case SignupResult.Failure(ex) =>
|
||||
logger.error(ex)(s"Error when creating external account!") *>
|
||||
|
@ -24,7 +24,8 @@ object FileUrlReader {
|
||||
scheme = Nel.of(scheme),
|
||||
authority = Some(""),
|
||||
path = LenientUri.NonEmptyPath(
|
||||
Nel.of(key.collective.id, key.category.id.id, key.id.id)
|
||||
Nel.of(key.collective.id, key.category.id.id, key.id.id),
|
||||
false
|
||||
),
|
||||
query = None,
|
||||
fragment = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user