Skip to content

Commit

Permalink
Validate jwt token using the public key from keycloak (#128)
Browse files Browse the repository at this point in the history
Validate jwt using public key from keycloak

- When validating fails the first time, fetch the public key from
  keycloak and try again
- Fail permanently, if uptaded JWKS still yields validation errors
- JWKS is cached until next validation error, requests to keycloak are throttled to one per minute
  • Loading branch information
eikek authored May 17, 2024
1 parent 192ad4c commit 0c17939
Show file tree
Hide file tree
Showing 38 changed files with 1,720 additions and 40 deletions.
24 changes: 22 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ lazy val root = project
.aggregate(
commons,
jwt,
openidKeycloak,
httpClient,
events,
redisClient,
Expand Down Expand Up @@ -149,6 +150,23 @@ lazy val httpClient = project
http4sBorer % "compile->compile;test->test"
)

lazy val openidKeycloak = project
.in(file("modules/openid-keycloak"))
.enablePlugins(AutomateHeaderPlugin)
.disablePlugins(DbTestPlugin, RevolverPlugin)
.settings(commonSettings)
.settings(
name := "openid-keycloak",
description := "OpenID configuration with keycloak",
libraryDependencies ++= Dependencies.http4sDsl.map(_ % Test)
)
.dependsOn(
http4sBorer % "compile->compile;test->test",
httpClient % "compile->compile;test->test",
jwt % "compile->compile;test->test",
commons % "test->test"
)

lazy val http4sMetrics = project
.in(file("modules/http4s-metrics"))
.enablePlugins(AutomateHeaderPlugin)
Expand Down Expand Up @@ -307,7 +325,8 @@ lazy val configValues = project
events % "compile->compile;test->test",
http4sCommons % "compile->compile;test->test",
renkuRedisClient % "compile->compile;test->test",
searchSolrClient % "compile->compile;test->test"
searchSolrClient % "compile->compile;test->test",
openidKeycloak % "compile->compile;test->test"
)

lazy val searchQuery = project
Expand Down Expand Up @@ -374,7 +393,8 @@ lazy val searchApi = project
searchSolrClient % "compile->compile;test->test",
configValues % "compile->compile;test->test",
searchQueryDocs % "compile->compile;test->test",
jwt % "compile->compile;test->test"
jwt % "compile->compile;test->test",
openidKeycloak % "compile->compile;test->test"
)
.enablePlugins(AutomateHeaderPlugin, DockerImagePlugin, RevolverPlugin)

Expand Down
62 changes: 62 additions & 0 deletions modules/commons/src/test/scala/io/renku/search/TestClock.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2024 Swiss Data Science Center (SDSC)
* A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
* Eidgenössische Technische Hochschule Zürich (ETHZ).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.renku.search

import java.time.{Clock as _, *}

import scala.concurrent.duration.FiniteDuration

import cats.Applicative
import cats.effect.*

object TestClock:
extension (i: Instant)
def toDuration: FiniteDuration = FiniteDuration(i.toEpochMilli(), "ms")

def fixedAt(fixed: Instant): Clock[IO] =
new Clock[IO] {
val applicative: Applicative[IO] = Applicative[IO]
val realTime: IO[FiniteDuration] = IO.pure(fixed.toDuration)
val monotonic: IO[FiniteDuration] = realTime
override def toString = s"FixedClock($fixed)"
}

/** Clock initially returning `start` and adding `interval` to subsequent calls. */
def advanceBy(start: Instant, interval: FiniteDuration): Clock[IO] =
new Clock[IO] {
val counter = Ref.unsafe[IO, Long](0)
val applicative: Applicative[IO] = Applicative[IO]
val realTime: IO[FiniteDuration] =
counter.getAndUpdate(_ + 1).map { n =>
(interval * n) + start.toDuration
}
val monotonic: IO[FiniteDuration] = realTime
override def toString = s"AdvanceByClock(start=$start, interval=$interval)"
}

