Skip to content

Commit

Permalink
Merge pull request #133 from SwissDataScienceCenter/improve-jwt-verify
Browse files Browse the repository at this point in the history
Improves jwt verify by adding an issuer url whitelist and scoping jwks data by their issuer
  • Loading branch information
eikek authored May 22, 2024
2 parents 0b795e8 + d19921a commit 9cfd435
Show file tree
Hide file tree
Showing 17 changed files with 815 additions and 34 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ lazy val openidKeycloak = project
http4sBorer % "compile->compile;test->test",
httpClient % "compile->compile;test->test",
jwt % "compile->compile;test->test",
commons % "test->test"
commons % "compile->compile;test->test"
)

lazy val http4sMetrics = project
Expand Down
111 changes: 111 additions & 0 deletions modules/commons/src/main/scala/io/renku/search/common/UrlPattern.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright 2024 Swiss Data Science Center (SDSC)
* A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
* Eidgenössische Technische Hochschule Zürich (ETHZ).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.renku.search.common

import io.renku.search.common.UrlPattern.Segment

final case class UrlPattern(
scheme: Option[Segment],
host: List[Segment],
port: Option[Segment],
path: List[Segment]
):
def matches(url: String): Boolean =
val parts = UrlPattern.splitUrl(url)
scheme.forall(s => parts.scheme.exists(s.matches)) &&
(host.isEmpty || host.length == parts.host.length) &&
host.zip(parts.host).forall { case (s, h) => s.matches(h) } &&
port.forall(p => parts.port.exists(p.matches)) &&
(path.isEmpty || path.length == parts.path.length) &&
path.zip(parts.path).forall { case (s, p) => s.matches(p) }

def render: String =
scheme.map(s => s"${s.render}://").getOrElse("") +
host.map(_.render).mkString(".") +
port.map(p => s":${p.render}").getOrElse("") +
(if (path.isEmpty) "" else path.map(_.render).mkString("/", "/", ""))

object UrlPattern:
val all: UrlPattern = UrlPattern(None, Nil, None, Nil)

final private[common] case class UrlParts(
scheme: Option[String],
host: List[String],
port: Option[String],
path: List[String]
)
private[common] def splitUrl(url: String): UrlParts = {
def readScheme(s: String): (Option[String], String) =
s.split("://").filter(_.nonEmpty).toList match
case s :: rest :: Nil => (Some(s), rest)
case rest => (None, rest.mkString)

def readHostPort(s: String): (List[String], Option[String]) =
s.split(':').toList match
case h :: p :: _ =>
(h.split('.').filter(_.nonEmpty).toList, Option(p).filter(_.nonEmpty))
case rest =>
(s.split('.').filter(_.nonEmpty).toList, None)

val (scheme, rest0) = readScheme(url)
rest0.split('/').toList match
case hp :: rest =>
val (host, port) = readHostPort(hp)
UrlParts(scheme, host, port, rest)
case _ =>
val (host, port) = readHostPort(rest0)
UrlParts(scheme, host, port, Nil)
}

def fromString(str: String): UrlPattern =
if (str == "*" || str.isEmpty) UrlPattern.all
else {
val parts = splitUrl(str)
UrlPattern(
parts.scheme.map(Segment.fromString),
parts.host.map(Segment.fromString),
parts.port.map(Segment.fromString),
parts.path.map(Segment.fromString)
)
}

enum Segment:
case Literal(value: String)
case Prefix(value: String)
case Suffix(value: String)
case MatchAll

def matches(value: String): Boolean = this match
case Literal(v) => v.equalsIgnoreCase(value)
case Prefix(v) => value.startsWith(v)
case Suffix(v) => value.endsWith(v)
case MatchAll => true

def render: String = this match
case Literal(v) => v
case Prefix(v) => s"${v}*"
case Suffix(v) => s"*${v}"
case MatchAll => "*"

object Segment:
def fromString(s: String): Segment = s match
case "*" => MatchAll
case x if x.startsWith("*") => Suffix(x.drop(1))
case x if x.endsWith("*") => Prefix(x.dropRight(1))
case _ => Literal(s)
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ trait GeneratorSyntax:
def asListOfN(min: Int = 1, max: Int = 8): Gen[List[A]] =
Gen.choose(min, max).flatMap(Gen.listOfN(_, self))

