-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Validate jwt token using the public key from keycloak (#128)
Validate jwt using public key from keycloak - When validating fails the first time, fetch the public key from keycloak and try again - Fail permanently, if uptaded JWKS still yields validation errors - JWKS is cached until next validation error, requests to keycloak are throttled to one per minute
- Loading branch information
Showing
38 changed files
with
1,720 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
62 changes: 62 additions & 0 deletions
62
modules/commons/src/test/scala/io/renku/search/TestClock.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
/* | ||
* Copyright 2024 Swiss Data Science Center (SDSC) | ||
* A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and | ||
* Eidgenössische Technische Hochschule Zürich (ETHZ). | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.renku.search | ||
|
||
import java.time.{Clock as _, *} | ||
|
||
import scala.concurrent.duration.FiniteDuration | ||
|
||
import cats.Applicative | ||
import cats.effect.* | ||
|
||
object TestClock: | ||
extension (i: Instant) | ||
def toDuration: FiniteDuration = FiniteDuration(i.toEpochMilli(), "ms") | ||
|
||
def fixedAt(fixed: Instant): Clock[IO] = | ||
new Clock[IO] { | ||
val applicative: Applicative[IO] = Applicative[IO] | ||
val realTime: IO[FiniteDuration] = IO.pure(fixed.toDuration) | ||
val monotonic: IO[FiniteDuration] = realTime | ||
override def toString = s"FixedClock($fixed)" | ||
} | ||
|
||
/** Clock initially returning `start` and adding `interval` to subsequent calls. */ | ||
def advanceBy(start: Instant, interval: FiniteDuration): Clock[IO] = | ||
new Clock[IO] { | ||
val counter = Ref.unsafe[IO, Long](0) | ||
val applicative: Applicative[IO] = Applicative[IO] | ||
val realTime: IO[FiniteDuration] = | ||
counter.getAndUpdate(_ + 1).map { n => | ||
(interval * n) + start.toDuration | ||
} | ||
val monotonic: IO[FiniteDuration] = realTime | ||
override def toString = s"AdvanceByClock(start=$start, interval=$interval)" | ||
} | ||
|
||
def sequence(start: Instant, more: Instant*): Clock[IO] = | ||
val times = start :: more.toList | ||
new Clock[IO] { | ||
val counter = Ref.unsafe[IO, Int](0) | ||
val applicative: Applicative[IO] = Applicative[IO] | ||
val realTime: IO[FiniteDuration] = | ||
counter.getAndUpdate(_ + 1).map(n => times(n).toDuration) | ||
val monotonic: IO[FiniteDuration] = realTime | ||
override def toString = s"SequenceClock(start=$start, more=$more)" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
modules/openid-keycloak/src/main/scala/io/renku/openid/keycloak/BigIntDecode.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
/* | ||
* Copyright 2024 Swiss Data Science Center (SDSC) | ||
* A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and | ||
* Eidgenössische Technische Hochschule Zürich (ETHZ). | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.renku.openid.keycloak | ||
|
||
import scodec.bits.Bases.Alphabets | ||
import scodec.bits.ByteVector | ||
|
||
private object BigIntDecode: | ||
|
||
def apply(num: String): Either[String, BigInt] = | ||
ByteVector | ||
.fromBase64Descriptive(num, Alphabets.Base64UrlNoPad) | ||
.map(bv => BigInt(1, bv.toArray)) | ||
|
||
def decode(num: String): Either[JwtError, BigInt] = | ||
apply(num).left.map(msg => JwtError.BigIntDecodeError(num, msg)) |
33 changes: 33 additions & 0 deletions
33
modules/openid-keycloak/src/main/scala/io/renku/openid/keycloak/Curve.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* | ||
* Copyright 2024 Swiss Data Science Center (SDSC) | ||
* A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and | ||
* Eidgenössische Technische Hochschule Zürich (ETHZ). | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.renku.openid.keycloak | ||
|
||
import io.bullet.borer.{Decoder, Encoder} | ||
|
||
enum Curve(val nameShort: String, val name: String): | ||
case P256 extends Curve("P-256", "secp256r1") | ||
case P384 extends Curve("P-384", "secp384r1") | ||
case P521 extends Curve("P-521", "secp521r1") | ||
|
||
object Curve: | ||
def fromString(s: String): Either[String, Curve] = | ||
Curve.values.find(_.name.equalsIgnoreCase(s)).toRight(s"Unsupported curve: $s") | ||
|
||
given Decoder[Curve] = Decoder.forString.mapEither(fromString) | ||
given Encoder[Curve] = Encoder.forString.contramap(_.name) |
131 changes: 131 additions & 0 deletions
131
modules/openid-keycloak/src/main/scala/io/renku/openid/keycloak/DefaultJwtVerify.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
/* | ||
* Copyright 2024 Swiss Data Science Center (SDSC) | ||
* A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and | ||
* Eidgenössische Technische Hochschule Zürich (ETHZ). | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.renku.openid.keycloak | ||
|
||
import scala.concurrent.duration.* | ||
|
||
import cats.data.EitherT | ||
import cats.effect.* | ||
import cats.syntax.all.* | ||
|
||
import io.renku.openid.keycloak.DefaultJwtVerify.State | ||
import io.renku.search.http.borer.BorerEntityJsonCodec | ||
import io.renku.search.jwt.JwtBorer | ||
import org.http4s.Method.GET | ||
import org.http4s.Uri | ||
import org.http4s.client.Client | ||
import org.http4s.client.dsl.Http4sClientDsl | ||
import pdi.jwt.JwtClaim | ||
|
||
final class DefaultJwtVerify[F[_]: Async]( | ||
client: Client[F], | ||
state: Ref[F, State], | ||
clock: Clock[F], | ||
config: JwtVerifyConfig | ||
) extends JwtVerify[F] | ||
with Http4sClientDsl[F] | ||
with BorerEntityJsonCodec: | ||
|
||
private val logger = scribe.cats.effect[F] | ||
|
||
def tryDecode(token: String) = | ||
EitherT(state.get.flatMap(_.jwks.validate(clock)(token))) | ||
|
||
def tryDecodeOnly(token: String): F[Either[JwtError, JwtClaim]] = | ||
JwtBorer.create[F](using clock).map { jwtb => | ||
jwtb | ||
.decodeNoSignatureCheck(token) | ||
.toEither | ||
.leftMap(ex => JwtError.JwtValidationError(token, None, None, ex)) | ||
} | ||
|
||
def verify(token: String): F[Either[JwtError, JwtClaim]] = | ||
if (!config.enableSignatureValidation) tryDecodeOnly(token) | ||
else tryDecode(token).foldF(updateCache(token), _.asRight.pure[F]) | ||
|
||
def updateCache(token: String)(jwtError: JwtError): F[Either[JwtError, JwtClaim]] = | ||
jwtError match | ||
case JwtError.JwtValidationError(_, _, Some(claim), _) => | ||
(for | ||
_ <- EitherT.right( | ||
logger.info( | ||
s"Token validation failed, fetch JWKS from keycloak and try again: ${jwtError.getMessage()}" | ||
) | ||
) | ||
jwks <- fetchJWKSGuarded(claim) | ||
result <- EitherT(jwks.validate(clock)(token)) | ||
yield result).value | ||
case e => Left(e).pure[F] | ||
|
||
def fetchJWKSGuarded(claim: JwtClaim): EitherT[F, JwtError, Jwks] = | ||
for | ||
_ <- checkLastUpdateDelay(config.minRequestDelay) | ||
result <- fetchJWKS(claim) | ||
yield result | ||
|
||
def checkLastUpdateDelay(min: FiniteDuration): EitherT[F, JwtError, Unit] = | ||
EitherT( | ||
clock.monotonic.flatMap(ct => state.modify(_.lastUpdateDelay(ct))).map { | ||
case delay if delay > min => Right(()) | ||
case _ => Left(JwtError.TooManyValidationRequests(min)) | ||
} | ||
) | ||
|
||
def fetchJWKS(claim: JwtClaim): EitherT[F, JwtError, Jwks] = | ||
for | ||
_ <- EitherT.right( | ||
clock.monotonic.flatMap(t => state.update(_.copy(lastUpdate = t))) | ||
) | ||
issuerUri <- EitherT.fromEither( | ||
Uri | ||
.fromString(claim.issuer.getOrElse("")) | ||
.leftMap(ex => JwtError.InvalidIssuerUrl(claim.issuer.getOrElse(""), ex)) | ||
) | ||
configUri = issuerUri.addPath(config.openIdConfigPath) | ||
|
||
_ <- EitherT.right(logger.debug(s"Fetch openid config from $configUri")) | ||
openIdCfg <- EitherT(client.expect[OpenIdConfig](GET(configUri)).attempt) | ||
.leftMap(ex => JwtError.OpenIdConfigError(configUri, ex)) | ||
_ <- EitherT.right(logger.trace(s"Got openid-config response: $openIdCfg")) | ||
|
||
_ <- EitherT.right(logger.debug(s"Fetch jwks config from ${openIdCfg.jwksUri}")) | ||
jwks <- EitherT(client.expect[Jwks](GET(openIdCfg.jwksUri)).attempt) | ||
.leftMap(ex => JwtError.JwksError(openIdCfg.jwksUri, ex)) | ||
|
||
_ <- EitherT.right(state.update(_.copy(jwks = jwks))) | ||
_ <- EitherT.right( | ||
logger.debug(s"Updated JWKS with keys: ${jwks.keys.map(_.keyId)}") | ||
) | ||
yield jwks | ||
|
||
object DefaultJwtVerify: | ||
final case class State( | ||
jwks: Jwks = Jwks.empty, | ||
lastUpdate: FiniteDuration = Duration.Zero, | ||
lastAccess: FiniteDuration = Duration.Zero | ||
): | ||
def lastUpdateDelay(now: FiniteDuration): (State, FiniteDuration) = | ||
(copy(lastAccess = now), now - lastUpdate) | ||
|
||
def apply[F[_]: Async]( | ||
client: Client[F], | ||
clock: Clock[F], | ||
config: JwtVerifyConfig | ||
): F[JwtVerify[F]] = | ||
Ref[F].of(State()).map(state => new DefaultJwtVerify(client, state, clock, config)) |
Oops, something went wrong.