mirror of
https://github.com/TheAnachronism/docspell.git
synced 2025-04-05 10:59:33 +00:00
Merge pull request #1633 from eikek/oidc-improvements
OIDC improvements
This commit is contained in:
commit
d0d8a8fbe7
@ -50,7 +50,7 @@ object Contact {
|
|||||||
p match {
|
p match {
|
||||||
case LenientUri.RootPath => false
|
case LenientUri.RootPath => false
|
||||||
case LenientUri.EmptyPath => false
|
case LenientUri.EmptyPath => false
|
||||||
case LenientUri.NonEmptyPath(segs) =>
|
case LenientUri.NonEmptyPath(segs, _) =>
|
||||||
Ident.fromString(segs.last).isRight &&
|
Ident.fromString(segs.last).isRight &&
|
||||||
segs.init.takeRight(3) == List("open", "upload", "item")
|
segs.init.takeRight(3) == List("open", "upload", "item")
|
||||||
}
|
}
|
||||||
|
@ -6,11 +6,10 @@
|
|||||||
|
|
||||||
package docspell.backend.auth
|
package docspell.backend.auth
|
||||||
|
|
||||||
import javax.crypto.Mac
|
|
||||||
import javax.crypto.spec.SecretKeySpec
|
|
||||||
|
|
||||||
import cats.implicits._
|
import cats.implicits._
|
||||||
|
|
||||||
|
import docspell.common.util.SignUtil
|
||||||
|
|
||||||
import scodec.bits._
|
import scodec.bits._
|
||||||
|
|
||||||
private[auth] object TokenUtil {
|
private[auth] object TokenUtil {
|
||||||
@ -34,11 +33,8 @@ private[auth] object TokenUtil {
|
|||||||
signRaw(raw, key)
|
signRaw(raw, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def signRaw(data: String, key: ByteVector): String = {
|
private def signRaw(data: String, key: ByteVector): String =
|
||||||
val mac = Mac.getInstance("HmacSHA1")
|
SignUtil.signString(data, key).toBase64
|
||||||
mac.init(new SecretKeySpec(key.toArray, "HmacSHA1"))
|
|
||||||
ByteVector.view(mac.doFinal(data.getBytes(utf8))).toBase64
|
|
||||||
}
|
|
||||||
|
|
||||||
def b64enc(s: String): String =
|
def b64enc(s: String): String =
|
||||||
ByteVector.view(s.getBytes(utf8)).toBase64
|
ByteVector.view(s.getBytes(utf8)).toBase64
|
||||||
@ -52,5 +48,4 @@ private[auth] object TokenUtil {
|
|||||||
def constTimeEq(s1: String, s2: String): Boolean =
|
def constTimeEq(s1: String, s2: String): Boolean =
|
||||||
s1.zip(s2)
|
s1.zip(s2)
|
||||||
.foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length
|
.foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -20,8 +20,10 @@ trait OSignup[F[_]] {
|
|||||||
|
|
||||||
def register(cfg: Config)(data: RegisterData): F[SignupResult]
|
def register(cfg: Config)(data: RegisterData): F[SignupResult]
|
||||||
|
|
||||||
/** Creates the given account if it doesn't exist. */
|
/** Creates the given account if it doesn't exist. This is independent from signup
|
||||||
def setupExternal(cfg: Config)(data: ExternalAccount): F[SignupResult]
|
* configuration.
|
||||||
|
*/
|
||||||
|
def setupExternal(data: ExternalAccount): F[SignupResult]
|
||||||
|
|
||||||
def newInvite(cfg: Config)(password: Password): F[NewInviteResult]
|
def newInvite(cfg: Config)(password: Password): F[NewInviteResult]
|
||||||
}
|
}
|
||||||
@ -77,36 +79,31 @@ object OSignup {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def setupExternal(cfg: Config)(data: ExternalAccount): F[SignupResult] =
|
def setupExternal(data: ExternalAccount): F[SignupResult] =
|
||||||
cfg.mode match {
|
if (data.source == AccountSource.Local)
|
||||||
case Config.Mode.Closed =>
|
SignupResult
|
||||||
SignupResult.signupClosed.pure[F]
|
.failure(new Exception("Account source must not be LOCAL!"))
|
||||||
case _ =>
|
.pure[F]
|
||||||
if (data.source == AccountSource.Local)
|
else
|
||||||
SignupResult
|
for {
|
||||||
.failure(new Exception("Account source must not be LOCAL!"))
|
recs <- makeRecords(data.collName, data.login, Password(""), data.source)
|
||||||
.pure[F]
|
cres <- store.add(
|
||||||
else
|
RCollective.insert(recs._1),
|
||||||
for {
|
RCollective.existsById(data.collName)
|
||||||
recs <- makeRecords(data.collName, data.login, Password(""), data.source)
|
)
|
||||||
cres <- store.add(
|
ures <- store.add(RUser.insert(recs._2), RUser.exists(data.login))
|
||||||
RCollective.insert(recs._1),
|
res = cres match {
|
||||||
RCollective.existsById(data.collName)
|
case AddResult.Failure(ex) =>
|
||||||
)
|
SignupResult.failure(ex)
|
||||||
ures <- store.add(RUser.insert(recs._2), RUser.exists(data.login))
|
case _ =>
|
||||||
res = cres match {
|
ures match {
|
||||||
case AddResult.Failure(ex) =>
|
case AddResult.Failure(ex) =>
|
||||||
SignupResult.failure(ex)
|
SignupResult.failure(ex)
|
||||||
case _ =>
|
case _ =>
|
||||||
ures match {
|
SignupResult.success
|
||||||
case AddResult.Failure(ex) =>
|
|
||||||
SignupResult.failure(ex)
|
|
||||||
case _ =>
|
|
||||||
SignupResult.success
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} yield res
|
}
|
||||||
}
|
} yield res
|
||||||
|
|
||||||
private def retryInvite(res: SignupResult): Boolean =
|
private def retryInvite(res: SignupResult): Boolean =
|
||||||
res match {
|
res match {
|
||||||
|
@ -121,7 +121,7 @@ object LenientUri {
|
|||||||
val isRoot = true
|
val isRoot = true
|
||||||
val isEmpty = false
|
val isEmpty = false
|
||||||
def /(seg: String): Path =
|
def /(seg: String): Path =
|
||||||
NonEmptyPath(NonEmptyList.of(seg))
|
NonEmptyPath(NonEmptyList.of(seg), false)
|
||||||
def asString = "/"
|
def asString = "/"
|
||||||
}
|
}
|
||||||
case object EmptyPath extends Path {
|
case object EmptyPath extends Path {
|
||||||
@ -129,20 +129,22 @@ object LenientUri {
|
|||||||
val isRoot = false
|
val isRoot = false
|
||||||
val isEmpty = true
|
val isEmpty = true
|
||||||
def /(seg: String): Path =
|
def /(seg: String): Path =
|
||||||
NonEmptyPath(NonEmptyList.of(seg))
|
NonEmptyPath(NonEmptyList.of(seg), false)
|
||||||
def asString = ""
|
def asString = ""
|
||||||
}
|
}
|
||||||
case class NonEmptyPath(segs: NonEmptyList[String]) extends Path {
|
case class NonEmptyPath(segs: NonEmptyList[String], trailingSlash: Boolean)
|
||||||
|
extends Path {
|
||||||
def segments = segs.toList
|
def segments = segs.toList
|
||||||
val isEmpty = false
|
val isEmpty = false
|
||||||
val isRoot = false
|
val isRoot = false
|
||||||
|
private val slashSuffix = if (trailingSlash) "/" else ""
|
||||||
def /(seg: String): Path =
|
def /(seg: String): Path =
|
||||||
copy(segs = segs.append(seg))
|
copy(segs = segs.append(seg))
|
||||||
def asString =
|
def asString =
|
||||||
segs.head match {
|
segs.head match {
|
||||||
case "." => segments.map(percentEncode).mkString("/")
|
case "." => segments.map(percentEncode).mkString("/") + slashSuffix
|
||||||
case ".." => segments.map(percentEncode).mkString("/")
|
case ".." => segments.map(percentEncode).mkString("/") + slashSuffix
|
||||||
case _ => "/" + segments.map(percentEncode).mkString("/")
|
case _ => "/" + segments.map(percentEncode).mkString("/") + slashSuffix
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -157,14 +159,14 @@ object LenientUri {
|
|||||||
str.trim match {
|
str.trim match {
|
||||||
case "/" => Right(RootPath)
|
case "/" => Right(RootPath)
|
||||||
case "" => Right(EmptyPath)
|
case "" => Right(EmptyPath)
|
||||||
case _ =>
|
case uriStr =>
|
||||||
Either.fromOption(
|
Either.fromOption(
|
||||||
stripLeading(str, '/')
|
stripLeading(uriStr, '/')
|
||||||
.split('/')
|
.split('/')
|
||||||
.toList
|
.toList
|
||||||
.traverse(percentDecode)
|
.traverse(percentDecode)
|
||||||
.flatMap(NonEmptyList.fromList)
|
.flatMap(NonEmptyList.fromList)
|
||||||
.map(NonEmptyPath.apply),
|
.map(NonEmptyPath(_, uriStr.endsWith("/"))),
|
||||||
s"Invalid path: $str"
|
s"Invalid path: $str"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -62,7 +62,7 @@ object UrlMatcher {
|
|||||||
// strip path to only match prefixes
|
// strip path to only match prefixes
|
||||||
val mPath: LenientUri.Path =
|
val mPath: LenientUri.Path =
|
||||||
NonEmptyList.fromList(url.path.segments.take(pathSegmentCount)) match {
|
NonEmptyList.fromList(url.path.segments.take(pathSegmentCount)) match {
|
||||||
case Some(nel) => LenientUri.NonEmptyPath(nel)
|
case Some(nel) => LenientUri.NonEmptyPath(nel, false)
|
||||||
case None => LenientUri.RootPath
|
case None => LenientUri.RootPath
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 Eike K. & Contributors
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
*/
|
||||||
|
|
||||||
|
package docspell.common.util
|
||||||
|
|
||||||
|
import javax.crypto.Mac
|
||||||
|
import javax.crypto.spec.SecretKeySpec
|
||||||
|
|
||||||
|
import scodec.bits.ByteVector
|
||||||
|
|
||||||
|
object SignUtil {
|
||||||
|
private val utf8 = java.nio.charset.StandardCharsets.UTF_8
|
||||||
|
|
||||||
|
private val macAlgo = "HmacSHA1"
|
||||||
|
|
||||||
|
private def getMac(key: ByteVector) = {
|
||||||
|
val mac = Mac.getInstance(macAlgo)
|
||||||
|
mac.init(new SecretKeySpec(key.toArray, macAlgo))
|
||||||
|
mac
|
||||||
|
}
|
||||||
|
|
||||||
|
def signString(data: String, key: ByteVector): ByteVector = {
|
||||||
|
val mac = getMac(key)
|
||||||
|
ByteVector.view(mac.doFinal(data.getBytes(utf8)))
|
||||||
|
}
|
||||||
|
|
||||||
|
def signBytes(data: ByteVector, key: ByteVector): ByteVector = {
|
||||||
|
val mac = getMac(key)
|
||||||
|
ByteVector.view(mac.doFinal(data.toArray))
|
||||||
|
}
|
||||||
|
|
||||||
|
def isEqual(sig1: ByteVector, sig2: ByteVector): Boolean =
|
||||||
|
sig1
|
||||||
|
.zipWith(sig2)((b1, b2) => (b1 - b2).toByte)
|
||||||
|
.foldLeft(true)(_ && _ == 0) && sig1.length == sig2.length
|
||||||
|
}
|
@ -29,4 +29,11 @@ class LenientUriTest extends FunSuite {
|
|||||||
)
|
)
|
||||||
assertEquals(LenientUri.percentDecode("a%25b%5Cc%7Cd%23e"), "a%b\\c|d#e".some)
|
assertEquals(LenientUri.percentDecode("a%25b%5Cc%7Cd%23e"), "a%b\\c|d#e".some)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("parse with trailing slash") {
|
||||||
|
assertEquals(LenientUri.unsafe("http://a.com/").asString, "http://a.com/")
|
||||||
|
assertEquals(LenientUri.unsafe("http://a.com").asString, "http://a.com")
|
||||||
|
assertEquals(LenientUri.unsafe("http://a.com/path").asString, "http://a.com/path")
|
||||||
|
assertEquals(LenientUri.unsafe("http://a.com/path/").asString, "http://a.com/path/")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 Eike K. & Contributors
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
*/
|
||||||
|
|
||||||
|
package docspell.common.util
|
||||||
|
|
||||||
|
import munit.FunSuite
|
||||||
|
import scodec.bits.ByteVector
|
||||||
|
|
||||||
|
class SignUtilTest extends FunSuite {
|
||||||
|
|
||||||
|
private val key = ByteVector.fromValidHex("caffee")
|
||||||
|
|
||||||
|
test("create and validate") {
|
||||||
|
val sig = SignUtil.signString("hello", key)
|
||||||
|
assert(SignUtil.isEqual(sig, SignUtil.signString("hello", key)))
|
||||||
|
}
|
||||||
|
}
|
@ -111,15 +111,19 @@ object CodeFlow {
|
|||||||
token <- r.attemptAs[AccessToken].value
|
token <- r.attemptAs[AccessToken].value
|
||||||
_ <- token match {
|
_ <- token match {
|
||||||
case Right(t) =>
|
case Right(t) =>
|
||||||
logger.trace(s"Got token response: $t")
|
logger.trace(s"Got token response (status=${r.status.code}): $t")
|
||||||
case Left(err) =>
|
case Left(err) =>
|
||||||
logger.error(err)(s"Error decoding access token: ${err.getMessage}")
|
logger.error(err)(
|
||||||
|
s"Error decoding access token (status=${r.status.code}): ${err.getMessage}"
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} yield token.toOption
|
} yield token.toOption
|
||||||
case r =>
|
case r =>
|
||||||
logger
|
for {
|
||||||
.error(s"Error obtaining access token '${r.status.code}' / ${r.as[String]}")
|
body <- r.bodyText.compile.string
|
||||||
.map(_ => None)
|
_ <- logger
|
||||||
|
.error(s"Error obtaining access token status=${r.status.code}, body=$body")
|
||||||
|
} yield None
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -177,5 +181,4 @@ object CodeFlow {
|
|||||||
logAction = Some((msg: String) => logger.trace(msg))
|
logAction = Some((msg: String) => logger.trace(msg))
|
||||||
)(c)
|
)(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ package docspell.oidc
|
|||||||
import docspell.common._
|
import docspell.common._
|
||||||
|
|
||||||
import org.http4s.Request
|
import org.http4s.Request
|
||||||
|
import scodec.bits.ByteVector
|
||||||
|
|
||||||
trait CodeFlowConfig[F[_]] {
|
trait CodeFlowConfig[F[_]] {
|
||||||
|
|
||||||
@ -22,17 +23,20 @@ trait CodeFlowConfig[F[_]] {
|
|||||||
*/
|
*/
|
||||||
def findProvider(id: Ident): Option[ProviderConfig]
|
def findProvider(id: Ident): Option[ProviderConfig]
|
||||||
|
|
||||||
|
def serverSecret: ByteVector
|
||||||
}
|
}
|
||||||
|
|
||||||
object CodeFlowConfig {
|
object CodeFlowConfig {
|
||||||
|
|
||||||
def apply[F[_]](
|
def apply[F[_]](
|
||||||
url: Request[F] => LenientUri,
|
url: Request[F] => LenientUri,
|
||||||
provider: Ident => Option[ProviderConfig]
|
provider: Ident => Option[ProviderConfig],
|
||||||
|
secret: ByteVector
|
||||||
): CodeFlowConfig[F] =
|
): CodeFlowConfig[F] =
|
||||||
new CodeFlowConfig[F] {
|
new CodeFlowConfig[F] {
|
||||||
def getEndpointUrl(req: Request[F]): LenientUri = url(req)
|
def getEndpointUrl(req: Request[F]): LenientUri = url(req)
|
||||||
def findProvider(id: Ident): Option[ProviderConfig] = provider(id)
|
def findProvider(id: Ident): Option[ProviderConfig] = provider(id)
|
||||||
|
val serverSecret = secret
|
||||||
}
|
}
|
||||||
|
|
||||||
private[oidc] def resumeUri[F[_]](
|
private[oidc] def resumeUri[F[_]](
|
||||||
@ -41,5 +45,4 @@ object CodeFlowConfig {
|
|||||||
cfg: CodeFlowConfig[F]
|
cfg: CodeFlowConfig[F]
|
||||||
): LenientUri =
|
): LenientUri =
|
||||||
cfg.getEndpointUrl(req) / prov.providerId.id / "resume"
|
cfg.getEndpointUrl(req) / prov.providerId.id / "resume"
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -41,19 +41,23 @@ object CodeFlowRoutes {
|
|||||||
case req @ GET -> Root / Ident(id) =>
|
case req @ GET -> Root / Ident(id) =>
|
||||||
config.findProvider(id) match {
|
config.findProvider(id) match {
|
||||||
case Some(cfg) =>
|
case Some(cfg) =>
|
||||||
val uri = cfg.authorizeUrl
|
for {
|
||||||
.withQuery("client_id", cfg.clientId)
|
state <- StateParam.generate[F](config.serverSecret)
|
||||||
.withQuery("scope", cfg.scope)
|
uri = cfg.authorizeUrl
|
||||||
.withQuery(
|
.withQuery("client_id", cfg.clientId)
|
||||||
"redirect_uri",
|
.withQuery("scope", cfg.scope)
|
||||||
CodeFlowConfig.resumeUri(req, cfg, config).asString
|
.withQuery(
|
||||||
|
"redirect_uri",
|
||||||
|
CodeFlowConfig.resumeUri(req, cfg, config).asString
|
||||||
|
)
|
||||||
|
.withQuery("response_type", "code")
|
||||||
|
.withQuery("state", state.asString)
|
||||||
|
_ <- logger.debug(
|
||||||
|
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
|
||||||
)
|
)
|
||||||
.withQuery("response_type", "code")
|
resp <- Found(Location(Uri.unsafeFromString(uri.asString)))
|
||||||
.withQuery("state", cfg.clientId)
|
} yield resp
|
||||||
logger.debug(
|
|
||||||
s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}"
|
|
||||||
) *>
|
|
||||||
Found(Location(Uri.unsafeFromString(uri.asString)))
|
|
||||||
case None =>
|
case None =>
|
||||||
logger.debug(s"No OAuth/OIDC provider found with id '$id'") *>
|
logger.debug(s"No OAuth/OIDC provider found with id '$id'") *>
|
||||||
NotFound()
|
NotFound()
|
||||||
@ -66,9 +70,20 @@ object CodeFlowRoutes {
|
|||||||
NotFound()
|
NotFound()
|
||||||
case Some(provider) =>
|
case Some(provider) =>
|
||||||
val codeFromReq = OptionT.fromOption[F](req.params.get("code"))
|
val codeFromReq = OptionT.fromOption[F](req.params.get("code"))
|
||||||
|
val stateParamValid = req.params
|
||||||
|
.get("state")
|
||||||
|
.exists(state => StateParam.isValidStateParam(state, config.serverSecret))
|
||||||
|
|
||||||
val userInfo = for {
|
val userInfo = for {
|
||||||
_ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}"))
|
_ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}"))
|
||||||
|
_ <-
|
||||||
|
if (stateParamValid) OptionT.pure[F](())
|
||||||
|
else
|
||||||
|
OptionT(
|
||||||
|
logger
|
||||||
|
.warn(s"Invalid state parameter returned from Idp!")
|
||||||
|
.as(Option.empty[Unit])
|
||||||
|
)
|
||||||
code <- codeFromReq
|
code <- codeFromReq
|
||||||
_ <- OptionT.liftF(
|
_ <- OptionT.liftF(
|
||||||
logger.trace(
|
logger.trace(
|
||||||
@ -91,7 +106,7 @@ object CodeFlowRoutes {
|
|||||||
s"$err$descr"
|
s"$err$descr"
|
||||||
}
|
}
|
||||||
.map(err => s": $err")
|
.map(err => s": $err")
|
||||||
.getOrElse("")
|
.getOrElse(": <no reason>")
|
||||||
|
|
||||||
logger.warn(s"Error resuming code flow from '${id.id}'$reason") *>
|
logger.warn(s"Error resuming code flow from '${id.id}'$reason") *>
|
||||||
onUserInfo.handle(req, provider, None)
|
onUserInfo.handle(req, provider, None)
|
||||||
|
49
modules/oidc/src/main/scala/docspell/oidc/StateParam.scala
Normal file
49
modules/oidc/src/main/scala/docspell/oidc/StateParam.scala
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 Eike K. & Contributors
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
*/
|
||||||
|
|
||||||
|
package docspell.oidc
|
||||||
|
|
||||||
|
import cats.effect._
|
||||||
|
import cats.syntax.all._
|
||||||
|
|
||||||
|
import docspell.common.util.{Random, SignUtil}
|
||||||
|
|
||||||
|
import scodec.bits.Bases.Alphabets
|
||||||
|
import scodec.bits.ByteVector
|
||||||
|
|
||||||
|
final case class StateParam(value: String, sig: ByteVector) {
|
||||||
|
def asString: String =
|
||||||
|
s"$value$$${sig.toBase64UrlNoPad}"
|
||||||
|
|
||||||
|
def isValid(key: ByteVector): Boolean = {
|
||||||
|
val actual = SignUtil.signString(value, key)
|
||||||
|
SignUtil.isEqual(actual, sig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object StateParam {
|
||||||
|
|
||||||
|
def generate[F[_]: Sync](key: ByteVector): F[StateParam] =
|
||||||
|
Random[F].string(8).map { v =>
|
||||||
|
val sig = SignUtil.signString(v, key)
|
||||||
|
StateParam(v, sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
def fromString(str: String, key: ByteVector): Either[String, StateParam] =
|
||||||
|
str.split('$') match {
|
||||||
|
case Array(v, sig) =>
|
||||||
|
ByteVector
|
||||||
|
.fromBase64Descriptive(sig, Alphabets.Base64UrlNoPad)
|
||||||
|
.map(s => StateParam(v, s))
|
||||||
|
.filterOrElse(_.isValid(key), s"Invalid signature in state param: $str")
|
||||||
|
|
||||||
|
case _ =>
|
||||||
|
Left(s"Invalid state parameter: $str")
|
||||||
|
}
|
||||||
|
|
||||||
|
def isValidStateParam(state: String, key: ByteVector) =
|
||||||
|
fromString(state, key).isRight
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 Eike K. & Contributors
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
*/
|
||||||
|
|
||||||
|
package docspell.oidc
|
||||||
|
|
||||||
|
import cats.effect._
|
||||||
|
|
||||||
|
import munit.CatsEffectSuite
|
||||||
|
import scodec.bits.ByteVector
|
||||||
|
|
||||||
|
class StateParamTest extends CatsEffectSuite {
|
||||||
|
|
||||||
|
private val key = ByteVector.fromValidHex("caffee")
|
||||||
|
private val key2 = key ++ key
|
||||||
|
|
||||||
|
test("generate") {
|
||||||
|
for {
|
||||||
|
p <- StateParam.generate[IO](key)
|
||||||
|
_ = {
|
||||||
|
assert(p.value.length > 8)
|
||||||
|
assert(p.sig.nonEmpty)
|
||||||
|
assert(p.isValid(key))
|
||||||
|
assert(!p.isValid(key2))
|
||||||
|
}
|
||||||
|
} yield ()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("fromString") {
|
||||||
|
for {
|
||||||
|
p <- StateParam.generate[IO](key)
|
||||||
|
str = p.asString
|
||||||
|
p2 = StateParam.fromString(str, key).fold(sys.error, identity)
|
||||||
|
p3 = StateParam.fromString(str, key2)
|
||||||
|
p4 = StateParam.fromString("uiaeuiaeue", key)
|
||||||
|
p5 = StateParam.fromString(str + "$" + str, key)
|
||||||
|
_ = {
|
||||||
|
assertEquals(p2, p)
|
||||||
|
assert(p3.isLeft)
|
||||||
|
assert(p4.isLeft)
|
||||||
|
assert(p5.isLeft)
|
||||||
|
}
|
||||||
|
} yield ()
|
||||||
|
}
|
||||||
|
}
|
@ -31,7 +31,8 @@ object OpenId {
|
|||||||
ClientRequestInfo
|
ClientRequestInfo
|
||||||
.getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid",
|
.getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid",
|
||||||
id =>
|
id =>
|
||||||
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider)
|
config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider),
|
||||||
|
config.auth.serverSecret
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] =
|
def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] =
|
||||||
@ -104,9 +105,7 @@ object OpenId {
|
|||||||
import dsl._
|
import dsl._
|
||||||
|
|
||||||
for {
|
for {
|
||||||
setup <- backend.signup.setupExternal(cfg.backend.signup)(
|
setup <- backend.signup.setupExternal(ExternalAccount(accountId))
|
||||||
ExternalAccount(accountId)
|
|
||||||
)
|
|
||||||
res <- setup match {
|
res <- setup match {
|
||||||
case SignupResult.Failure(ex) =>
|
case SignupResult.Failure(ex) =>
|
||||||
logger.error(ex)(s"Error when creating external account!") *>
|
logger.error(ex)(s"Error when creating external account!") *>
|
||||||
|
@ -24,7 +24,8 @@ object FileUrlReader {
|
|||||||
scheme = Nel.of(scheme),
|
scheme = Nel.of(scheme),
|
||||||
authority = Some(""),
|
authority = Some(""),
|
||||||
path = LenientUri.NonEmptyPath(
|
path = LenientUri.NonEmptyPath(
|
||||||
Nel.of(key.collective.id, key.category.id.id, key.id.id)
|
Nel.of(key.collective.id, key.category.id.id, key.id.id),
|
||||||
|
false
|
||||||
),
|
),
|
||||||
query = None,
|
query = None,
|
||||||
fragment = None
|
fragment = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user