diff --git a/modules/restserver/src/main/scala/docspell/restserver/auth/CookieData.scala b/modules/restserver/src/main/scala/docspell/restserver/auth/CookieData.scala index 8a43843d..b427970f 100644 --- a/modules/restserver/src/main/scala/docspell/restserver/auth/CookieData.scala +++ b/modules/restserver/src/main/scala/docspell/restserver/auth/CookieData.scala @@ -2,7 +2,7 @@ package docspell.restserver.auth import docspell.backend.auth._ import docspell.common.AccountId -import docspell.restserver.Config +import docspell.common.LenientUri import org.http4s._ import org.http4s.util._ @@ -11,14 +11,13 @@ case class CookieData(auth: AuthToken) { def accountId: AccountId = auth.account def asString: String = auth.asString - def asCookie(cfg: Config, host: Option[String]): ResponseCookie = { - val domain = CookieData.getDomain(cfg, host) - val sec = cfg.baseUrl.scheme.exists(_.endsWith("s")) - val path = cfg.baseUrl.path / "api" / "v1" / "sec" + def asCookie(baseUrl: LenientUri): ResponseCookie = { + val sec = baseUrl.scheme.exists(_.endsWith("s")) + val path = baseUrl.path / "api" / "v1" / "sec" ResponseCookie( CookieData.cookieName, asString, - domain = domain, + domain = None, path = Some(path.asString), httpOnly = true, secure = sec @@ -29,10 +28,6 @@ object CookieData { val cookieName = "docspell_auth" val headerName = "X-Docspell-Auth" - private def getDomain(cfg: Config, remote: Option[String]): Option[String] = - if (cfg.baseUrl.isLocal) remote.orElse(cfg.baseUrl.host) - else cfg.baseUrl.host - def authenticator[F[_]](r: Request[F]): Either[String, String] = fromCookie(r).orElse(fromHeader(r)) @@ -51,14 +46,14 @@ object CookieData { .map(_.value) .toRight("Couldn't find an authenticator") - def deleteCookie(cfg: Config, remoteHost: Option[String]): ResponseCookie = + def deleteCookie(baseUrl: LenientUri): ResponseCookie = ResponseCookie( cookieName, "", - domain = getDomain(cfg, remoteHost), - path = Some(cfg.baseUrl.path / "api" / "v1" / "sec").map(_.asString), + domain = None, + path = Some(baseUrl.path / "api" / "v1" / "sec").map(_.asString), httpOnly = true, - secure = cfg.baseUrl.scheme.exists(_.endsWith("s")), + secure = baseUrl.scheme.exists(_.endsWith("s")), maxAge = Some(-1) ) diff --git a/modules/restserver/src/main/scala/docspell/restserver/http4s/ClientHost.scala b/modules/restserver/src/main/scala/docspell/restserver/http4s/ClientHost.scala deleted file mode 100644 index 2c06dd15..00000000 --- a/modules/restserver/src/main/scala/docspell/restserver/http4s/ClientHost.scala +++ /dev/null @@ -1,29 +0,0 @@ -package docspell.restserver.http4s - -import org.http4s._ -import org.http4s.headers._ -import org.http4s.util.CaseInsensitiveString - -/** Obtain the host name of the client from the request. - */ -object ClientHost { - - def get[F[_]](req: Request[F]): Option[String] = - xForwardedFor(req) - .orElse(xForwardedHost(req)) - .orElse(host(req)) - - private def host[F[_]](req: Request[F]): Option[String] = - req.headers.get(Host).map(_.host) - - private def xForwardedFor[F[_]](req: Request[F]): Option[String] = - req.headers - .get(`X-Forwarded-For`) - .flatMap(_.values.head) - .flatMap(inet => Option(inet.getHostName).orElse(Option(inet.getHostAddress))) - - private def xForwardedHost[F[_]](req: Request[F]): Option[String] = - req.headers - .get(CaseInsensitiveString("X-Forwarded-Host")) - .map(_.value) -} diff --git a/modules/restserver/src/main/scala/docspell/restserver/http4s/ClientRequestInfo.scala b/modules/restserver/src/main/scala/docspell/restserver/http4s/ClientRequestInfo.scala new file mode 100644 index 00000000..a98926a0 --- /dev/null +++ b/modules/restserver/src/main/scala/docspell/restserver/http4s/ClientRequestInfo.scala @@ -0,0 +1,68 @@ +package docspell.restserver.http4s + +import cats.data.NonEmptyList +import cats.implicits._ + +import docspell.common._ +import docspell.restserver.Config + +import org.http4s._ +import org.http4s.headers._ +import org.http4s.util.CaseInsensitiveString + +/** Obtain information about the client by inspecting the request. + */ +object ClientRequestInfo { + + def getBaseUrl[F[_]](cfg: Config, req: Request[F]): LenientUri = + if (cfg.baseUrl.isLocal) getBaseUrl(req, cfg.bind.port).getOrElse(cfg.baseUrl) + else cfg.baseUrl + + private def getBaseUrl[F[_]](req: Request[F], serverPort: Int): Option[LenientUri] = + for { + scheme <- NonEmptyList.fromList(getProtocol(req).toList) + host <- getHostname(req) + port = xForwardedPort(req).getOrElse(serverPort) + hostPort = if (port == 80 || port == 443) host else s"${host}:${port}" + } yield LenientUri(scheme, Some(hostPort), LenientUri.EmptyPath, None, None) + + def getHostname[F[_]](req: Request[F]): Option[String] = + xForwardedHost(req) + .orElse(xForwardedFor(req)) + .orElse(host(req)) + + def getProtocol[F[_]](req: Request[F]): Option[String] = + xForwardedProto(req).orElse(clientConnectionProto(req)) + + private def host[F[_]](req: Request[F]): Option[String] = + req.headers.get(Host).map(_.host) + + private def xForwardedFor[F[_]](req: Request[F]): Option[String] = + req.headers + .get(`X-Forwarded-For`) + .flatMap(_.values.head) + .flatMap(inet => Option(inet.getHostName).orElse(Option(inet.getHostAddress))) + + private def xForwardedHost[F[_]](req: Request[F]): Option[String] = + req.headers + .get(CaseInsensitiveString("X-Forwarded-Host")) + .map(_.value) + + private def xForwardedProto[F[_]](req: Request[F]): Option[String] = + req.headers + .get(CaseInsensitiveString("X-Forwarded-Proto")) + .map(_.value) + + private def clientConnectionProto[F[_]](req: Request[F]): Option[String] = + req.isSecure.map { + case true => "https" + case false => "http" + } + + private def xForwardedPort[F[_]](req: Request[F]): Option[Int] = + req.headers + .get(CaseInsensitiveString("X-Forwarded-Port")) + .map(_.value) + .flatMap(str => Either.catchNonFatal(str.toInt).toOption) + +} diff --git a/modules/restserver/src/main/scala/docspell/restserver/routes/LoginRoutes.scala b/modules/restserver/src/main/scala/docspell/restserver/routes/LoginRoutes.scala index 78b39b8b..5b06472e 100644 --- a/modules/restserver/src/main/scala/docspell/restserver/routes/LoginRoutes.scala +++ b/modules/restserver/src/main/scala/docspell/restserver/routes/LoginRoutes.scala @@ -4,10 +4,11 @@ import cats.effect._ import cats.implicits._ import docspell.backend.auth._ +import docspell.common._ import docspell.restapi.model._ import docspell.restserver._ import docspell.restserver.auth._ -import docspell.restserver.http4s.ClientHost +import docspell.restserver.http4s.ClientRequestInfo import org.http4s._ import org.http4s.circe.CirceEntityDecoder._ @@ -22,10 +23,9 @@ object LoginRoutes { HttpRoutes.of[F] { case req @ POST -> Root / "login" => for { - up <- req.as[UserPass] - res <- S.loginUserPass(cfg.auth)(Login.UserPass(up.account, up.password)) - remote = ClientHost.get(req) - resp <- makeResponse(dsl, cfg, remote, res, up.account) + up <- req.as[UserPass] + res <- S.loginUserPass(cfg.auth)(Login.UserPass(up.account, up.password)) + resp <- makeResponse(dsl, cfg, req, res, up.account) } yield resp } } @@ -38,17 +38,20 @@ object LoginRoutes { case req @ POST -> Root / "session" => Authenticate .authenticateRequest(S.loginSession(cfg.auth))(req) - .flatMap(res => makeResponse(dsl, cfg, ClientHost.get(req), res, "")) + .flatMap(res => makeResponse(dsl, cfg, req, res, "")) case req @ POST -> Root / "logout" => - Ok().map(_.addCookie(CookieData.deleteCookie(cfg, ClientHost.get(req)))) + Ok().map(_.addCookie(CookieData.deleteCookie(getBaseUrl(cfg, req)))) } } - def makeResponse[F[_]: Effect]( + private def getBaseUrl[F[_]](cfg: Config, req: Request[F]): LenientUri = + ClientRequestInfo.getBaseUrl(cfg, req) + + private def makeResponse[F[_]: Effect]( dsl: Http4sDsl[F], cfg: Config, - remoteHost: Option[String], + req: Request[F], res: Login.Result, account: String ): F[Response[F]] = { @@ -66,7 +69,7 @@ object LoginRoutes { Some(cd.asString), cfg.auth.sessionValid.millis ) - ).map(_.addCookie(cd.asCookie(cfg, remoteHost))) + ).map(_.addCookie(cd.asCookie(getBaseUrl(cfg, req)))) } yield resp case _ => Ok(AuthResult("", account, false, "Login failed.", None, 0L)) diff --git a/modules/restserver/src/main/scala/docspell/restserver/routes/NotifyDueItemsRoutes.scala b/modules/restserver/src/main/scala/docspell/restserver/routes/NotifyDueItemsRoutes.scala index 7b38e4f1..240dcc43 100644 --- a/modules/restserver/src/main/scala/docspell/restserver/routes/NotifyDueItemsRoutes.scala +++ b/modules/restserver/src/main/scala/docspell/restserver/routes/NotifyDueItemsRoutes.scala @@ -10,6 +10,7 @@ import docspell.common._ import docspell.restapi.model._ import docspell.restserver.Config import docspell.restserver.conv.Conversions +import docspell.restserver.http4s.ClientRequestInfo import docspell.store.usertask._ import org.http4s._ @@ -40,7 +41,7 @@ object NotifyDueItemsRoutes { for { data <- req.as[NotificationSettings] newId <- Ident.randomId[F] - task <- makeTask(newId, cfg, user.account, data) + task <- makeTask(newId, getBaseUrl(cfg, req), user.account, data) res <- ut.executeNow(user.account, task) .attempt @@ -60,7 +61,7 @@ object NotifyDueItemsRoutes { case req @ PUT -> Root => def run(data: NotificationSettings) = for { - task <- makeTask(data.id, cfg, user.account, data) + task <- makeTask(data.id, getBaseUrl(cfg, req), user.account, data) res <- ut.submitNotifyDueItems(user.account, task) .attempt @@ -78,7 +79,7 @@ object NotifyDueItemsRoutes { for { data <- req.as[NotificationSettings] newId <- Ident.randomId[F] - task <- makeTask(newId, cfg, user.account, data) + task <- makeTask(newId, getBaseUrl(cfg, req), user.account, data) res <- ut.submitNotifyDueItems(user.account, task) .attempt @@ -96,9 +97,12 @@ object NotifyDueItemsRoutes { } } + private def getBaseUrl[F[_]](cfg: Config, req: Request[F]) = + ClientRequestInfo.getBaseUrl(cfg, req) + def makeTask[F[_]: Sync]( id: Ident, - cfg: Config, + baseUrl: LenientUri, user: AccountId, settings: NotificationSettings ): F[UserTask[NotifyDueItemsArgs]] = @@ -112,7 +116,7 @@ object NotifyDueItemsRoutes { user, settings.smtpConnection, settings.recipients, - Some(cfg.baseUrl / "app" / "item"), + Some(baseUrl / "app" / "item"), settings.remindDays, if (settings.capOverdue) Some(settings.remindDays) else None,