diff --git a/modules/analysis/src/main/scala/docspell/analysis/contact/Contact.scala b/modules/analysis/src/main/scala/docspell/analysis/contact/Contact.scala index 2a9f87a9..cbdf0522 100644 --- a/modules/analysis/src/main/scala/docspell/analysis/contact/Contact.scala +++ b/modules/analysis/src/main/scala/docspell/analysis/contact/Contact.scala @@ -50,7 +50,7 @@ object Contact { p match { case LenientUri.RootPath => false case LenientUri.EmptyPath => false - case LenientUri.NonEmptyPath(segs) => + case LenientUri.NonEmptyPath(segs, _) => Ident.fromString(segs.last).isRight && segs.init.takeRight(3) == List("open", "upload", "item") } diff --git a/modules/backend/src/main/scala/docspell/backend/auth/TokenUtil.scala b/modules/backend/src/main/scala/docspell/backend/auth/TokenUtil.scala index 329e8e4a..060bc439 100644 --- a/modules/backend/src/main/scala/docspell/backend/auth/TokenUtil.scala +++ b/modules/backend/src/main/scala/docspell/backend/auth/TokenUtil.scala @@ -6,11 +6,10 @@ package docspell.backend.auth -import javax.crypto.Mac -import javax.crypto.spec.SecretKeySpec - import cats.implicits._ +import docspell.common.util.SignUtil + import scodec.bits._ private[auth] object TokenUtil { @@ -34,11 +33,8 @@ private[auth] object TokenUtil { signRaw(raw, key) } - private def signRaw(data: String, key: ByteVector): String = { - val mac = Mac.getInstance("HmacSHA1") - mac.init(new SecretKeySpec(key.toArray, "HmacSHA1")) - ByteVector.view(mac.doFinal(data.getBytes(utf8))).toBase64 - } + private def signRaw(data: String, key: ByteVector): String = + SignUtil.signString(data, key).toBase64 def b64enc(s: String): String = ByteVector.view(s.getBytes(utf8)).toBase64 @@ -52,5 +48,4 @@ private[auth] object TokenUtil { def constTimeEq(s1: String, s2: String): Boolean = s1.zip(s2) .foldLeft(true) { case (r, (c1, c2)) => r & c1 == c2 } & s1.length == s2.length - } diff --git a/modules/backend/src/main/scala/docspell/backend/signup/OSignup.scala b/modules/backend/src/main/scala/docspell/backend/signup/OSignup.scala index c59bc773..92314b8c 100644 --- a/modules/backend/src/main/scala/docspell/backend/signup/OSignup.scala +++ b/modules/backend/src/main/scala/docspell/backend/signup/OSignup.scala @@ -20,8 +20,10 @@ trait OSignup[F[_]] { def register(cfg: Config)(data: RegisterData): F[SignupResult] - /** Creates the given account if it doesn't exist. */ - def setupExternal(cfg: Config)(data: ExternalAccount): F[SignupResult] + /** Creates the given account if it doesn't exist. This is independent from signup + * configuration. + */ + def setupExternal(data: ExternalAccount): F[SignupResult] def newInvite(cfg: Config)(password: Password): F[NewInviteResult] } @@ -77,36 +79,31 @@ object OSignup { } } - def setupExternal(cfg: Config)(data: ExternalAccount): F[SignupResult] = - cfg.mode match { - case Config.Mode.Closed => - SignupResult.signupClosed.pure[F] - case _ => - if (data.source == AccountSource.Local) - SignupResult - .failure(new Exception("Account source must not be LOCAL!")) - .pure[F] - else - for { - recs <- makeRecords(data.collName, data.login, Password(""), data.source) - cres <- store.add( - RCollective.insert(recs._1), - RCollective.existsById(data.collName) - ) - ures <- store.add(RUser.insert(recs._2), RUser.exists(data.login)) - res = cres match { + def setupExternal(data: ExternalAccount): F[SignupResult] = + if (data.source == AccountSource.Local) + SignupResult + .failure(new Exception("Account source must not be LOCAL!")) + .pure[F] + else + for { + recs <- makeRecords(data.collName, data.login, Password(""), data.source) + cres <- store.add( + RCollective.insert(recs._1), + RCollective.existsById(data.collName) + ) + ures <- store.add(RUser.insert(recs._2), RUser.exists(data.login)) + res = cres match { + case AddResult.Failure(ex) => + SignupResult.failure(ex) + case _ => + ures match { case AddResult.Failure(ex) => SignupResult.failure(ex) case _ => - ures match { - case AddResult.Failure(ex) => - SignupResult.failure(ex) - case _ => - SignupResult.success - } + SignupResult.success } - } yield res - } + } + } yield res private def retryInvite(res: SignupResult): Boolean = res match { diff --git a/modules/common/src/main/scala/docspell/common/LenientUri.scala b/modules/common/src/main/scala/docspell/common/LenientUri.scala index 4061969d..c9f797ab 100644 --- a/modules/common/src/main/scala/docspell/common/LenientUri.scala +++ b/modules/common/src/main/scala/docspell/common/LenientUri.scala @@ -121,7 +121,7 @@ object LenientUri { val isRoot = true val isEmpty = false def /(seg: String): Path = - NonEmptyPath(NonEmptyList.of(seg)) + NonEmptyPath(NonEmptyList.of(seg), false) def asString = "/" } case object EmptyPath extends Path { @@ -129,20 +129,22 @@ object LenientUri { val isRoot = false val isEmpty = true def /(seg: String): Path = - NonEmptyPath(NonEmptyList.of(seg)) + NonEmptyPath(NonEmptyList.of(seg), false) def asString = "" } - case class NonEmptyPath(segs: NonEmptyList[String]) extends Path { + case class NonEmptyPath(segs: NonEmptyList[String], trailingSlash: Boolean) + extends Path { def segments = segs.toList val isEmpty = false val isRoot = false + private val slashSuffix = if (trailingSlash) "/" else "" def /(seg: String): Path = copy(segs = segs.append(seg)) def asString = segs.head match { - case "." => segments.map(percentEncode).mkString("/") - case ".." => segments.map(percentEncode).mkString("/") - case _ => "/" + segments.map(percentEncode).mkString("/") + case "." => segments.map(percentEncode).mkString("/") + slashSuffix + case ".." => segments.map(percentEncode).mkString("/") + slashSuffix + case _ => "/" + segments.map(percentEncode).mkString("/") + slashSuffix } } @@ -157,14 +159,14 @@ object LenientUri { str.trim match { case "/" => Right(RootPath) case "" => Right(EmptyPath) - case _ => + case uriStr => Either.fromOption( - stripLeading(str, '/') + stripLeading(uriStr, '/') .split('/') .toList .traverse(percentDecode) .flatMap(NonEmptyList.fromList) - .map(NonEmptyPath.apply), + .map(NonEmptyPath(_, uriStr.endsWith("/"))), s"Invalid path: $str" ) } diff --git a/modules/common/src/main/scala/docspell/common/UrlMatcher.scala b/modules/common/src/main/scala/docspell/common/UrlMatcher.scala index c8fd393e..de978dd9 100644 --- a/modules/common/src/main/scala/docspell/common/UrlMatcher.scala +++ b/modules/common/src/main/scala/docspell/common/UrlMatcher.scala @@ -62,7 +62,7 @@ object UrlMatcher { // strip path to only match prefixes val mPath: LenientUri.Path = NonEmptyList.fromList(url.path.segments.take(pathSegmentCount)) match { - case Some(nel) => LenientUri.NonEmptyPath(nel) + case Some(nel) => LenientUri.NonEmptyPath(nel, false) case None => LenientUri.RootPath } diff --git a/modules/common/src/main/scala/docspell/common/util/SignUtil.scala b/modules/common/src/main/scala/docspell/common/util/SignUtil.scala new file mode 100644 index 00000000..04ccd455 --- /dev/null +++ b/modules/common/src/main/scala/docspell/common/util/SignUtil.scala @@ -0,0 +1,39 @@ +/* + * Copyright 2020 Eike K. & Contributors + * + * SPDX-License-Identifier: AGPL-3.0-or-later + */ + +package docspell.common.util + +import javax.crypto.Mac +import javax.crypto.spec.SecretKeySpec + +import scodec.bits.ByteVector + +object SignUtil { + private val utf8 = java.nio.charset.StandardCharsets.UTF_8 + + private val macAlgo = "HmacSHA1" + + private def getMac(key: ByteVector) = { + val mac = Mac.getInstance(macAlgo) + mac.init(new SecretKeySpec(key.toArray, macAlgo)) + mac + } + + def signString(data: String, key: ByteVector): ByteVector = { + val mac = getMac(key) + ByteVector.view(mac.doFinal(data.getBytes(utf8))) + } + + def signBytes(data: ByteVector, key: ByteVector): ByteVector = { + val mac = getMac(key) + ByteVector.view(mac.doFinal(data.toArray)) + } + + def isEqual(sig1: ByteVector, sig2: ByteVector): Boolean = + sig1 + .zipWith(sig2)((b1, b2) => (b1 - b2).toByte) + .foldLeft(true)(_ && _ == 0) && sig1.length == sig2.length +} diff --git a/modules/common/src/test/scala/docspell/common/LenientUriTest.scala b/modules/common/src/test/scala/docspell/common/LenientUriTest.scala index 4bd69582..a81e7452 100644 --- a/modules/common/src/test/scala/docspell/common/LenientUriTest.scala +++ b/modules/common/src/test/scala/docspell/common/LenientUriTest.scala @@ -29,4 +29,11 @@ class LenientUriTest extends FunSuite { ) assertEquals(LenientUri.percentDecode("a%25b%5Cc%7Cd%23e"), "a%b\\c|d#e".some) } + + test("parse with trailing slash") { + assertEquals(LenientUri.unsafe("http://a.com/").asString, "http://a.com/") + assertEquals(LenientUri.unsafe("http://a.com").asString, "http://a.com") + assertEquals(LenientUri.unsafe("http://a.com/path").asString, "http://a.com/path") + assertEquals(LenientUri.unsafe("http://a.com/path/").asString, "http://a.com/path/") + } } diff --git a/modules/common/src/test/scala/docspell/common/util/SignUtilTest.scala b/modules/common/src/test/scala/docspell/common/util/SignUtilTest.scala new file mode 100644 index 00000000..d985f014 --- /dev/null +++ b/modules/common/src/test/scala/docspell/common/util/SignUtilTest.scala @@ -0,0 +1,20 @@ +/* + * Copyright 2020 Eike K. & Contributors + * + * SPDX-License-Identifier: AGPL-3.0-or-later + */ + +package docspell.common.util + +import munit.FunSuite +import scodec.bits.ByteVector + +class SignUtilTest extends FunSuite { + + private val key = ByteVector.fromValidHex("caffee") + + test("create and validate") { + val sig = SignUtil.signString("hello", key) + assert(SignUtil.isEqual(sig, SignUtil.signString("hello", key))) + } +} diff --git a/modules/oidc/src/main/scala/docspell/oidc/CodeFlow.scala b/modules/oidc/src/main/scala/docspell/oidc/CodeFlow.scala index c1a01995..ff5c06a2 100644 --- a/modules/oidc/src/main/scala/docspell/oidc/CodeFlow.scala +++ b/modules/oidc/src/main/scala/docspell/oidc/CodeFlow.scala @@ -111,15 +111,19 @@ object CodeFlow { token <- r.attemptAs[AccessToken].value _ <- token match { case Right(t) => - logger.trace(s"Got token response: $t") + logger.trace(s"Got token response (status=${r.status.code}): $t") case Left(err) => - logger.error(err)(s"Error decoding access token: ${err.getMessage}") + logger.error(err)( + s"Error decoding access token (status=${r.status.code}): ${err.getMessage}" + ) } } yield token.toOption case r => - logger - .error(s"Error obtaining access token '${r.status.code}' / ${r.as[String]}") - .map(_ => None) + for { + body <- r.bodyText.compile.string + _ <- logger + .error(s"Error obtaining access token status=${r.status.code}, body=$body") + } yield None }) } @@ -177,5 +181,4 @@ object CodeFlow { logAction = Some((msg: String) => logger.trace(msg)) )(c) } - } diff --git a/modules/oidc/src/main/scala/docspell/oidc/CodeFlowConfig.scala b/modules/oidc/src/main/scala/docspell/oidc/CodeFlowConfig.scala index 1be95021..8d626e3e 100644 --- a/modules/oidc/src/main/scala/docspell/oidc/CodeFlowConfig.scala +++ b/modules/oidc/src/main/scala/docspell/oidc/CodeFlowConfig.scala @@ -9,6 +9,7 @@ package docspell.oidc import docspell.common._ import org.http4s.Request +import scodec.bits.ByteVector trait CodeFlowConfig[F[_]] { @@ -22,17 +23,20 @@ trait CodeFlowConfig[F[_]] { */ def findProvider(id: Ident): Option[ProviderConfig] + def serverSecret: ByteVector } object CodeFlowConfig { def apply[F[_]]( url: Request[F] => LenientUri, - provider: Ident => Option[ProviderConfig] + provider: Ident => Option[ProviderConfig], + secret: ByteVector ): CodeFlowConfig[F] = new CodeFlowConfig[F] { def getEndpointUrl(req: Request[F]): LenientUri = url(req) def findProvider(id: Ident): Option[ProviderConfig] = provider(id) + val serverSecret = secret } private[oidc] def resumeUri[F[_]]( @@ -41,5 +45,4 @@ object CodeFlowConfig { cfg: CodeFlowConfig[F] ): LenientUri = cfg.getEndpointUrl(req) / prov.providerId.id / "resume" - } diff --git a/modules/oidc/src/main/scala/docspell/oidc/CodeFlowRoutes.scala b/modules/oidc/src/main/scala/docspell/oidc/CodeFlowRoutes.scala index c0adcdd2..af383a37 100644 --- a/modules/oidc/src/main/scala/docspell/oidc/CodeFlowRoutes.scala +++ b/modules/oidc/src/main/scala/docspell/oidc/CodeFlowRoutes.scala @@ -41,19 +41,23 @@ object CodeFlowRoutes { case req @ GET -> Root / Ident(id) => config.findProvider(id) match { case Some(cfg) => - val uri = cfg.authorizeUrl - .withQuery("client_id", cfg.clientId) - .withQuery("scope", cfg.scope) - .withQuery( - "redirect_uri", - CodeFlowConfig.resumeUri(req, cfg, config).asString + for { + state <- StateParam.generate[F](config.serverSecret) + uri = cfg.authorizeUrl + .withQuery("client_id", cfg.clientId) + .withQuery("scope", cfg.scope) + .withQuery( + "redirect_uri", + CodeFlowConfig.resumeUri(req, cfg, config).asString + ) + .withQuery("response_type", "code") + .withQuery("state", state.asString) + _ <- logger.debug( + s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}" ) - .withQuery("response_type", "code") - .withQuery("state", cfg.clientId) - logger.debug( - s"Redirecting to OAuth/OIDC provider ${cfg.providerId.id}: ${uri.asString}" - ) *> - Found(Location(Uri.unsafeFromString(uri.asString))) + resp <- Found(Location(Uri.unsafeFromString(uri.asString))) + } yield resp + case None => logger.debug(s"No OAuth/OIDC provider found with id '$id'") *> NotFound() @@ -66,9 +70,20 @@ object CodeFlowRoutes { NotFound() case Some(provider) => val codeFromReq = OptionT.fromOption[F](req.params.get("code")) + val stateParamValid = req.params + .get("state") + .exists(state => StateParam.isValidStateParam(state, config.serverSecret)) val userInfo = for { _ <- OptionT.liftF(logger.info(s"Resume OAuth/OIDC flow for ${id.id}")) + _ <- + if (stateParamValid) OptionT.pure[F](()) + else + OptionT( + logger + .warn(s"Invalid state parameter returned from Idp!") + .as(Option.empty[Unit]) + ) code <- codeFromReq _ <- OptionT.liftF( logger.trace( @@ -91,7 +106,7 @@ object CodeFlowRoutes { s"$err$descr" } .map(err => s": $err") - .getOrElse("") + .getOrElse(": ") logger.warn(s"Error resuming code flow from '${id.id}'$reason") *> onUserInfo.handle(req, provider, None) diff --git a/modules/oidc/src/main/scala/docspell/oidc/StateParam.scala b/modules/oidc/src/main/scala/docspell/oidc/StateParam.scala new file mode 100644 index 00000000..cbf70334 --- /dev/null +++ b/modules/oidc/src/main/scala/docspell/oidc/StateParam.scala @@ -0,0 +1,49 @@ +/* + * Copyright 2020 Eike K. & Contributors + * + * SPDX-License-Identifier: AGPL-3.0-or-later + */ + +package docspell.oidc + +import cats.effect._ +import cats.syntax.all._ + +import docspell.common.util.{Random, SignUtil} + +import scodec.bits.Bases.Alphabets +import scodec.bits.ByteVector + +final case class StateParam(value: String, sig: ByteVector) { + def asString: String = + s"$value$$${sig.toBase64UrlNoPad}" + + def isValid(key: ByteVector): Boolean = { + val actual = SignUtil.signString(value, key) + SignUtil.isEqual(actual, sig) + } +} + +object StateParam { + + def generate[F[_]: Sync](key: ByteVector): F[StateParam] = + Random[F].string(8).map { v => + val sig = SignUtil.signString(v, key) + StateParam(v, sig) + } + + def fromString(str: String, key: ByteVector): Either[String, StateParam] = + str.split('$') match { + case Array(v, sig) => + ByteVector + .fromBase64Descriptive(sig, Alphabets.Base64UrlNoPad) + .map(s => StateParam(v, s)) + .filterOrElse(_.isValid(key), s"Invalid signature in state param: $str") + + case _ => + Left(s"Invalid state parameter: $str") + } + + def isValidStateParam(state: String, key: ByteVector) = + fromString(state, key).isRight +} diff --git a/modules/oidc/src/test/scala/docspell/oidc/StateParamTest.scala b/modules/oidc/src/test/scala/docspell/oidc/StateParamTest.scala new file mode 100644 index 00000000..a189d935 --- /dev/null +++ b/modules/oidc/src/test/scala/docspell/oidc/StateParamTest.scala @@ -0,0 +1,47 @@ +/* + * Copyright 2020 Eike K. & Contributors + * + * SPDX-License-Identifier: AGPL-3.0-or-later + */ + +package docspell.oidc + +import cats.effect._ + +import munit.CatsEffectSuite +import scodec.bits.ByteVector + +class StateParamTest extends CatsEffectSuite { + + private val key = ByteVector.fromValidHex("caffee") + private val key2 = key ++ key + + test("generate") { + for { + p <- StateParam.generate[IO](key) + _ = { + assert(p.value.length > 8) + assert(p.sig.nonEmpty) + assert(p.isValid(key)) + assert(!p.isValid(key2)) + } + } yield () + } + + test("fromString") { + for { + p <- StateParam.generate[IO](key) + str = p.asString + p2 = StateParam.fromString(str, key).fold(sys.error, identity) + p3 = StateParam.fromString(str, key2) + p4 = StateParam.fromString("uiaeuiaeue", key) + p5 = StateParam.fromString(str + "$" + str, key) + _ = { + assertEquals(p2, p) + assert(p3.isLeft) + assert(p4.isLeft) + assert(p5.isLeft) + } + } yield () + } +} diff --git a/modules/restserver/src/main/scala/docspell/restserver/auth/OpenId.scala b/modules/restserver/src/main/scala/docspell/restserver/auth/OpenId.scala index 9ec8e7ea..b65841ef 100644 --- a/modules/restserver/src/main/scala/docspell/restserver/auth/OpenId.scala +++ b/modules/restserver/src/main/scala/docspell/restserver/auth/OpenId.scala @@ -31,7 +31,8 @@ object OpenId { ClientRequestInfo .getBaseUrl(config, req) / "api" / "v1" / "open" / "auth" / "openid", id => - config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider) + config.openid.filter(_.enabled).find(_.provider.providerId == id).map(_.provider), + config.auth.serverSecret ) def handle[F[_]: Async](backend: BackendApp[F], config: Config): OnUserInfo[F] = @@ -104,9 +105,7 @@ object OpenId { import dsl._ for { - setup <- backend.signup.setupExternal(cfg.backend.signup)( - ExternalAccount(accountId) - ) + setup <- backend.signup.setupExternal(ExternalAccount(accountId)) res <- setup match { case SignupResult.Failure(ex) => logger.error(ex)(s"Error when creating external account!") *> diff --git a/modules/store/src/main/scala/docspell/store/file/FileUrlReader.scala b/modules/store/src/main/scala/docspell/store/file/FileUrlReader.scala index df69f421..47b8cc52 100644 --- a/modules/store/src/main/scala/docspell/store/file/FileUrlReader.scala +++ b/modules/store/src/main/scala/docspell/store/file/FileUrlReader.scala @@ -24,7 +24,8 @@ object FileUrlReader { scheme = Nel.of(scheme), authority = Some(""), path = LenientUri.NonEmptyPath( - Nel.of(key.collective.id, key.category.id.id, key.id.id) + Nel.of(key.collective.id, key.category.id.id, key.id.id), + false ), query = None, fragment = None