def sequence(start: Instant, more: Instant*): Clock[IO] =
val times = start :: more.toList
new Clock[IO] {
val counter = Ref.unsafe[IO, Int](0)
val applicative: Applicative[IO] = Applicative[IO]
val realTime: IO[FiniteDuration] =
counter.getAndUpdate(_ + 1).map(n => times(n).toDuration)
val monotonic: IO[FiniteDuration] = realTime
override def toString = s"SequenceClock(start=$start, more=$more)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package io.renku.search.config
import cats.syntax.all.*
import ciris.*
import com.comcast.ip4s.{Ipv4Address, Port}
import io.renku.openid.keycloak.JwtVerifyConfig
import io.renku.redis.client.*
import io.renku.search.http.HttpServerConfig
import io.renku.solr.client.{SolrConfig, SolrUser}
Expand Down Expand Up @@ -85,3 +86,16 @@ object ConfigValues extends ConfigDecoders:
val port =
renv(s"${prefix}_HTTP_SERVER_PORT").default(defaultPort.value.toString).as[Port]
(bindAddress, port).mapN(HttpServerConfig.apply)

val jwtVerifyConfig: ConfigValue[Effect, JwtVerifyConfig] = {
val defaults = JwtVerifyConfig.default
val enableSigCheck = renv("JWT_ENABLE_SIGNATURE_CHECK")
.as[Boolean]
.default(defaults.enableSignatureValidation)
val requestDelay = renv("JWT_KEYCLOAK_REQUEST_DELAY")
.as[FiniteDuration]
.default(defaults.minRequestDelay)
val openIdConfigPath =
renv("JWT_OPENID_CONFIG_PATH").default(defaults.openIdConfigPath)
(requestDelay, enableSigCheck, openIdConfigPath).mapN(JwtVerifyConfig.apply)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ import org.http4s.*

trait Http4sJsonCodec:
given Encoder[Uri] = Encoder.forString.contramap(_.renderString)
given Decoder[Uri] = Decoder.forString.mapEither(s => Uri.fromString(s).left.map(_.getMessage))

object Http4sJsonCodec extends Http4sJsonCodec
17 changes: 17 additions & 0 deletions modules/jwt/src/main/scala/io/renku/search/jwt/JwtBorer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,20 @@ class JwtBorer(override val clock: Clock)

object JwtBorer extends JwtBorer(Clock.systemUTC()):
def apply(clock: Clock): JwtBorer = new JwtBorer(clock)

def create[F[_]: cats.effect.Clock]: F[JwtBorer] =
val c = cats.effect.Clock[F]
c.applicative.map(c.realTimeInstant) { rt =>
new JwtBorer(new Clock {
def instant(): java.time.Instant = rt
def getZone(): java.time.ZoneId = java.time.ZoneId.of("UTC")
override def withZone(zone: java.time.ZoneId): Clock = this
})
}

def readHeader(token: String): Either[Throwable, JwtHeader] =
val h64 = token.takeWhile(_ != '.')
Json
.decode(JwtBase64.decode(h64))
.to[JwtHeader]
.valueEither
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright 2024 Swiss Data Science Center (SDSC)
* A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
* Eidgenössische Technische Hochschule Zürich (ETHZ).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.renku.openid.keycloak

import scodec.bits.Bases.Alphabets
import scodec.bits.ByteVector

private object BigIntDecode:

def apply(num: String): Either[String, BigInt] =
ByteVector
.fromBase64Descriptive(num, Alphabets.Base64UrlNoPad)
.map(bv => BigInt(1, bv.toArray))

def decode(num: String): Either[JwtError, BigInt] =
apply(num).left.map(msg => JwtError.BigIntDecodeError(num, msg))
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2024 Swiss Data Science Center (SDSC)
* A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
* Eidgenössische Technische Hochschule Zürich (ETHZ).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.renku.openid.keycloak

import io.bullet.borer.{Decoder, Encoder}

enum Curve(val nameShort: String, val name: String):
case P256 extends Curve("P-256", "secp256r1")
case P384 extends Curve("P-384", "secp384r1")
case P521 extends Curve("P-521", "secp521r1")

object Curve:
def fromString(s: String): Either[String, Curve] =
Curve.values.find(_.name.equalsIgnoreCase(s)).toRight(s"Unsupported curve: $s")

given Decoder[Curve] = Decoder.forString.mapEither(fromString)
given Encoder[Curve] = Encoder.forString.contramap(_.name)
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright 2024 Swiss Data Science Center (SDSC)
* A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
* Eidgenössische Technische Hochschule Zürich (ETHZ).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.renku.openid.keycloak

import scala.concurrent.duration.*

import cats.data.EitherT
import cats.effect.*
import cats.syntax.all.*

import io.renku.openid.keycloak.DefaultJwtVerify.State
import io.renku.search.http.borer.BorerEntityJsonCodec
import io.renku.search.jwt.JwtBorer
import org.http4s.Method.GET
import org.http4s.Uri
import org.http4s.client.Client
import org.http4s.client.dsl.Http4sClientDsl
import pdi.jwt.JwtClaim

final class DefaultJwtVerify[F[_]: Async](
client: Client[F],
state: Ref[F, State],
clock: Clock[F],
config: JwtVerifyConfig
) extends JwtVerify[F]
with Http4sClientDsl[F]
with BorerEntityJsonCodec:

private val logger = scribe.cats.effect[F]

def tryDecode(token: String) =
EitherT(state.get.flatMap(_.jwks.validate(clock)(token)))

def tryDecodeOnly(token: String): F[Either[JwtError, JwtClaim]] =
JwtBorer.create[F](using clock).map { jwtb =>
jwtb
.decodeNoSignatureCheck(token)
.toEither
.leftMap(ex => JwtError.JwtValidationError(token, None, None, ex))
}

def verify(token: String): F[Either[JwtError, JwtClaim]] =
if (!config.enableSignatureValidation) tryDecodeOnly(token)
else tryDecode(token).foldF(updateCache(token), _.asRight.pure[F])

def updateCache(token: String)(jwtError: JwtError): F[Either[JwtError, JwtClaim]] =
jwtError match
case JwtError.JwtValidationError(_, _, Some(claim), _) =>
(for
_ <- EitherT.right(
logger.info(
s"Token validation failed, fetch JWKS from keycloak and try again: ${jwtError.getMessage()}"
)
)
jwks <- fetchJWKSGuarded(claim)
result <- EitherT(jwks.validate(clock)(token))
yield result).value
case e => Left(e).pure[F]

def fetchJWKSGuarded(claim: JwtClaim): EitherT[F, JwtError, Jwks] =
for
_ <- checkLastUpdateDelay(config.minRequestDelay)
result <- fetchJWKS(claim)
yield result

def checkLastUpdateDelay(min: FiniteDuration): EitherT[F, JwtError, Unit] =
EitherT(
clock.monotonic.flatMap(ct => state.modify(_.lastUpdateDelay(ct))).map {
case delay if delay > min => Right(())
case _ => Left(JwtError.TooManyValidationRequests(min))
}
)

def fetchJWKS(claim: JwtClaim): EitherT[F, JwtError, Jwks] =
for
_ <- EitherT.right(
clock.monotonic.flatMap(t => state.update(_.copy(lastUpdate = t)))
)
issuerUri <- EitherT.fromEither(
Uri
.fromString(claim.issuer.getOrElse(""))
.leftMap(ex => JwtError.InvalidIssuerUrl(claim.issuer.getOrElse(""), ex))
)
configUri = issuerUri.addPath(config.openIdConfigPath)

_ <- EitherT.right(logger.debug(s"Fetch openid config from $configUri"))
openIdCfg <- EitherT(client.expect[OpenIdConfig](GET(configUri)).attempt)
.leftMap(ex => JwtError.OpenIdConfigError(configUri, ex))
_ <- EitherT.right(logger.trace(s"Got openid-config response: $openIdCfg"))

_ <- EitherT.right(logger.debug(s"Fetch jwks config from ${openIdCfg.jwksUri}"))
jwks <- EitherT(client.expect[Jwks](GET(openIdCfg.jwksUri)).attempt)
.leftMap(ex => JwtError.JwksError(openIdCfg.jwksUri, ex))

_ <- EitherT.right(state.update(_.copy(jwks = jwks)))
_ <- EitherT.right(
logger.debug(s"Updated JWKS with keys: ${jwks.keys.map(_.keyId)}")
)
yield jwks

object DefaultJwtVerify:
final case class State(
jwks: Jwks = Jwks.empty,
lastUpdate: FiniteDuration = Duration.Zero,
lastAccess: FiniteDuration = Duration.Zero
):
def lastUpdateDelay(now: FiniteDuration): (State, FiniteDuration) =
(copy(lastAccess = now), now - lastUpdate)

def apply[F[_]: Async](
client: Client[F],
clock: Clock[F],
config: JwtVerifyConfig
): F[JwtVerify[F]] =
Ref[F].of(State()).map(state => new DefaultJwtVerify(client, state, clock, config))
Loading

0 comments on commit 0c17939

Please sign in to comment.