def asOption: Gen[Option[A]] =
Gen.option(self)

def asSome: Gen[Option[A]] =
self.map(Some(_))

extension [A](self: Stream[Gen, A])
def toIO: Stream[IO, A] =
self.translate(FunctionK.lift[Gen, IO]([X] => (gx: Gen[X]) => IO(gx.generateOne)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
* limitations under the License.
*/

package io.renku.search.model
package io.renku.search.common

import cats.data.NonEmptyList

import io.renku.search.GeneratorSyntax.*
import org.scalacheck.Gen

object CommonGenerators:
Expand All @@ -28,3 +29,27 @@ object CommonGenerators:
e0 <- gen
en <- Gen.listOfN(n - 1, gen)
} yield NonEmptyList(e0, en)

def urlPatternGen: Gen[UrlPattern] =
def segmentGen(inner: Gen[String]): Gen[UrlPattern.Segment] =
Gen.oneOf(
inner.map(s => UrlPattern.Segment.Prefix(s)),
inner.map(s => UrlPattern.Segment.Suffix(s)),
inner.map(s => UrlPattern.Segment.Literal(s)),
Gen.const(UrlPattern.Segment.MatchAll)
)

val schemes = segmentGen(Gen.oneOf("http", "https"))
val ports = segmentGen(Gen.oneOf("123", "8080", "8145", "487", "11"))
val hosts = segmentGen(
Gen.oneOf("test", "com", "ch", "de", "dev", "renku", "penny", "cycle")
).asListOfN(0, 5)
val paths = segmentGen(
Gen.oneOf("auth", "authenticate", "doAuth", "me", "run", "api")
).asListOfN(0, 5)
for
scheme <- schemes.asOption
host <- hosts
port <- ports.asOption
path <- paths
yield UrlPattern(scheme, host, port, path)
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Copyright 2024 Swiss Data Science Center (SDSC)
* A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
* Eidgenössische Technische Hochschule Zürich (ETHZ).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.renku.search.common

import io.renku.search.common.UrlPattern.{Segment, UrlParts}
import munit.FunSuite
import munit.ScalaCheckSuite
import org.scalacheck.Prop

class UrlPatternSpec extends ScalaCheckSuite:

test("read parts"):
assertEquals(UrlPattern.splitUrl(""), UrlParts(None, Nil, None, Nil))
assertEquals(
UrlPattern.splitUrl("test.com"),
UrlParts(None, List("test", "com"), None, Nil)
)
assertEquals(
UrlPattern.splitUrl("*.test.com"),
UrlParts(None, List("*", "test", "com"), None, Nil)
)
assertEquals(
UrlPattern.splitUrl("test.com:123"),
UrlParts(None, List("test", "com"), Some("123"), Nil)
)
assertEquals(
UrlPattern.splitUrl("test.com:123/auth"),
UrlParts(None, List("test", "com"), Some("123"), List("auth"))
)
assertEquals(
UrlPattern.splitUrl("https://test.com:123/auth"),
UrlParts(Some("https"), List("test", "com"), Some("123"), List("auth"))
)
assertEquals(
UrlPattern.splitUrl("/auth/exec"),
UrlParts(None, Nil, None, List("auth", "exec"))
)

test("fromString"):
assertEquals(UrlPattern.fromString("*"), UrlPattern.all)
assertEquals(UrlPattern.fromString(""), UrlPattern.all)
assertEquals(
UrlPattern.fromString("*.*"),
UrlPattern(
None,
List(Segment.MatchAll, Segment.MatchAll),
None,
Nil
)
)
assertEquals(
UrlPattern.fromString("*.test.com"),
UrlPattern(
None,
List(Segment.MatchAll, Segment.Literal("test"), Segment.Literal("com")),
None,
Nil
)
)
assertEquals(
UrlPattern.fromString("*test.com"),
UrlPattern(
None,
List(Segment.Suffix("test"), Segment.Literal("com")),
None,
Nil
)
)
assertEquals(
UrlPattern.fromString("*test.com/auth*"),
UrlPattern(
None,
List(Segment.Suffix("test"), Segment.Literal("com")),
None,
List(Segment.Prefix("auth"))
)
)
assertEquals(
UrlPattern.fromString("https://test.com:15*/auth/sign"),
UrlPattern(
Some(Segment.Literal("https")),
List(Segment.Literal("test"), Segment.Literal("com")),
Some(Segment.Prefix("15")),
List(Segment.Literal("auth"), Segment.Literal("sign"))
)
)

property("read valid url pattern") {
Prop.forAll(CommonGenerators.urlPatternGen) { pattern =>
val parsed = UrlPattern.fromString(pattern.render)
val result = parsed == pattern
if (!result) {
println(s"Given: $pattern Parsed: ${parsed} Rendered: ${pattern.render}")
}
result
}
}

property("match all for all patterns") {
Prop.forAll(CommonGenerators.urlPatternGen) { pattern =>
val result = UrlPattern.all.matches(pattern.render)
if (!result) {
println(s"Failed pattern: ${pattern.render}")
}
result
}
}

test("matches successful"):
List(
UrlPattern.fromString("*.test.com") -> List(
"dev.test.com",
"http://sub.test.com/ab/cd"
),
UrlPattern.fromString("/auth/renku") -> List(
"dev.test.com/auth/renku",
"http://sub.test.com/auth/renku"
),
UrlPattern.fromString("*.test.com/auth/renku") -> List(
"http://dev.test.com/auth/renku",
"sub1.test.com/auth/renku"
)
).foreach { case (pattern, values) =>
values.foreach(v =>
assert(
pattern.matches(v),
s"Pattern ${pattern.render} did not match $v, but it should"
)
)
}

test("matches not successful"):
List(
UrlPattern.fromString("*.test.com") -> List(
"fest.com",
"http://sub.fest.com/ab/cd"
),
UrlPattern.fromString("/auth/renku") -> List(
"fest.com/tauth/renku",
"http://sub.test.com/auth/renkuu"
),
UrlPattern.fromString("*.test.com/auth/renku") -> List(
"http://dev.test.com/auth",
"sub1.sub2.test.com/auth/renku"
)
).foreach { case (pattern, values) =>
values.foreach(v =>
assert(
!pattern.matches(v),
s"Pattern ${pattern.render} matched $v, but it should not"
)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import io.renku.redis.client.*
import org.http4s.Uri

import scala.concurrent.duration.{Duration, FiniteDuration}
import io.renku.search.common.UrlPattern

trait ConfigDecoders:

Expand Down Expand Up @@ -60,3 +61,8 @@ trait ConfigDecoders:
given ConfigDecoder[String, Port] =
ConfigDecoder[String]
.mapOption(Port.getClass.getSimpleName)(Port.fromString)

given ConfigDecoder[String, List[UrlPattern]] =
ConfigDecoder[String].map { str =>
str.split(',').toList.map(UrlPattern.fromString)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package io.renku.search.config
import cats.syntax.all.*
import ciris.*
import com.comcast.ip4s.{Ipv4Address, Port}
import io.renku.search.common.UrlPattern
import io.renku.openid.keycloak.JwtVerifyConfig
import io.renku.redis.client.*
import io.renku.search.http.HttpServerConfig
Expand Down Expand Up @@ -97,5 +98,11 @@ object ConfigValues extends ConfigDecoders:
.default(defaults.minRequestDelay)
val openIdConfigPath =
renv("JWT_OPENID_CONFIG_PATH").default(defaults.openIdConfigPath)
(requestDelay, enableSigCheck, openIdConfigPath).mapN(JwtVerifyConfig.apply)
val allowedIssuers =
renv("JWT_ALLOWED_ISSUER_URL_PATTERNS")
.as[List[UrlPattern]]
.default(defaults.allowedIssuerUrls)
(requestDelay, enableSigCheck, openIdConfigPath, allowedIssuers).mapN(
JwtVerifyConfig.apply
)
}
Loading

0 comments on commit 9cfd435

Please sign in to comment.