From 661fe0591da11fc3136b48b69d1a1b063bb90d7c Mon Sep 17 00:00:00 2001 From: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> Date: Thu, 9 Jan 2025 21:31:48 +0100 Subject: [PATCH] Schema based header codecs, unified with query codecs (#3232) Fix for publish CI job --- .github/workflows/ci.yml | 4 + build.sbt | 1 + .../ServerInboundHandlerBenchmark.scala | 2 +- .../zio/http/endpoint/cli/CliEndpoint.scala | 39 +- .../zio/http/endpoint/cli/HttpOptions.scala | 8 +- .../scala/zio/http/endpoint/cli/CliSpec.scala | 2 +- .../zio/http/endpoint/cli/CommandGen.scala | 18 +- .../zio/http/endpoint/cli/EndpointGen.scala | 10 +- .../zio/http/endpoint/cli/OptionsGen.scala | 15 +- .../zio/http/gen/scala/CodeGenSpec.scala | 2 +- .../scala/zio/http/endpoint/HeaderSpec.scala | 204 ++++++ .../http/endpoint/QueryParameterSpec.scala | 50 +- .../scala/zio/http/endpoint/RequestSpec.scala | 54 +- .../endpoint/openapi/OpenAPIGenSpec.scala | 135 +++- .../src/main/scala/zio/http/Header.scala | 1 + .../scala/zio/http/codec/HeaderCodecs.scala | 34 +- .../main/scala/zio/http/codec/HttpCodec.scala | 354 ++++++---- .../scala/zio/http/codec/HttpCodecError.scala | 10 + .../scala/zio/http/codec/QueryCodecs.scala | 126 +--- .../scala/zio/http/codec/StringCodec.scala | 394 +++++++++++ .../zio/http/codec/TextBinaryCodec.scala | 8 +- .../zio/http/codec/internal/Atomized.scala | 30 +- .../http/codec/internal/AtomizedCodecs.scala | 7 +- .../http/codec/internal/EncoderDecoder.scala | 619 +++++++++++------- .../zio/http/endpoint/http/HttpGen.scala | 24 +- .../http/endpoint/openapi/JsonSchema.scala | 20 +- .../http/endpoint/openapi/OpenAPIGen.scala | 55 +- .../zio/http/internal/HeaderGetters.scala | 7 + 28 files changed, 1647 insertions(+), 586 deletions(-) create mode 100644 zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala create mode 100644 zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 31c78e038..10adabc07 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -196,6 +196,10 @@ jobs: tar xf targets.tar rm targets.tar + - uses: coursier/setup-action@v1 + with: + apps: sbt + - name: Release env: PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} diff --git a/build.sbt b/build.sbt index 71f68af84..6a5520882 100644 --- a/build.sbt +++ b/build.sbt @@ -58,6 +58,7 @@ ThisBuild / githubWorkflowTargetTags ++= Seq("v*") ThisBuild / githubWorkflowPublishTargetBranches += RefPredicate.StartsWith(Ref.Tag("v")) ThisBuild / githubWorkflowPublish := Seq( + WorkflowStep.Use(UseRef.Public("coursier", "setup-action", "v1"), Map("apps" -> "sbt")), WorkflowStep.Sbt( List("ci-release"), name = Some("Release"), diff --git a/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala b/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala index d28c665e5..8f594614c 100644 --- a/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala +++ b/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala @@ -18,7 +18,7 @@ class ServerInboundHandlerBenchmark { private val largeString = random.alphanumeric.take(100000).mkString private val baseUrl = "http://localhost:8080" - private val headers = Headers(Header.ContentType(MediaType.text.`plain`).untyped) + private val headers = Headers(Header.ContentType(MediaType.text.`plain`)) private val arrayEndpoint = "array" private val arrayResponse = ZIO.succeed( diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala index 566289ea3..6c446338a 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala @@ -1,7 +1,6 @@ package zio.http.endpoint.cli import zio.http._ -import zio.http.codec.HttpCodec.Query.QueryType import zio.http.codec._ import zio.http.endpoint._ @@ -112,13 +111,11 @@ private[cli] object CliEndpoint { } CliEndpoint(body = HttpOptions.Body(name, codec.defaultMediaType, codec.defaultSchema) :: List()) - case HttpCodec.Header(name, textCodec, _) if textCodec.isInstanceOf[TextCodec.Constant] => - CliEndpoint(headers = - HttpOptions.HeaderConstant(name, textCodec.asInstanceOf[TextCodec.Constant].string) :: List(), - ) - case HttpCodec.Header(name, textCodec, _) => - CliEndpoint(headers = HttpOptions.Header(name, textCodec) :: List()) - case HttpCodec.Method(codec, _) => + case HttpCodec.Header(headerType, _) => + CliEndpoint(headers = HttpOptions.Header(headerType.name, TextCodec.string) :: List()) + case HttpCodec.HeaderCustom(codec, _) => + ??? // todo + case HttpCodec.Method(codec, _) => codec.asInstanceOf[SimpleCodec[_, _]] match { case SimpleCodec.Specified(method: Method) => CliEndpoint(methods = method) @@ -128,22 +125,16 @@ private[cli] object CliEndpoint { case HttpCodec.Path(pathCodec, _) => CliEndpoint(url = HttpOptions.Path(pathCodec) :: List()) - case HttpCodec.Query(queryType, _) => - queryType match { - case QueryType.Primitive(name, codec) => - CliEndpoint(url = HttpOptions.Query(name, codec) :: List()) - case record @ QueryType.Record(_) => - val queryOptions = record.fieldAndCodecs.map { case (field, codec) => - HttpOptions.Query(field.name, codec) - } - CliEndpoint(url = queryOptions.toList) - case QueryType.Collection(_, elements, _) => - val queryOptions = - HttpOptions.Query(elements.name, elements.codec) - CliEndpoint(url = queryOptions :: List()) - } - - case HttpCodec.Status(_, _) => CliEndpoint.empty + case HttpCodec.Query(codec, _) => + if (codec.isPrimitive) + CliEndpoint(url = HttpOptions.Query(codec) :: List()) + else if (codec.isRecord) + CliEndpoint(url = codec.recordFields.map { case (_, codec) => + HttpOptions.Query(codec) + }.toList) + else + CliEndpoint(url = HttpOptions.Query(codec) :: List()) + case HttpCodec.Status(_, _) => CliEndpoint.empty } } diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala index 191194864..2abb5704b 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala @@ -11,6 +11,7 @@ import zio.schema._ import zio.schema.annotation.description import zio.http._ +import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ /* @@ -264,10 +265,9 @@ private[cli] object HttpOptions { } - final case class Query(override val name: String, codec: BinaryCodecWithSchema[_], doc: Doc = Doc.empty) - extends URLOptions { + final case class Query(codec: SchemaCodec[_], doc: Doc = Doc.empty) extends URLOptions { self => - + override val name = codec.name.get override val tag = "?" + name def options: Options[_] = optionsFromSchema(codec)(name) @@ -293,7 +293,7 @@ private[cli] object HttpOptions { } - private[cli] def optionsFromSchema[A](codec: BinaryCodecWithSchema[A]): String => Options[A] = + private[cli] def optionsFromSchema[A](codec: SchemaCodec[A]): String => Options[A] = codec.schema match { case Schema.Primitive(standardType, _) => standardType match { diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala index 5153f1ab4..f26145431 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala @@ -27,7 +27,7 @@ object CliSpec extends ZIOSpecDefault { val bodyStream = ContentCodec.contentStream[BigInt]("bodyStream") - val headerCodec = HttpCodec.Header("header", TextCodec.string) + val headerCodec = HttpCodec.headerAs[String]("header") val path1 = PathCodec.bool("path1") diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala index ad9989721..9d84e04b9 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala @@ -47,20 +47,20 @@ object CommandGen { case _: HttpOptions.Constant => false case _ => true }.map { - case HttpOptions.Path(pathCodec, _) => - pathCodec.segments.toList.flatMap { case segment => + case HttpOptions.Path(pathCodec, _) => + pathCodec.segments.toList.flatMap { segment => getSegment(segment) match { case (_, "") => Nil case (name, "boolean") => s"[${getName(name, "")}]" :: Nil case (name, codec) => s"${getName(name, "")} $codec" :: Nil } } - case HttpOptions.Query(name, codec, _) => - getType(codec) match { - case "" => s"[${getName(name, "")}]" :: Nil - case codec => s"${getName(name, "")} $codec" :: Nil + case HttpOptions.Query(codec, _) if codec.isPrimitive => + getType(codec.schema) match { + case "" => s"[${getName(codec.name.get, "")}]" :: Nil + case tpy => s"${getName(codec.name.get, "")} $tpy" :: Nil } - case _ => Nil + case _ => Nil }.foldRight(List[String]())(_ ++ _) val headersOptions = cliEndpoint.headers.filter { @@ -121,8 +121,8 @@ object CommandGen { case _ => "" } - def getType[A](codec: BinaryCodecWithSchema[A]): String = - codec.schema match { + def getType[A](schema: Schema[A]): String = + schema match { case Schema.Primitive(standardType, _) => standardType match { case StandardType.UnitType => "" diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala index d868a86cc..380953327 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala @@ -5,7 +5,9 @@ import zio.test._ import zio.schema.Schema +import zio.http.Header.HeaderType import zio.http._ +import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ import zio.http.endpoint._ import zio.http.endpoint.cli.AuxGen._ @@ -78,7 +80,7 @@ object EndpointGen { lazy val anyHeader: Gen[Any, CliReprOf[Codec[_]]] = Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).map { case (name, codec) => CliRepr( - HttpCodec.Header(name, codec), + HttpCodec.Header(Header.Custom(name, "").headerType), // todo use schema bases header codec match { case TextCodec.Constant(value) => CliEndpoint(headers = HttpOptions.HeaderConstant(name, value) :: Nil) case _ => CliEndpoint(headers = HttpOptions.Header(name, codec) :: Nil) @@ -102,10 +104,10 @@ object EndpointGen { lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] = Gen.alphaNumericStringBounded(1, 30).zip(anyStandardType).map { case (name, schema0) => val schema = schema0.asInstanceOf[Schema[Any]] - val codec = BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) + val codec = SchemaCodec(Some(name), schema) CliRepr( - HttpCodec.Query(HttpCodec.Query.QueryType.Primitive(name, codec)), - CliEndpoint(url = HttpOptions.Query(name, codec) :: Nil), + HttpCodec.Query(codec), + CliEndpoint(url = HttpOptions.Query(codec) :: Nil), ) } diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala index 58fe22aa8..9e349cff1 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala @@ -7,6 +7,7 @@ import zio.test.Gen import zio.schema.Schema import zio.http._ +import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ import zio.http.endpoint.cli.AuxGen._ import zio.http.endpoint.cli.CliRepr._ @@ -32,10 +33,10 @@ object OptionsGen { .optionsFromTextCodec(textCodec)(name) .map(value => textCodec.encode(value)) - def encodeOptions[A](name: String, codec: BinaryCodecWithSchema[A]): Options[String] = + def encodeOptions[A](name: String, codec: SchemaCodec[A]): Options[String] = HttpOptions .optionsFromSchema(codec)(name) - .map(value => codec.codec(CodecConfig.defaultConfig).encode(value).asString) + .map(value => codec.stringCodec.encode(value)) lazy val anyBodyOption: Gen[Any, CliReprOf[Options[Retriever]]] = Gen @@ -83,14 +84,12 @@ object OptionsGen { }, Gen .alphaNumericStringBounded(1, 30) - .zip(anyStandardType.map { s => - val schema = s.asInstanceOf[Schema[Any]] - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - }) - .map { case (name, codec) => + .zip(anyStandardType) + .map { case (name, schema) => + val codec = SchemaCodec(Some(name), schema) CliRepr( encodeOptions(name, codec), - CliEndpoint(url = HttpOptions.Query(name, codec) :: Nil), + CliEndpoint(url = HttpOptions.Query(codec) :: Nil), ) }, ) diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala index fe0aa3f37..d7c052f13 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala @@ -155,7 +155,7 @@ object CodeGenSpec extends ZIOSpecDefault { Endpoint(Method.GET / "api" / "v1" / "users") .header(HeaderCodec.accept) .header(HeaderCodec.contentType) - .header(HeaderCodec.name[String]("Token")) + .header(HeaderCodec.headerAs[String]("Token")) val openAPI = OpenAPIGen.fromEndpoints(endpoint) codeGenFromOpenAPI(openAPI) { testDir => diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala new file mode 100644 index 000000000..ee5a7896e --- /dev/null +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala @@ -0,0 +1,204 @@ +package zio.http.endpoint + +import zio.test._ +import zio.{NonEmptyChunk, Scope} + +import zio.schema.Schema +import zio.schema.annotation.fieldName + +import zio.http._ +import zio.http.codec.HttpCodec +import zio.http.endpoint.EndpointSpec.testEndpointWithHeaders + +object HeaderSpec extends ZIOHttpSpec { + case class MyHeaders(age: String, @fieldName("content-type") cType: String = "application", xApiKey: Option[String]) + + object MyHeaders { + implicit val schema: Schema[MyHeaders] = zio.schema.DeriveSchema.gen[MyHeaders] + } + + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("HeaderCodec")( + test("Headers from case class") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, cType, apiKey) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headers[MyHeaders]) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age, + "content-type" -> cType, + "x-api-key" -> apiKey, + ), + MyHeaders(age, cType, Some(apiKey)).toString, + ) && + testRoutes( + s"/users", + List( + "age" -> age, + "content-type" -> cType, + "x-api-key" -> "", + ), + MyHeaders(age, cType, Some("")).toString, + ) && + testRoutes( + s"/users", + List( + "age" -> age, + ), + MyHeaders(age, "application", None).toString, + ) + } + }, + test("Optional Headers from case class") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, cType, apiKey) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headers[MyHeaders].optional) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "content-type" -> cType, + ), + None.toString, + ) && + testRoutes( + s"/users", + List( + "age" -> age, + "content-type" -> cType, + "x-api-key" -> apiKey, + ), + Some(MyHeaders(age, cType, Some(apiKey))).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + ), + Some(MyHeaders(age, "application", None)).toString, + ) + } + }, + test("Multiple Header values") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, age2, age3) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headerAs[List[String]]("age")) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age, + ), + List(age).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + ), + List(age, age2).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + "age" -> age3, + ), + List(age, age2, age3).toString, + ) + } + }, + test("Multiple Header values non empty") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, age2, age3) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headerAs[NonEmptyChunk[String]]("age")) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age, + ), + NonEmptyChunk(age).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + ), + NonEmptyChunk(age, age2).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + "age" -> age3, + ), + NonEmptyChunk(age, age2, age3).toString, + ) + } + }, + test("Header from transformed schema") { + case class Wrapper(age: Int) + implicit val schema: Schema[Wrapper] = zio.schema.Schema[Int].transform[Wrapper](Wrapper(_), _.age) + check(Gen.int) { age => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headerAs[Wrapper]("age")) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age.toString, + ), + Wrapper(age).toString, + ) + } + }, + ) +} diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala index 0c5b73248..af0a2eca1 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala @@ -55,16 +55,6 @@ object QueryParameterSpec extends ZIOHttpSpec { testRoutes( s"/users?int=$int&optInt=${optInt.mkString}&string=$string&strings=${strings.mkString(",")}", Params(int, optInt, string, strings).toString, - ) && - testRoutes( - s"/users?int=$int&string=$string&strings=${strings.mkString(",")}", - Params(int, None, string, strings).toString, - ) && testRoutes( - s"/users?int=$int&optInt=${optInt.mkString}&strings=${strings.mkString(",")}", - Params(int, optInt, "", strings).toString, - ) && testRoutes( - s"/users?int=$int&optInt=${optInt.mkString}&string=$string", - Params(int, optInt, string, Chunk("defaultString")).toString, ) } }, @@ -110,8 +100,8 @@ object QueryParameterSpec extends ZIOHttpSpec { }, ), ) _ - // testRoutes(s"/users/$userId", s"path(users, $userId, None)") && - // testRoutes(s"/users/$userId?details=", s"path(users, $userId, None)") && + testRoutes(s"/users/$userId", s"path(users, $userId, None)") && + testRoutes(s"/users/$userId?details=", s"path(users, $userId, Some())") && testRoutes(s"/users/$userId?details=$details", s"path(users, $userId, Some($details))") } }, @@ -168,6 +158,38 @@ object QueryParameterSpec extends ZIOHttpSpec { ) } }, + test("query parameters with multiple values non empty") { + check(Gen.int, Gen.listOfN(3)(Gen.alphaNumericString)) { (userId, keys) => + val routes = Routes( + Endpoint(GET / "users" / int("userId")) + .query(HttpCodec.query[NonEmptyChunk[String]]("key")) + .out[String] + .implementHandler { + Handler.fromFunction { case (userId, keys) => + s"""path(users, $userId, ${keys.mkString(", ")})""" + } + }, + ) + val testRoutes = testEndpoint( + routes, + ) _ + + testRoutes( + s"/users/$userId?key=${keys(0)}&key=${keys(1)}&key=${keys(2)}", + s"path(users, $userId, ${keys.mkString(", ")})", + ) && + testRoutes( + s"/users/$userId?key=${keys(0)}&key=${keys(1)}", + s"path(users, $userId, ${keys.take(2).mkString(", ")})", + ) && + testRoutes( + s"/users/$userId?key=${keys(0)}", + s"path(users, $userId, ${keys.take(1).mkString(", ")})", + ) && routes + .runZIO(Request.get(s"/users/$userId")) + .map(resp => assertTrue(resp.status == Status.BadRequest)) + } + }, test("optional query parameters with multiple values") { check(Gen.int, Gen.listOfN(3)(Gen.alphaNumericString)) { (userId, keys) => val testRoutes = testEndpoint( @@ -341,7 +363,7 @@ object QueryParameterSpec extends ZIOHttpSpec { test("query parameters keys without values for multi value query") { val routes = Routes( Endpoint(GET / "users") - .query(HttpCodec.query[Chunk[RuntimeFlags]]("ints")) + .query(HttpCodec.query[Chunk[Int]]("ints")) .out[String] .implementHandler { Handler.fromFunction { queryParams => s"path(users, $queryParams)" } @@ -438,6 +460,6 @@ object QueryParameterSpec extends ZIOHttpSpec { assertTrue(response.status == Status.Ok) } }, - ) + ).provide(ErrorResponseConfig.debugLayer) } diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala index 7e6e1fb6c..345c63dd2 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala @@ -41,7 +41,7 @@ object RequestSpec extends ZIOHttpSpec { val testRoutes = testEndpointWithHeaders( Routes( Endpoint(GET / "users" / int("userId")) - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[String] .implementHandler { Handler.fromFunction { case (userId, correlationId) => @@ -49,7 +49,7 @@ object RequestSpec extends ZIOHttpSpec { } }, Endpoint(GET / "users" / int("userId") / "posts" / int("postId")) - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[String] .implementHandler { Handler.fromFunction { case (userId, postId, correlationId) => @@ -70,6 +70,48 @@ object RequestSpec extends ZIOHttpSpec { ) } }, + test("simple request with header with multiple values") { + check(Gen.int, Gen.listOfN(3)(Gen.uuid)) { (userId, correlationId) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(GET / "users" / int("userId")) + .header(HeaderCodec.headerAs[Chunk[java.util.UUID]]("X-Correlation-ID")) + .out[String] + .implementHandler { + Handler.fromFunction { case (userId, correlationId) => + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})" + } + }, + ), + ) _ + testRoutes( + s"/users/$userId", + correlationId.map(uuid => "X-Correlation-ID" -> uuid.toString), + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})", + ) + } + }, + test("simple request with header with multiple values non empty") { + check(Gen.int, Gen.listOfN(3)(Gen.uuid)) { (userId, correlationId) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(GET / "users" / int("userId")) + .header(HeaderCodec.headerAs[NonEmptyChunk[java.util.UUID]]("X-Correlation-ID")) + .out[String] + .implementHandler { + Handler.fromFunction { case (userId, correlationId) => + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})" + } + }, + ), + ) _ + testRoutes( + s"/users/$userId", + correlationId.map(uuid => "X-Correlation-ID" -> uuid.toString), + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})", + ) + } + }, test("custom content type") { check(Gen.int) { id => val endpoint = @@ -200,7 +242,7 @@ object RequestSpec extends ZIOHttpSpec { check(Gen.int, Gen.alphaNumericString) { (id, notACorrelationId) => val endpoint = Endpoint(GET / "posts") - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[Int] val routes = endpoint.implementHandler { @@ -219,7 +261,7 @@ object RequestSpec extends ZIOHttpSpec { check(Gen.int) { id => val endpoint = Endpoint(GET / "posts") - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[Int] val routes = endpoint.implementHandler { @@ -453,7 +495,7 @@ object RequestSpec extends ZIOHttpSpec { }, test("composite in codecs") { check(Gen.alphaNumericString, Gen.alphaNumericString) { (queryValue, headerValue) => - val headerOrQuery = HeaderCodec.name[String]("X-Header") | HttpCodec.query[String]("header") + val headerOrQuery = HeaderCodec.headerAs[String]("X-Header") | HttpCodec.query[String]("header") val endpoint = Endpoint(GET / "test").out[String].inCodec(headerOrQuery) val routes = endpoint.implementHandler(Handler.identity).toRoutes val request = Request.get( @@ -487,7 +529,7 @@ object RequestSpec extends ZIOHttpSpec { } }, test("composite out codecs") { - val headerOrQuery = HeaderCodec.name[String]("X-Header") | StatusCodec.status(Status.Created) + val headerOrQuery = HeaderCodec.headerAs[String]("X-Header") | StatusCodec.status(Status.Created) val endpoint = Endpoint(GET / "test").query(HttpCodec.query[Boolean]("Created")).outCodec(headerOrQuery) val routes = endpoint.implementHandler { diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala index 73ceede3a..0b8316905 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala @@ -2,7 +2,7 @@ package zio.http.endpoint.openapi import zio.json.ast.Json import zio.test._ -import zio.{Chunk, Scope, ZIO} +import zio.{Chunk, NonEmptyChunk, Scope, ZIO} import zio.schema.annotation._ import zio.schema.validation.Validation @@ -194,6 +194,13 @@ object OpenAPIGenSpec extends ZIOSpecDefault { .out[SimpleOutputBody] .outError[NotFoundError](Status.NotFound) + private val queryParamNonEmptyCollectionEndpoint = + Endpoint(GET / "withQuery") + .in[SimpleInputBody] + .query(HttpCodec.query[NonEmptyChunk[String]]("query")) + .out[SimpleOutputBody] + .outError[NotFoundError](Status.NotFound) + private val queryParamValidationEndpoint = Endpoint(GET / "withQuery") .in[SimpleInputBody] @@ -648,6 +655,132 @@ object OpenAPIGenSpec extends ZIOSpecDefault { test("with query parameter with multiple values") { val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", queryParamCollectionEndpoint) val json = toJsonAst(generated) + val expectedJson = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/withQuery" : { + | "get" : { + | "parameters" : [ + | + | { + | "name" : "query", + | "in" : "query", + | "schema" : + | { + | "type" : + | "string" + | }, + | "allowReserved" : false, + | "style" : "form" + | } + | ], + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleInputBody" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | } + | } + | } + | }, + | "404" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | } + | } + | } + | } + | } + | } + | }, + | "components" : { + | "schemas" : { + | "NotFoundError" : + | { + | "type" : + | "object", + | "properties" : { + | "message" : { + | "type" : + | "string" + | } + | }, + | "required" : [ + | "message" + | ] + | }, + | "SimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "userName" : { + | "type" : + | "string" + | }, + | "score" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "userName", + | "score" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("with query parameter with multiple values - non empty") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", queryParamNonEmptyCollectionEndpoint) + val json = toJsonAst(generated) val expectedJson = """{ | "openapi" : "3.1.0", | "info" : { diff --git a/zio-http/shared/src/main/scala/zio/http/Header.scala b/zio-http/shared/src/main/scala/zio/http/Header.scala index 48a87bb8a..d264ea62c 100644 --- a/zio-http/shared/src/main/scala/zio/http/Header.scala +++ b/zio-http/shared/src/main/scala/zio/http/Header.scala @@ -64,6 +64,7 @@ object Header { type Typed[HV] = HeaderType { type HeaderValue = HV } } + // @deprecated("Use Schema based header codecs instead", "3.1.0") final case class Custom(customName: CharSequence, value: CharSequence) extends Header { override type Self = Custom override def self: Self = this diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala index a6ced4dec..ce21ab126 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala @@ -16,24 +16,48 @@ package zio.http.codec +import java.util.UUID + import scala.util.Try import zio.stacktracer.TracingImplicits.disableAutoTrace +import zio.schema._ + import zio.http.Header.HeaderType import zio.http._ private[codec] trait HeaderCodecs { - private[http] def headerCodec[A](name: String, value: TextCodec[A]): HeaderCodec[A] = - HttpCodec.Header(name, value) + private[http] def headerCodec[A](name: String, value: TextCodec[A]): HeaderCodec[A] = { + val schema = value match { + case TextCodec.Constant(string) => + Schema[String].transformOrFail[Unit]( + s => if (s == string) Right(()) else Left(s"Header $name was not $string"), + (_: Unit) => Right(string), + ) + case TextCodec.StringCodec => Schema[String] + case TextCodec.IntCodec => Schema[Int] + case TextCodec.LongCodec => Schema[Long] + case TextCodec.BooleanCodec => Schema[Boolean] + case TextCodec.UUIDCodec => Schema[UUID] + } + HttpCodec.HeaderCustom(name, schema.asInstanceOf[Schema[A]]) + } def header(headerType: HeaderType): HeaderCodec[headerType.HeaderValue] = - headerCodec(headerType.name, TextCodec.string) - .transformOrFailLeft(headerType.parse(_))(headerType.render(_)) + HttpCodec.Header(headerType) + + def headerAs[A](name: String)(implicit schema: Schema[A]): HeaderCodec[A] = + HttpCodec.HeaderCustom(name, schema) + + def headers[A](implicit schema: Schema[A]): HeaderCodec[A] = + HttpCodec.HeaderCustom(schema) + @deprecated("Use Schema based headerAs instead", "3.1.0") def name[A](name: String)(implicit codec: TextCodec[A]): HeaderCodec[A] = headerCodec(name, codec) + @deprecated("Use Schema based API instead", "3.1.0") def nameTransform[A, B]( name: String, parse: B => A, @@ -43,11 +67,13 @@ private[codec] trait HeaderCodecs { Try(parse(s)).toEither.left.map(e => s"Failed to parse header $name: ${e.getMessage}"), )(render) + @deprecated("Use Schema based API instead", "3.1.0") def nameTransformOption[A, B](name: String, parse: B => Option[A], render: A => B)(implicit codec: TextCodec[B], ): HeaderCodec[A] = headerCodec(name, codec).transformOrFailLeft(parse(_).toRight(s"Failed to parse header $name"))(render) + @deprecated("Use Schema based API instead", "3.1.0") def nameTransformOrFail[A, B](name: String, parse: B => Either[String, A], render: A => B)(implicit codec: TextCodec[B], ): HeaderCodec[A] = diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala index 278f18536..673dcc242 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala @@ -18,18 +18,22 @@ package zio.http.codec import scala.annotation.tailrec import scala.reflect.ClassTag +import scala.util.Try import zio._ -import zio.stream.ZStream +import zio.stream.{ZPipeline, ZStream} import zio.schema.Schema -import zio.schema.annotation._ +import zio.schema.codec.DecodeError +import zio.schema.validation.{Validation, ValidationError} import zio.http.Header.Accept.MediaTypeWithQFactor +import zio.http.Header.HeaderType import zio.http._ -import zio.http.codec.HttpCodec.Query.QueryType +import zio.http.codec.HttpCodec.SchemaCodec.camelToKebab import zio.http.codec.HttpCodec.{Annotated, Metadata} +import zio.http.codec.StringCodec.StringCodec import zio.http.codec.internal._ /** @@ -337,12 +341,13 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with private[http] sealed trait AtomTag private[http] object AtomTag { - case object Status extends AtomTag - case object Path extends AtomTag - case object Content extends AtomTag - case object Query extends AtomTag - case object Header extends AtomTag - case object Method extends AtomTag + case object Status extends AtomTag + case object Path extends AtomTag + case object Content extends AtomTag + case object Query extends AtomTag + case object Header extends AtomTag + case object HeaderCustom extends AtomTag + case object Method extends AtomTag } def empty: HttpCodec[Any, Unit] = @@ -2264,140 +2269,220 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): ContentStream[A] = copy(index = index) } private[http] final case class Query[A, Out]( - queryType: Query.QueryType[A], + codec: SchemaCodec[A], index: Int = 0, ) extends Atom[HttpCodecType.Query, Out] { self => def erase: Query[Any, Any] = self.asInstanceOf[Query[Any, Any]] - def tag: AtomTag = AtomTag.Query - def index(index: Int): Query[A, Out] = copy(index = index) - def isOptional: Boolean = - queryType match { - case QueryType.Primitive(_, BinaryCodecWithSchema(_, schema)) if schema.isInstanceOf[Schema.Optional[_]] => - true - case QueryType.Record(recordSchema) => - recordSchema match { - case s if s.isInstanceOf[Schema.Optional[_]] => true - case record: Schema.Record[_] if record.fields.forall(_.optional) => true - case _ => false - } - case _ => false - } + def isCollection: Boolean = codec.isCollection + + def isOptional: Boolean = codec.isOptional + + def isOptionalSchema: Boolean = codec.isOptionalSchema + + def isPrimitive: Boolean = codec.isPrimitive + + def isRecord: Boolean = codec.isRecord + + def nameUnsafe: String = codec.name.get /** * Returns a new codec, where the value produced by this one is optional. */ override def optional: HttpCodec[HttpCodecType.Query, Option[Out]] = - queryType match { - case QueryType.Primitive(name, codec) if codec.schema.isInstanceOf[Schema.Optional[_]] => - throw new IllegalArgumentException( - s"Cannot make an optional query parameter optional. Name: $name schema: ${codec.schema}", - ) - case QueryType.Primitive(name, codec) => - val optionalSchema = codec.schema.optional - copy(queryType = - QueryType.Primitive(name, BinaryCodecWithSchema(TextBinaryCodec.fromSchema(optionalSchema), optionalSchema)), - ) - case QueryType.Record(recordSchema) if recordSchema.isInstanceOf[Schema.Optional[_]] => - throw new IllegalArgumentException(s"Cannot make an optional query parameter optional") - case QueryType.Record(recordSchema) => - val optionalSchema = recordSchema.optional - copy(queryType = QueryType.Record(optionalSchema)) - case queryType @ QueryType.Collection(_, _, false) => - copy(queryType = QueryType.Collection(queryType.colSchema, queryType.elements, optional = true)) - case queryType @ QueryType.Collection(_, _, true) => - throw new IllegalArgumentException(s"Cannot make an optional query parameter optional: $queryType") - + if (isOptionalSchema) { + throw new IllegalArgumentException("Query is already optional") + } else { + Annotated(Query(codec.optional, index), Metadata.Optional()) } + def tag: AtomTag = AtomTag.Query + } - private[http] object Query { - sealed trait QueryType[A] - object QueryType { - case class Primitive[A](name: String, codec: BinaryCodecWithSchema[A]) extends QueryType[A] - case class Collection[A](colSchema: Schema.Collection[_, _], elements: QueryType.Primitive[A], optional: Boolean) - extends QueryType[A] { - def toCollection(values: Chunk[Any]): A = - colSchema match { - case Schema.Sequence(_, fromChunk, _, _, _) => - fromChunk.asInstanceOf[Chunk[Any] => Any](values).asInstanceOf[A] - case Schema.Set(_, _) => - values.toSet.asInstanceOf[A] - case _ => - throw new IllegalArgumentException( - s"Unsupported collection schema for query object field of type: $colSchema", - ) - } - } - case class Record[A](recordSchema: Schema[A]) extends QueryType[A] { - private var namesAndCodecs: Chunk[(Schema.Field[_, _], BinaryCodecWithSchema[Any])] = _ - private[http] def fieldAndCodecs: Chunk[(Schema.Field[_, _], BinaryCodecWithSchema[Any])] = - if (namesAndCodecs == null) { - namesAndCodecs = recordSchema match { - case record: Schema.Record[A] => - record.fields.map { field => - validateSchema(field.name, field.schema) - val codec = binaryCodecForField(field.annotations.foldLeft(field.schema)(_ annotate _)) - (unlazy(field.asInstanceOf[Schema.Field[Any, Any]]), codec) - } - case s if s.isInstanceOf[Schema.Optional[_]] => - val record = s.asInstanceOf[Schema.Optional[A]].schema.asInstanceOf[Schema.Record[A]] - record.fields.map { field => - validateSchema(field.name, field.annotations.foldLeft(field.schema)(_ annotate _)) - val codec = binaryCodecForField(field.schema) - (field, codec) - } - case s => throw new IllegalArgumentException(s"Unsupported schema for query object field of type: $s") - } - namesAndCodecs - } else { - namesAndCodecs - } + object Query { + def apply[A](name: String, schema: Schema[A]): Query[A, A] = Query(SchemaCodec(Some(name), schema)) + def apply[A](schema: Schema[A]): Query[A, A] = Query(SchemaCodec(None, schema)) + } + + final case class SchemaCodec[A](name: Option[String], schema: Schema[A], kebabCase: Boolean = false) { + + def erasedSchema: Schema[Any] = schema.asInstanceOf[Schema[Any]] + + val isCollection: Boolean = schema match { + case _: Schema.Collection[_, _] => true + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Collection[_, _]] => true + case _ => false + } + + val isOptional: Boolean = schema match { + case _: Schema.Optional[_] => + true + case record: Schema.Record[_] => + record.fields.forall(_.optional) || record.defaultValue.isRight + case d: Schema.Collection[_, _] => + Try(d.empty).isSuccess || d.defaultValue.isRight + case _ => + false + } + + val isOptionalSchema: Boolean = + schema match { + case _: Schema.Optional[_] => true + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Optional[_]] => true + case _ => false } - private def unlazy(field: Schema.Field[Any, Any]): Schema.Field[Any, Any] = field.schema match { - case Schema.Lazy(schema) => - Schema.Field( - field.name, - schema(), - field.annotations, - field.validation, - field.get, - field.set, + val isPrimitive: Boolean = schema match { + case _: Schema.Primitive[_] => true + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Primitive[_]] => true + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Primitive[_]] => true + case _ => false + } + + val isRecord: Boolean = schema match { + case _: Schema.Record[_] => true + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => true + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Record[_]] => true + case _ => false + } + + def optional: SchemaCodec[Option[A]] = copy(schema = schema.optional) + + val recordFields: Chunk[(Schema.Field[_, _], SchemaCodec[Any])] = { + val fields = schema match { + case record: Schema.Record[A] => + record.fields + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[A]].fields + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[A]].fields + case _ => Chunk.empty + } + fields.map(unlazyField).map { + case field if field.schema.isInstanceOf[Schema.Collection[_, _]] => + val elementSchema = field.schema.asInstanceOf[Schema.Collection[_, _]] match { + case s: Schema.NonEmptySequence[_, _, _] => s.elementSchema + case s: Schema.Sequence[_, _, _] => s.elementSchema + case s: Schema.Set[_] => s.elementSchema + case _: Schema.Map[_, _] => throw new IllegalArgumentException("Maps are not supported") + case _: Schema.NonEmptyMap[_, _] => throw new IllegalArgumentException("Maps are not supported") + } + val codec = SchemaCodec(Some(if (!kebabCase) field.name else camelToKebab(field.name)), elementSchema) + (field, codec.asInstanceOf[SchemaCodec[Any]]) + case field => + val codec = SchemaCodec( + Some(if (!kebabCase) field.name else camelToKebab(field.name)), + field.annotations.foldLeft(field.schema)(_ annotate _), ) - case _ => field + (field, codec.asInstanceOf[SchemaCodec[Any]]) } + } - private def binaryCodecForField[A](schema: Schema[A]): BinaryCodecWithSchema[Any] = (schema match { - case schema @ Schema.Primitive(_, _) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case Schema.Transform(_, _, _, _, _) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case Schema.Optional(_, _) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case e: Schema.Enum[_] if isSimple(e) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case l @ Schema.Lazy(_) => binaryCodecForField(l.schema) - case Schema.Set(schema, _) => binaryCodecForField(schema) - case Schema.Sequence(schema, _, _, _, _) => binaryCodecForField(schema) - case schema => throw new IllegalArgumentException(s"Unsupported schema for query object field of type: $schema") - }).asInstanceOf[BinaryCodecWithSchema[Any]] - - def isSimple(schema: Schema.Enum[_]): Boolean = - schema.annotations.exists(_.isInstanceOf[simpleEnum]) - - @tailrec - private def validateSchema[A](name: String, schema: Schema[A]): Unit = schema match { - case _: Schema.Primitive[A] => () - case Schema.Transform(schema, _, _, _, _) => validateSchema(name, schema) - case Schema.Optional(schema, _) => validateSchema(name, schema) - case Schema.Lazy(schema) => validateSchema(name, schema()) - case Schema.Set(schema, _) => validateSchema(name, schema) - case Schema.Sequence(schema, _, _, _, _) => validateSchema(name, schema) - case s => throw new IllegalArgumentException(s"Unsupported schema for query object field of type: $s") - } + val recordSchema: Schema.Record[Any] = schema match { + case record: Schema.Record[_] => + record.asInstanceOf[Schema.Record[Any]] + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[Any]] + case _ => null + } + val stringCodec: StringCodec[Any] = + stringCodecForSchema(schema.asInstanceOf[Schema[Any]]) + + private def stringCodecForSchema(s: Schema[_]): StringCodec[Any] = { + (s match { + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Primitive[_]] => + StringCodec.fromSchema(schema) + case s: Schema.Optional[_] => + stringCodecForSchema(s.schema) + case s: Schema.Collection[_, _] => + s match { + case schema: Schema.NonEmptySequence[_, _, _] => StringCodec.fromSchema(schema.elementSchema) + case schema: Schema.Sequence[_, _, _] => StringCodec.fromSchema(schema.elementSchema) + case schema: Schema.Set[_] => StringCodec.fromSchema(schema.elementSchema) + case _: Schema.Map[_, _] => StringCodec.fromSchema(s) + case _: Schema.NonEmptyMap[_, _] => StringCodec.fromSchema(s) + } + case s: Schema.Lazy[_] => StringCodec.fromSchema(s.schema) + case s: Schema.Transform[Any, Any, _] @unchecked => + val stringCodec = StringCodec.fromSchema(s.schema) + new StringCodec[Any] { + override def decode(whole: String): Either[DecodeError, Any] = + stringCodec.decode(whole).flatMap(s.f(_).left.map(DecodeError.ReadError(Cause.empty, _))) + + override def streamDecoder: ZPipeline[Any, DecodeError, Char, Any] = + stringCodec.streamDecoder >>> ZPipeline.map(s.f(_).left.map(DecodeError.ReadError(Cause.empty, _))) + + override def encode(value: Any): String = + stringCodec.encode(s.g(value).fold(msg => throw new Exception(msg), identity)) + + override def streamEncoder: ZPipeline[Any, Nothing, Any, Char] = + ZPipeline.map[Any, Any]( + s.g(_).fold(msg => throw new Exception(msg), identity), + ) >>> stringCodec.streamEncoder + } + case schema: Schema[_] => StringCodec.fromSchema(schema) + }).asInstanceOf[StringCodec[Any]] } + + private def unlazyField(field: Schema.Field[_, _]): Schema.Field[_, _] = field match { + case f if f.schema.isInstanceOf[Schema.Lazy[_]] => + Schema.Field( + f.name, + f.schema.asInstanceOf[Schema.Lazy[_]].schema.asInstanceOf[Schema[Any]], + f.annotations, + f.validation.asInstanceOf[Validation[Any]], + f.get.asInstanceOf[Any => Any], + f.set.asInstanceOf[(Any, Any) => Any], + ) + case f => f + } + + def validate(value: Any): Chunk[ValidationError] = + schema.asInstanceOf[Schema[_]] match { + case Schema.Optional(schema: Schema[Any], _) => + schema.validate(value)(schema) + case schema: Schema[_] => + schema.asInstanceOf[Schema[Any]].validate(value)(schema.asInstanceOf[Schema[Any]]) + } + val defaultValue: A = + if (schema.isInstanceOf[Schema.Collection[_, _]]) { + Try(schema.asInstanceOf[Schema.Collection[A, _]].empty).fold( + _ => null.asInstanceOf[A], + identity, + ) + } else { + schema.defaultValue match { + case Right(value) => value + case Left(_) => + schema match { + case _: Schema.Optional[_] => None.asInstanceOf[A] + case collection: Schema.Collection[A, _] => + Try(collection.empty).fold( + _ => null.asInstanceOf[A], + identity, + ) + case _ => null.asInstanceOf[A] + } + } + } + + } + + object SchemaCodec { + private def camelToKebab(s: String): String = + if (s.isEmpty) "" + else if (s.head.isUpper) s.head.toLower.toString + camelToKebab(s.tail) + else if (s.contains('-')) s + else + s.foldLeft("") { (acc, c) => + if (c.isUpper) acc + "-" + c.toLower + else acc + c + } } private[http] final case class Method[A](codec: SimpleCodec[zio.http.Method, A], index: Int = 0) @@ -2409,7 +2494,34 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): Method[A] = copy(index = index) } - private[http] final case class Header[A](name: String, textCodec: TextCodec[A], index: Int = 0) + private[http] final case class HeaderCustom[A](codec: SchemaCodec[A], index: Int = 0) + extends Atom[HttpCodecType.Header, A] { + self => + def erase: HeaderCustom[Any] = self.asInstanceOf[HeaderCustom[Any]] + + override def optional: HttpCodec[HttpCodecType.Header, Option[A]] = + if (codec.isOptionalSchema) { + throw new IllegalArgumentException("Header is already optional") + } else { + Annotated( + HeaderCustom(codec.optional, index), + Metadata.Optional(), + ) + } + + def tag: AtomTag = AtomTag.HeaderCustom + + def index(index: Int): HeaderCustom[A] = copy(index = index) + } + + object HeaderCustom { + def apply[A](name: String, schema: Schema[A]): HeaderCustom[A] = + HeaderCustom(SchemaCodec(Some(name), schema, kebabCase = true)) + def apply[A](schema: Schema[A]): HeaderCustom[A] = + HeaderCustom(SchemaCodec(None, schema, kebabCase = true)) + } + + private[http] final case class Header[A](headerType: HeaderType.Typed[A], index: Int = 0) extends Atom[HttpCodecType.Header, A] { self => def erase: Header[Any] = self.asInstanceOf[Header[Any]] diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala index bcd97223d..3df1973ab 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala @@ -23,6 +23,7 @@ import zio.{Cause, Chunk} import zio.schema.codec.DecodeError import zio.schema.validation.ValidationError +import zio.http.Header.HeaderType import zio.http.{Path, Status} sealed trait HttpCodecError extends Exception with NoStackTrace with Product with Serializable { @@ -33,6 +34,9 @@ object HttpCodecError { final case class MissingHeader(headerName: String) extends HttpCodecError { def message = s"Missing header $headerName" } + final case class MissingHeaders(headerNames: Chunk[String]) extends HttpCodecError { + def message = s"Missing headers ${headerNames.mkString(", ")}" + } final case class MalformedMethod(expected: zio.http.Method, actual: zio.http.Method) extends HttpCodecError { def message = s"Expected $expected but found $actual" } @@ -48,6 +52,12 @@ object HttpCodecError { final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Malformed header $headerName failed to decode using $textCodec" } + final case class MalformedCustomHeader(headerName: String, cause: DecodeError) extends HttpCodecError { + def message = s"Malformed custom header $headerName could not be decoded: $cause" + } + final case class MalformedTypedHeader(headerName: String) extends HttpCodecError { + def message = s"Malformed header $headerName" + } final case class MissingQueryParam(queryParamName: String) extends HttpCodecError { def message = s"Missing query parameter $queryParamName" } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala index 4bc203f5e..4f98ec8e4 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala @@ -26,100 +26,33 @@ private[codec] trait QueryCodecs { def query[A](name: String)(implicit schema: Schema[A]): QueryCodec[A] = schema match { - case s @ Schema.Primitive(_, _) => - HttpCodec.Query( - HttpCodec.Query.QueryType - .Primitive(name, BinaryCodecWithSchema.fromBinaryCodec(TextBinaryCodec.fromSchema(s))(s)), - ) - case c @ Schema.Sequence(elementSchema, _, _, _, _) => - if (supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - c, - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(elementSchema), elementSchema), - ), - optional = false, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sequences") - } - case c @ Schema.Set(elementSchema, _) => - if (supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - c, - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(elementSchema), elementSchema), - ), - optional = false, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sets") - } - case Schema.Optional(Schema.Primitive(_, _), _) => - HttpCodec.Query( - HttpCodec.Query.QueryType - .Primitive(name, BinaryCodecWithSchema.fromBinaryCodec(TextBinaryCodec.fromSchema(schema))(schema)), - ) - case Schema.Optional(c @ Schema.Sequence(elementSchema, _, _, _, _), _) => - if (supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - c, - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(elementSchema), elementSchema), - ), - optional = true, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sequences") - } - case Schema.Optional(inner, _) if inner.isInstanceOf[Schema.Set[_]] => - val elementSchema = inner.asInstanceOf[Schema.Set[Any]].elementSchema - if (supportedElementSchema(elementSchema)) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - inner.asInstanceOf[Schema.Set[_]], - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(inner), inner), - ), - optional = true, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sets") - } - case enum0: Schema.Enum[_] if enum0.annotations.exists(_.isInstanceOf[simpleEnum]) => - HttpCodec.Query( - HttpCodec.Query.QueryType - .Primitive(name, BinaryCodecWithSchema.fromBinaryCodec(TextBinaryCodec.fromSchema(schema))(schema)), - ) - case record: Schema.Record[A] if record.fields.size == 1 => - val field = record.fields.head - if (supportedElementSchema(field.schema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(record), record), - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of records") - } - case other => + case c: Schema.Collection[_, _] if !supportedCollection(c) => + throw new IllegalArgumentException(s"Collection schema $c is not supported for query codecs") + case enum0: Schema.Enum[_] if !enum0.annotations.exists(_.isInstanceOf[simpleEnum]) => + throw new IllegalArgumentException(s"Enum schema $enum0 is not supported. All cases must be objects.") + case record: Schema.Record[A] if record.fields.size != 1 => + throw new IllegalArgumentException("Use queryAll[A] for records with more than one field") + case record: Schema.Record[A] if !supportedElementSchema(record.fields.head.schema.asInstanceOf[Schema[Any]]) => throw new IllegalArgumentException( - s"Only primitive types, sequences, sets, optional, enums and records with a single field can be used to infer query codecs, but got $other", + s"Only primitive types and simple enums can be used in single field records, but got ${record.fields.head.schema}", ) + case other => + HttpCodec.Query(name, other) } + private def supportedCollection(schema: Schema.Collection[_, _]): Boolean = schema match { + case Schema.Map(_, _, _) => + false + case Schema.NonEmptyMap(_, _, _) => + false + case Schema.Sequence(elementSchema, _, _, _, _) => + supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) + case Schema.NonEmptySequence(elementSchema, _, _, _, _) => + supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) + case Schema.Set(elementSchema, _) => + supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) + } + @tailrec private def supportedElementSchema(elementSchema: Schema[Any]): Boolean = elementSchema match { case Schema.Lazy(schema0) => supportedElementSchema(schema0()) @@ -131,11 +64,16 @@ private[codec] trait QueryCodecs { def queryAll[A](implicit schema: Schema[A]): QueryCodec[A] = schema match { - case _: Schema.Primitive[A] => + case _: Schema.Primitive[A] => throw new IllegalArgumentException("Use query[A](name: String) for primitive types") - case record: Schema.Record[A] => HttpCodec.Query(HttpCodec.Query.QueryType.Record(record)) - case Schema.Optional(_, _) => HttpCodec.Query(HttpCodec.Query.QueryType.Record(schema)) - case _ => throw new IllegalArgumentException("Only case classes can be used to infer query codecs") + case record: Schema.Record[A] => + HttpCodec.Query(record) + case Schema.Optional(s, _) if s.isInstanceOf[Schema.Record[_]] => + HttpCodec.Query(schema) + case _ => + throw new IllegalArgumentException( + "Only case classes can be used with queryAll. Maybe you wanted to use query[A](name: String)?", + ) } } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala new file mode 100644 index 000000000..ae49864c5 --- /dev/null +++ b/zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala @@ -0,0 +1,394 @@ +package zio.http.codec + +import java.time._ +import java.util.{Currency, UUID} + +import scala.annotation.tailrec + +import zio._ + +import zio.stream._ + +import zio.schema._ +import zio.schema.annotation.simpleEnum +import zio.schema.codec._ + +import zio.http.Charsets + +object StringCodec { + type StringCodec[A] = Codec[String, Char, A] + private def errorCodec[A](schema: Schema[A]) = + new Codec[String, Char, A] { + override def decode(whole: String): Either[DecodeError, A] = throw new IllegalArgumentException( + s"Schema $schema is not supported by StringCodec.", + ) + + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = throw new IllegalArgumentException( + s"Schema $schema is not supported by StringCodec.", + ) + + override def encode(value: A): String = throw new IllegalArgumentException( + s"Schema $schema is not supported by StringCodec.", + ) + + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = throw new IllegalArgumentException( + s"Schema $schema is not supported by StringCodec.", + ) + } + + @tailrec + private def emptyStringIsValue(schema: Schema[_]): Boolean = { + schema match { + case value: Schema.Optional[_] => + val innerSchema = value.schema + emptyStringIsValue(innerSchema) + case _ => + schema.asInstanceOf[Schema.Primitive[_]].standardType match { + case StandardType.UnitType => true + case StandardType.StringType => true + case StandardType.BinaryType => true + case StandardType.CharType => true + case _ => false + } + } + } + + implicit def fromSchema[A](implicit schema: Schema[A]): Codec[String, Char, A] = { + schema match { + case Schema.Optional(schema, _) => + val codec = fromSchema(schema).asInstanceOf[Codec[String, Char, Any]] + new Codec[String, Char, A] { + override def encode(a: A): String = { + a match { + case Some(value) => codec.encode(value) + case None => "" + } + } + + override def decode(c: String): Either[DecodeError, A] = { + if (c.isEmpty && !emptyStringIsValue(schema)) Right(None.asInstanceOf[A]) + else { + codec.decode(c).map(Some(_)).asInstanceOf[Either[DecodeError, A]] + } + } + + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = + ZPipeline.map((a: A) => encode(a).toSeq).flattenIterables + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = + codec.streamDecoder.map(v => Some(v).asInstanceOf[A]) + } + case enum0: Schema.Enum[_] if enum0.annotations.exists(_.isInstanceOf[simpleEnum]) => + val stringCodec = fromSchema(Schema.Primitive(StandardType.StringType)) + val caseMap = enum0.nonTransientCases + .map(case_ => + case_.schema.asInstanceOf[Schema.CaseClass0[A]].defaultConstruct() -> + case_.caseName, + ) + .toMap + val reverseCaseMap = caseMap.map(_.swap) + new Codec[String, Char, A] { + override def encode(a: A): String = { + val caseName = caseMap(a.asInstanceOf[A]) + stringCodec.encode(caseName) + } + + override def decode(c: String): Either[DecodeError, A] = + stringCodec.decode(c).flatMap { caseName => + reverseCaseMap.get(caseName) match { + case Some(value) => Right(value.asInstanceOf[A]) + case None => Left(DecodeError.MissingCase(caseName, enum0)) + } + } + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = + ZPipeline.map((a: A) => encode(a).toSeq).flattenIterables + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = + stringCodec.streamDecoder.mapZIO { caseName => + reverseCaseMap.get(caseName) match { + case Some(value) => ZIO.succeed(value.asInstanceOf[A]) + case None => ZIO.fail(DecodeError.MissingCase(caseName, enum0)) + } + } + } + + case enum0: Schema.Enum[_] => errorCodec(enum0) + case record: Schema.Record[_] if record.fields.size == 1 => + val fieldSchema = record.fields.head.schema + val codec = fromSchema(fieldSchema).asInstanceOf[Codec[String, Char, A]] + new Codec[String, Char, A] { + override def encode(a: A): String = + codec.encode(record.deconstruct(a)(Unsafe.unsafe).head.get.asInstanceOf[A]) + override def decode(c: String): Either[DecodeError, A] = + codec + .decode(c) + .flatMap(a => + record.construct(Chunk(a))(Unsafe.unsafe).left.map(s => DecodeError.ReadError(Cause.empty, s)), + ) + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = + ZPipeline.map((a: A) => encode(a).toSeq).flattenIterables + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = + codec.streamDecoder.mapZIO(a => + ZIO.fromEither( + record.construct(Chunk(a))(Unsafe.unsafe).left.map(s => DecodeError.ReadError(Cause.empty, s)), + ), + ) + } + case record: Schema.Record[_] => errorCodec(record) + case collection: Schema.Collection[_, _] => errorCodec(collection) + case Schema.Transform(schema, f, g, _, _) => + val codec = fromSchema(schema) + new Codec[String, Char, A] { + override def encode(a: A): String = codec.encode(g(a).fold(e => throw new Exception(e), identity)) + override def decode(c: String): Either[DecodeError, A] = codec + .decode(c) + .flatMap(x => + f(x).left + .map(DecodeError.ReadError(Cause.fail(new Exception("Error during decoding")), _)), + ) + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = + ZPipeline.mapChunks(_.flatMap(encode)) + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = codec.streamDecoder.map { x => + f(x) match { + case Left(value) => throw DecodeError.ReadError(Cause.fail(new Exception("Error in decoding")), value) + case Right(a) => a + } + } + } + case Schema.Primitive(_, _) => + new Codec[String, Char, A] { + val decode0: String => Either[DecodeError, Any] = + schema match { + case Schema.Primitive(standardType, _) => + standardType match { + case StandardType.UnitType => + val result = Right("") + (_: String) => result + case StandardType.StringType => + (s: String) => Right(s) + case StandardType.BoolType => + (s: String) => + s.toLowerCase match { + case "true" | "on" | "yes" | "1" => Right(true) + case "false" | "off" | "no" | "0" => Right(false) + case _ => Left(DecodeError.ReadError(Cause.fail(new Exception("Invalid boolean value")), s)) + } + case StandardType.ByteType => + (s: String) => + try { + Right(s.toByte) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.ShortType => + (s: String) => + try { + Right(s.toShort) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.IntType => + (s: String) => + try { + Right(s.toInt) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.LongType => + (s: String) => + try { + Right(s.toLong) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.FloatType => + (s: String) => + try { + Right(s.toFloat) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.DoubleType => + (s: String) => + try { + Right(s.toDouble) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.BinaryType => + val result = Left(DecodeError.UnsupportedSchema(schema, "TextCodec")) + (_: String) => result + case StandardType.CharType => + (s: String) => Right(s.charAt(0)) + case StandardType.UUIDType => + (s: String) => + try { + Right(UUID.fromString(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.BigDecimalType => + (s: String) => + try { + Right(BigDecimal(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.BigIntegerType => + (s: String) => + try { + Right(BigInt(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.DayOfWeekType => + (s: String) => + try { + Right(DayOfWeek.valueOf(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.MonthType => + (s: String) => + try { + Right(Month.valueOf(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.MonthDayType => + (s: String) => + try { + Right(MonthDay.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.PeriodType => + (s: String) => + try { + Right(Period.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.YearType => + (s: String) => + try { + Right(Year.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.YearMonthType => + (s: String) => + try { + Right(YearMonth.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.ZoneIdType => + (s: String) => + try { + Right(ZoneId.of(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.ZoneOffsetType => + (s: String) => + try { + Right(ZoneOffset.of(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.DurationType => + (s: String) => + try { + Right(java.time.Duration.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.InstantType => + (s: String) => + try { + Right(Instant.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.LocalDateType => + (s: String) => + try { + Right(LocalDate.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.LocalTimeType => + (s: String) => + try { + Right(LocalTime.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.LocalDateTimeType => + (s: String) => + try { + Right(LocalDateTime.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.OffsetTimeType => + (s: String) => + try { + Right(OffsetTime.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.OffsetDateTimeType => + (s: String) => + try { + Right(OffsetDateTime.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.ZonedDateTimeType => + (s: String) => + try { + Right(ZonedDateTime.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.CurrencyType => + (s: String) => + try { + Right(Currency.getInstance(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + } + case schema => + val result = Left( + DecodeError.UnsupportedSchema(schema, "Only primitive types are supported for text decoding."), + ) + (_: String) => result + } + override def encode(a: A): String = + schema match { + case Schema.Primitive(_, _) => a.toString + case _ => + throw new IllegalArgumentException( + s"Cannot encode $a of type ${a.getClass} with schema $schema", + ) + } + + override def decode(c: String): Either[DecodeError, A] = + decode0(c).map(_.asInstanceOf[A]) + + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = + ZPipeline.map((a: A) => a.toString.toSeq).flattenIterables + + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = + ZPipeline + .chunks[Char] + .map(_.asString) + .mapZIO(s => ZIO.fromEither(decode(s))) + .mapErrorCause(e => Cause.fail(DecodeError.ReadError(e, e.squash.getMessage))) + } + case Schema.Lazy(schema0) => fromSchema(schema0()) + case _ => errorCodec(schema) + } + } +} diff --git a/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala index cd552ad81..9059501f9 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala @@ -120,10 +120,10 @@ object TextBinaryCodec { ) override def streamEncoder: ZPipeline[Any, Nothing, A, Byte] = ZPipeline.mapChunks(_.flatMap(encode)) - override def streamDecoder: ZPipeline[Any, DecodeError, Byte, A] = codec.streamDecoder.map { x => + override def streamDecoder: ZPipeline[Any, DecodeError, Byte, A] = codec.streamDecoder.mapZIO { x => f(x) match { - case Left(value) => throw DecodeError.ReadError(Cause.fail(new Exception("Error in decoding")), value) - case Right(a) => a + case Left(value) => ZIO.fail(DecodeError.ReadError(Cause.fail(new Exception("Error in decoding")), value)) + case Right(a) => ZIO.succeed(a) } } } @@ -356,7 +356,7 @@ object TextBinaryCodec { override def streamDecoder: ZPipeline[Any, DecodeError, Byte, A] = (ZPipeline[Byte] >>> ZPipeline.utf8Decode) - .map(s => decode(Chunk.fromArray(s.getBytes)).fold(throw _, identity)) + .mapZIO(s => ZIO.fromEither(decode(Chunk.fromArray(s.getBytes)))) .mapErrorCause(e => Cause.fail(DecodeError.ReadError(e, e.squash.getMessage))) } case Schema.Lazy(schema0) => fromSchema(schema0()) diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala index 9a53a41ad..3557b776d 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala @@ -25,30 +25,34 @@ private[http] final case class Atomized[A]( path: A, query: A, header: A, + headerCustom: A, content: A, ) { def get(tag: HttpCodec.AtomTag): A = { tag match { - case HttpCodec.AtomTag.Status => status - case HttpCodec.AtomTag.Path => path - case HttpCodec.AtomTag.Content => content - case HttpCodec.AtomTag.Query => query - case HttpCodec.AtomTag.Header => header - case HttpCodec.AtomTag.Method => method + case HttpCodec.AtomTag.Status => status + case HttpCodec.AtomTag.Path => path + case HttpCodec.AtomTag.Content => content + case HttpCodec.AtomTag.Query => query + case HttpCodec.AtomTag.Header => header + case HttpCodec.AtomTag.HeaderCustom => headerCustom + case HttpCodec.AtomTag.Method => method } } def update(tag: HttpCodec.AtomTag)(f: A => A): Atomized[A] = { tag match { - case HttpCodec.AtomTag.Status => copy(status = f(status)) - case HttpCodec.AtomTag.Path => copy(path = f(path)) - case HttpCodec.AtomTag.Content => copy(content = f(content)) - case HttpCodec.AtomTag.Query => copy(query = f(query)) - case HttpCodec.AtomTag.Header => copy(header = f(header)) - case HttpCodec.AtomTag.Method => copy(method = f(method)) + case HttpCodec.AtomTag.Status => copy(status = f(status)) + case HttpCodec.AtomTag.Path => copy(path = f(path)) + case HttpCodec.AtomTag.Content => copy(content = f(content)) + case HttpCodec.AtomTag.Query => copy(query = f(query)) + case HttpCodec.AtomTag.Header => copy(header = f(header)) + case HttpCodec.AtomTag.HeaderCustom => copy(headerCustom = f(header)) + case HttpCodec.AtomTag.Method => copy(method = f(method)) } } } private[http] object Atomized { - def apply[A](defValue: => A): Atomized[A] = Atomized(defValue, defValue, defValue, defValue, defValue, defValue) + def apply[A](defValue: => A): Atomized[A] = + Atomized(defValue, defValue, defValue, defValue, defValue, defValue, defValue) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala index 52cce2c84..af296b3cf 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala @@ -27,6 +27,7 @@ private[http] final case class AtomizedCodecs( path: Chunk[PathCodec[_]], query: Chunk[Query[_, _]], header: Chunk[Header[_]], + headerCustom: Chunk[HeaderCustom[_]], content: Chunk[BodyCodec[_]], status: Chunk[SimpleCodec[zio.http.Status, _]], ) { self => @@ -35,9 +36,10 @@ private[http] final case class AtomizedCodecs( case method0: Method[_] => self.copy(method = method :+ method0.codec) case query0: Query[_, _] => self.copy(query = query :+ query0) case header0: Header[_] => self.copy(header = header :+ header0) + case header0: HeaderCustom[_] => self.copy(headerCustom = headerCustom :+ header0) + case status0: Status[_] => self.copy(status = status :+ status0.codec) case content0: Content[_] => self.copy(content = content :+ BodyCodec.Single(content0.codec, content0.name)) - case status0: Status[_] => self.copy(status = status :+ status0.codec) case stream0: ContentStream[_] => self.copy(content = content :+ BodyCodec.Multiple(stream0.codec, stream0.name)) } @@ -48,6 +50,7 @@ private[http] final case class AtomizedCodecs( path = Array.ofDim(path.length), query = Array.ofDim(query.length), header = Array.ofDim(header.length), + headerCustom = Array.ofDim(headerCustom.length), content = Array.ofDim(content.length), status = Array.ofDim(status.length), ) @@ -59,6 +62,7 @@ private[http] final case class AtomizedCodecs( path = path.materialize, query = query.materialize, header = header.materialize, + headerCustom = headerCustom.materialize, content = content.materialize, status = status.materialize, ) @@ -71,6 +75,7 @@ private[http] object AtomizedCodecs { path = Chunk.empty, query = Chunk.empty, header = Chunk.empty, + headerCustom = Chunk.empty, content = Chunk.empty, status = Chunk.empty, ) diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index 44d99b72d..61d858d99 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -16,16 +16,17 @@ package zio.http.codec.internal +import scala.annotation.tailrec import scala.util.Try import zio._ -import zio.schema.codec.{BinaryCodec, DecodeError} +import zio.schema.codec.DecodeError import zio.schema.{Schema, StandardType} import zio.http.Header.Accept.MediaTypeWithQFactor import zio.http._ -import zio.http.codec.HttpCodec.Query.QueryType +import zio.http.codec.StringCodec.StringCodec import zio.http.codec._ private[codec] trait EncoderDecoder[-AtomTypes, Value] { self => @@ -46,7 +47,7 @@ private[codec] object EncoderDecoder { val flattened = httpCodec.alternatives flattened.length match { - case 0 => Undefined() + case 0 => Undefined.asInstanceOf[EncoderDecoder[AtomTypes, Value]] case 1 => Single(flattened.head._1) case _ => Multiple(flattened) } @@ -109,7 +110,7 @@ private[codec] object EncoderDecoder { } } - private final case class Undefined[-AtomTypes, Value]() extends EncoderDecoder[AtomTypes, Value] { + private object Undefined extends EncoderDecoder[Any, Any] { val encodeWithErrorMessage = """ @@ -125,7 +126,7 @@ private[codec] object EncoderDecoder { override def encodeWith[Z]( config: CodecConfig, - value: Value, + value: Any, outputTypes: Chunk[MediaTypeWithQFactor], )(f: (zio.http.URL, Option[zio.http.Status], Option[zio.http.Method], zio.http.Headers, zio.http.Body) => Z): Z = { throw new IllegalStateException(encodeWithErrorMessage) @@ -138,7 +139,7 @@ private[codec] object EncoderDecoder { method: zio.http.Method, headers: zio.http.Headers, body: zio.http.Body, - )(implicit trace: zio.Trace): zio.Task[Value] = { + )(implicit trace: zio.Trace): zio.Task[Any] = { ZIO.fail(new IllegalStateException(decodeErrorMessage)) } } @@ -158,6 +159,12 @@ private[codec] object EncoderDecoder { }.toMap private lazy val nameByIndex = indexByName.map(_.swap) + private def canEmpty(schema: Schema.Collection[_, _]): Boolean = + schema match { + case _: Schema.NonEmptyMap[_, _] | _: Schema.NonEmptySequence[_, _, _] => false + case _ => true + } + override def decode(config: CodecConfig, url: URL, status: Status, method: Method, headers: Headers, body: Body)( implicit trace: Trace, ): Task[Value] = ZIO.suspendSucceed { @@ -168,6 +175,7 @@ private[codec] object EncoderDecoder { decodeStatus(status, inputsBuilder.status) decodeMethod(method, inputsBuilder.method) decodeHeaders(headers, inputsBuilder.header) + decodeCustomHeaders(headers, inputsBuilder.headerCustom) decodeBody(config, body, inputsBuilder.content).as(constructor(inputsBuilder)) } @@ -180,7 +188,7 @@ private[codec] object EncoderDecoder { val query = encodeQuery(config, inputs.query) val status = encodeStatus(inputs.status) val method = encodeMethod(inputs.method) - val headers = encodeHeaders(inputs.header) + val headers = encodeHeaders(inputs.header) ++ encodeCustomHeaders(inputs.headerCustom) def contentTypeHeaders = encodeContentType(inputs.content, outputTypes) val body = encodeBody(config, inputs.content, outputTypes) @@ -220,156 +228,276 @@ private[codec] object EncoderDecoder { inputs, (codec, queryParams) => { val query = codec.erase - val isOptional = query.isOptional - query.queryType match { - case QueryType.Primitive(name, bc @ BinaryCodecWithSchema(_, schema)) => - val count = queryParams.valueCount(name) - val hasParam = queryParams.hasQueryParam(name) - if (!hasParam && isOptional) None - else if (!hasParam) throw HttpCodecError.MissingQueryParam(name) - else if (count != 1) throw HttpCodecError.InvalidQueryParamCount(name, 1, count) - else { - val decoded = bc - .codec(config) - .decode( - Chunk.fromArray(queryParams.unsafeQueryParam(name).getBytes(Charsets.Utf8)), - ) match { + val optional = query.isOptionalSchema + val hasDefault = query.codec.defaultValue != null && query.isOptional + val default = query.codec.defaultValue + if (codec.isPrimitive) { + val name = query.nameUnsafe + val hasParam = queryParams.hasQueryParam(name) + if ( + (!hasParam || (queryParams + .unsafeQueryParam(name) == "" && !emptyStringIsValue(codec.codec.schema))) && hasDefault + ) + default + else if (!hasParam) + throw HttpCodecError.MissingQueryParam(name) + else if (queryParams.valueCount(name) != 1) + throw HttpCodecError.InvalidQueryParamCount(name, 1, queryParams.valueCount(name)) + else { + val decoded = + codec.codec.stringCodec.decode(queryParams.unsafeQueryParam(name)) match { + case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) + case Right(value) => value + } + val validationErrors = codec.codec.erasedSchema.validate(decoded)(codec.codec.erasedSchema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + else decoded + } + + } else if (codec.isCollection) { + val name = query.nameUnsafe + val hasParam = queryParams.hasQueryParam(name) + + if (!hasParam) { + if (query.codec.defaultValue != null) query.codec.defaultValue + else throw HttpCodecError.MissingQueryParam(name) + } else { + val decoded = queryParams.queryParams(name).map { value => + query.codec.stringCodec.decode(value) match { case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) case Right(value) => value } - val validationErrors = schema.validate(decoded)(schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - if (isOptional && decoded == None && emptyStringIsValue(schema.asInstanceOf[Schema.Optional[_]].schema)) - Some("") - else decoded } - case c @ QueryType.Collection(_, QueryType.Primitive(name, bc), optional) => - if (!queryParams.hasQueryParam(name)) { - if (!optional) c.toCollection(Chunk.empty) - else None + if (optional) + Some( + createAndValidateCollection( + query.codec.schema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Collection[_, _]], + decoded, + ), + ) + else createAndValidateCollection(query.codec.schema.asInstanceOf[Schema.Collection[_, _]], decoded) + } + } else { + val recordSchema = query.codec.recordSchema + val fields = query.codec.recordFields + val hasAllParams = fields.forall { case (field, codec) => + queryParams.hasQueryParam(field.fieldName) || field.optional || codec.isOptional + } + if (!hasAllParams && hasDefault) default + else if (!hasAllParams) throw HttpCodecError.MissingQueryParams { + fields.collect { + case (field, codec) + if !(queryParams.hasQueryParam(field.fieldName) || field.optional || codec.isOptional) => + field.fieldName + } + } + else { + val decoded = fields.map { + case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => + val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] + if (!queryParams.hasQueryParam(field.fieldName)) { + if (field.defaultValue.isDefined) field.defaultValue.get + else throw HttpCodecError.MissingQueryParam(field.fieldName) + } else { + val values = queryParams.queryParams(field.fieldName) + val decoded = + values.map(decodeAndUnwrap(field, codec, _, HttpCodecError.MalformedQueryParam)) + createAndValidateCollection(schema, decoded) + + } + case (field, codec) => + val value = queryParams.queryParamOrElse(field.fieldName, null) + val decoded = { + if (value == null || (value == "" && !emptyStringIsValue(codec.schema))) codec.defaultValue + else decodeAndUnwrap(field, codec, value, HttpCodecError.MalformedQueryParam) + } + validateDecoded(codec, decoded) + } + if (optional) { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw HttpCodecError.MalformedQueryParam( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) + case _ => Some(value) + } + } } else { - val values = queryParams.queryParams(name) - val decoded = c.toCollection { - values.map { value => - bc.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) - case Right(value) => value + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw HttpCodecError.MalformedQueryParam( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) + case _ => value } - } } - val erasedSchema = c.colSchema.asInstanceOf[Schema[Any]] - val validationErrors = erasedSchema.validate(decoded)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - if (optional) Some(decoded) - else decoded } - case query @ QueryType.Record(recordSchema) => - val hasAllParams = query.fieldAndCodecs.forall { case (field, _) => - queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined + } + } + }, + ) + + private def createAndValidateCollection(schema: Schema.Collection[_, _], decoded: Chunk[Any]) = { + val collection = schema.fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) + val erasedSchema = schema.asInstanceOf[Schema[Any]] + val validationErrors = erasedSchema.validate(collection)(erasedSchema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + collection + } + + @tailrec + private def emptyStringIsValue(schema: Schema[_]): Boolean = { + schema match { + case value: Schema.Optional[_] => + val innerSchema = value.schema + emptyStringIsValue(innerSchema) + case _ => + schema.asInstanceOf[Schema.Primitive[_]].standardType match { + case StandardType.UnitType => true + case StandardType.StringType => true + case StandardType.BinaryType => true + case StandardType.CharType => true + case _ => false + } + } + } + + private def decodeCustomHeaders(headers: Headers, inputs: Array[Any]): Unit = + genericDecode[Headers, HttpCodec.HeaderCustom[_]]( + headers, + flattened.headerCustom, + inputs, + (header, headers) => { + val optional = header.codec.isOptionalSchema + if (header.codec.isPrimitive) { + val schema = header.erase.codec.schema + val name = header.codec.name.get + val value = headers.getUnsafe(name) + if (value ne null) { + val decoded = header.codec.stringCodec.decode(value) match { + case Left(error) => throw HttpCodecError.MalformedCustomHeader(name, error) + case Right(value) => value } - if (!hasAllParams && recordSchema.isInstanceOf[Schema.Optional[_]]) None - else if (!hasAllParams && isOptional) { - recordSchema.defaultValue match { - case Left(err) => - throw new IllegalStateException(s"Cannot compute default value for $recordSchema. Error was: $err") - case Right(value) => value - } - } else if (!hasAllParams) throw HttpCodecError.MissingQueryParams { - query.fieldAndCodecs.collect { - case (field, _) - if !(queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined) => - field.name + val validationErrors = schema.validate(decoded)(schema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + else decoded + } else { + if (optional) None + else throw HttpCodecError.MissingHeader(name) + } + } else if (header.codec.isCollection) { + val name = header.codec.name.get + val values = headers.rawHeaders(name) + val decoded = values.map { value => + header.codec.stringCodec.decode(value) match { + case Left(error) => throw HttpCodecError.MalformedCustomHeader(name, error) + case Right(value) => value + } + } + if (optional) + Some( + createAndValidateCollection( + header.codec.schema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Collection[_, _]], + decoded, + ), + ) + else createAndValidateCollection(header.codec.schema.asInstanceOf[Schema.Collection[_, _]], decoded) + } else { + val recordSchema = header.codec.recordSchema + val fields = header.codec.recordFields + val hasAllParams = fields.forall { case (field, codec) => + headers.contains(field.fieldName) || field.optional || codec.isOptional + } + if (!hasAllParams) { + if (header.codec.defaultValue != null && header.codec.isOptional) header.codec.defaultValue + else + throw HttpCodecError.MissingHeaders { + fields.collect { + case (field, codec) if !(headers.contains(field.fieldName) || field.optional || codec.isOptional) => + field.fieldName + } } + } else { + val decoded = fields.map { + case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => + if (!headers.contains(codec.name.get)) { + if (codec.defaultValue != null) codec.defaultValue + else throw HttpCodecError.MissingHeader(codec.name.get) + } else { + val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] + val values = headers.rawHeaders(codec.name.get) + val decoded = + values.map(decodeAndUnwrap(field, codec, _, HttpCodecError.MalformedCustomHeader)) + createAndValidateCollection(schema, decoded) + } + case (field, codec) => + val value = headers.getUnsafe(codec.name.get) + val decoded = + if (value == null || (value == "" && !emptyStringIsValue(codec.schema))) codec.defaultValue + else decodeAndUnwrap(field, codec, value, HttpCodecError.MalformedCustomHeader) + validateDecoded(codec, decoded) } - else { - val decoded = query.fieldAndCodecs.map { - case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => - if (!queryParams.hasQueryParam(field.name) && field.defaultValue.nonEmpty) field.defaultValue.get - else { - val values = queryParams.queryParams(field.name) - val decoded = values.map { value => - codec.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) - case Right(value) => value - } - } - val decodedCollection = - field.schema match { - case s @ Schema.Sequence(_, fromChunk, _, _, _) => - val collection = fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) - val erasedSchema = s.asInstanceOf[Schema[Any]] - val validationErrors = erasedSchema.validate(collection)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - collection - case s @ Schema.Set(_, _) => - val collection = decoded.toSet[Any] - val erasedSchema = s.asInstanceOf[Schema.Set[Any]] - val validationErrors = erasedSchema.validate(collection)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - collection - case _ => throw new IllegalStateException("Only Sequence and Set are supported.") - } - decodedCollection - } - case (field, codec) => - val value = queryParams.queryParamOrElse(field.name, null) - val decoded = { - if (value == null) field.defaultValue.get - else { - codec.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) - case Right(value) => value - } - } + if (optional) { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw HttpCodecError.MalformedCustomHeader( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) + case _ => Some(value) } - val validationErrors = codec.schema.validate(decoded)(codec.schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - decoded } - if (recordSchema.isInstanceOf[Schema.Optional[_]]) { - val schema = recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] - val constructed = schema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam( - s"${schema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - schema.validate(value)(schema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => Some(value) - } - } - } else { - val schema = recordSchema.asInstanceOf[Schema.Record[Any]] - val constructed = schema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam( - s"${schema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - schema.validate(value)(schema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => value - } - } + } else { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw HttpCodecError.MalformedCustomHeader( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) + case _ => value + } } } + } } }, ) - private def emptyStringIsValue(schema: Schema[_]): Boolean = - schema.asInstanceOf[Schema.Primitive[_]].standardType match { - case StandardType.UnitType => true - case StandardType.StringType => true - case StandardType.BinaryType => true - case StandardType.CharType => true - case _ => false + private def validateDecoded(codec: HttpCodec.SchemaCodec[Any], decoded: Any) = { + val validationErrors = codec.schema.validate(decoded)(codec.schema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + decoded + } + + private def decodeAndUnwrap( + field: Schema.Field[_, _], + codec: HttpCodec.SchemaCodec[Any], + value: String, + ex: (String, DecodeError) => HttpCodecError, + ) = { + codec.stringCodec.decode(value) match { + case Left(error) => throw ex(codec.name.get, error) + case Right(value) => value } + } private def decodeHeaders(headers: Headers, inputs: Array[Any]): Unit = genericDecode[Headers, HttpCodec.Header[_]]( @@ -377,14 +505,14 @@ private[codec] object EncoderDecoder { flattened.header, inputs, (codec, headers) => - headers.get(codec.name) match { + headers.get(codec.headerType.name) match { case Some(value) => - codec.erase.textCodec - .decode(value) - .getOrElse(throw HttpCodecError.MalformedHeader(codec.name, codec.textCodec)) + codec.erase.headerType + .parse(value) + .getOrElse(throw HttpCodecError.MalformedTypedHeader(codec.headerType.name)) case None => - throw HttpCodecError.MissingHeader(codec.name) + throw HttpCodecError.MissingHeader(codec.headerType.name) }, ) @@ -513,111 +641,154 @@ private[codec] object EncoderDecoder { inputs, QueryParams.empty, (codec, input, queryParams) => { - val query = codec.erase - - query.queryType match { - case QueryType.Primitive(name, codec) => - val schema = codec.schema - if (schema.isInstanceOf[Schema.Primitive[_]]) { - if (schema.asInstanceOf[Schema.Primitive[_]].standardType.isInstanceOf[StandardType.UnitType.type]) { - queryParams.addQueryParams(name, Chunk.empty[String]) - } else { - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(input).asString - queryParams.addQueryParams(name, Chunk(encoded)) - } - } else if (schema.isInstanceOf[Schema.Optional[_]]) { - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(input).asString - if (encoded.nonEmpty) queryParams.addQueryParams(name, Chunk(encoded)) else queryParams + val query = codec.erase + val optional = query.isOptionalSchema + val stringCodec = codec.codec.stringCodec.asInstanceOf[StringCodec[Any]] + + if (query.isPrimitive) { + val schema = codec.codec.schema + val name = query.nameUnsafe + if (schema.isInstanceOf[Schema.Primitive[_]]) { + if (schema.asInstanceOf[Schema.Primitive[_]].standardType.isInstanceOf[StandardType.UnitType.type]) { + queryParams.addQueryParams(name, Chunk.empty[String]) } else { - throw new IllegalStateException( - "Only primitive schema is supported for query parameters of type Primitive", - ) + val encoded = stringCodec.encode(input) + queryParams.addQueryParams(name, Chunk(encoded)) } - case QueryType.Collection(_, QueryType.Primitive(name, codec), optional) => - var in: Any = input - if (optional) { - in = input.asInstanceOf[Option[Any]].getOrElse(Chunk.empty) - } - val values = input.asInstanceOf[Iterable[Any]] - if (values.nonEmpty) { - queryParams.addQueryParams( - name, - Chunk.fromIterable( - values.map { value => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString - }, - ), - ) - } else queryParams - case query @ QueryType.Record(recordSchema) if recordSchema.isInstanceOf[Schema.Optional[_]] => - input match { - case None => queryParams - case Some(value) => - val innerSchema = - recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] - val fieldValues = innerSchema.deconstruct(value)(Unsafe.unsafe) - var j = 0 - var qp = queryParams - while (j < fieldValues.size) { - val (field, codec) = query.fieldAndCodecs(j) - val name = field.name - val value = fieldValues(j) match { - case Some(value) => value - case None => field.defaultValue - } - value match { - case values: Iterable[_] => - qp = qp.addQueryParams( - name, - Chunk.fromIterable(values.map { v => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(v).asString - }), - ) - case _ => - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString - qp = qp.addQueryParam(name, encoded) - } - j = j + 1 - } - qp - } - case query @ QueryType.Record(recordSchema) => - val innerSchema = recordSchema.asInstanceOf[Schema.Record[Any]] - val fieldValues = innerSchema.deconstruct(input)(Unsafe.unsafe) - var j = 0 - var qp = queryParams - while (j < fieldValues.size) { - val (field, codec) = query.fieldAndCodecs(j) - val name = field.name - val value = fieldValues(j) match { + } else if (schema.isInstanceOf[Schema.Optional[_]]) { + val encoded = stringCodec.encode(input) + if (encoded.nonEmpty) queryParams.addQueryParams(name, Chunk(encoded)) else queryParams + } else { + throw new IllegalStateException( + "Only primitive schema is supported for query parameters of type Primitive", + ) + } + } else if (query.isCollection) { + val name = query.nameUnsafe + var in: Any = input + if (optional) { + in = input.asInstanceOf[Option[Any]].getOrElse(Chunk.empty) + } + val values = input.asInstanceOf[Iterable[Any]] + if (values.nonEmpty) { + queryParams.addQueryParams( + name, + Chunk.fromIterable(values.map { value => stringCodec.encode(value) }), + ) + } else queryParams + } else if (query.isRecord) { + val value = input match { + case None => null + case Some(value) => value + case value => value + } + if (value == null) queryParams + else { + val innerSchema = query.codec.recordSchema + val fieldValues = innerSchema.deconstruct(value)(Unsafe.unsafe) + var qp = queryParams + val fieldIt = query.codec.recordFields.iterator + val fieldValuesIt = fieldValues.iterator + while (fieldIt.hasNext) { + val (field, codec) = fieldIt.next() + val name = field.fieldName + val value = fieldValuesIt.next() match { case Some(value) => value case None => field.defaultValue } value match { - case values if values.isInstanceOf[Iterable[_]] => + case values: Iterable[_] => qp = qp.addQueryParams( name, - Chunk.fromIterable(values.asInstanceOf[Iterable[Any]].map { v => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(v).asString + Chunk.fromIterable(values.map { v => + codec.stringCodec.encode(v) }), ) - case _ => - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString + case _ => + val encoded = codec.stringCodec.encode(value) qp = qp.addQueryParam(name, encoded) } - j = j + 1 } qp + } + } else { + queryParams + } + }, + ) + + private def encodeCustomHeaders(inputs: Array[Any]): Headers = { + genericEncode[Headers, HttpCodec.HeaderCustom[_]]( + flattened.headerCustom, + inputs, + Headers.empty, + (codec, input, headers) => { + val optional = codec.codec.isOptionalSchema + val stringCodec = codec.erase.codec.stringCodec + if (codec.codec.isPrimitive) { + val name = codec.codec.name.get + val value = input + if (optional && value == None) headers + else { + val encoded = stringCodec.encode(value) + headers ++ Headers(name, encoded) + } + } else if (codec.codec.isCollection) { + val name = codec.codec.name.get + val values = input.asInstanceOf[Iterable[Any]] + if (values.nonEmpty) { + headers ++ Headers.FromIterable( + values.map { value => + Header.Custom(name, stringCodec.encode(value)) + }, + ) + } else headers + } else { + val recordSchema = codec.codec.recordSchema + val fields = codec.codec.recordFields + val value = input match { + case None => null + case Some(value) => value + case value => value + } + if (value == null) headers + else { + val fieldValues = recordSchema.deconstruct(value)(Unsafe.unsafe) + var hs = headers + val fieldIt = fields.iterator + val fieldValuesIt = fieldValues.iterator + while (fieldIt.hasNext) { + val (field, codec) = fieldIt.next() + val name = field.fieldName + val value = fieldValuesIt.next() match { + case Some(value) => value + case None => field.defaultValue + } + value match { + case values: Iterable[_] => + hs = hs ++ Headers.FromIterable( + values.map { v => + Header.Custom(name, codec.stringCodec.encode(v)) + }, + ) + case _ => + val encoded = codec.stringCodec.encode(value) + hs = hs ++ Headers(name, encoded) + } + } + hs + } } }, ) + } private def encodeHeaders(inputs: Array[Any]): Headers = genericEncode[Headers, HttpCodec.Header[_]]( flattened.header, inputs, Headers.empty, - (codec, input, headers) => headers ++ Headers(codec.name, codec.erase.textCodec.encode(input)), + (codec, input, headers) => headers ++ Headers(codec.headerType.name, codec.erase.headerType.render(input)), ) private def encodeStatus(inputs: Array[Any]): Option[Status] = diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala index eae8f5aba..3cc9b02f5 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala @@ -125,38 +125,36 @@ object HttpGen { private def getName(name: Option[String]) = { name.getOrElse(throw new IllegalArgumentException("name is required")) } def headersVariables(inAtoms: AtomizedMetaCodecs): Seq[HttpVariable] = - inAtoms.header.collect { case mc @ MetaCodec(HttpCodec.Header(name, codec, _), _) => + inAtoms.header.collect { case mc @ MetaCodec(HttpCodec.Header(headerType, _), _) => HttpVariable( - name.capitalize, - mc.examples.values.headOption.map(e => codec.asInstanceOf[TextCodec[Any]].encode(e)), + headerType.name.capitalize, + mc.examples.values.headOption.map(e => headerType.render(e)), ) } def queryVariables(config: CodecConfig, inAtoms: AtomizedMetaCodecs): Seq[HttpVariable] = { inAtoms.query.collect { - case mc @ MetaCodec(HttpCodec.Query(HttpCodec.Query.QueryType.Primitive(name, codec), _), _) => + case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isPrimitive => HttpVariable( - name, - mc.examples.values.headOption.map((e: Any) => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(e).asString, - ), + codec.name.get, + mc.examples.values.headOption.map((e: Any) => codec.stringCodec.encode(e)), ) :: Nil - case mc @ MetaCodec(HttpCodec.Query(record @ HttpCodec.Query.QueryType.Record(schema), _), _) => - val recordSchema = (schema match { + case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isRecord => + val recordSchema = (codec.schema match { case value if value.isInstanceOf[Schema.Optional[_]] => value.asInstanceOf[Schema.Optional[Any]].schema - case _ => schema + case _ => codec.schema }).asInstanceOf[Schema.Record[Any]] val examples = mc.examples.values.headOption.map { ex => recordSchema.deconstruct(ex)(Unsafe.unsafe) } - record.fieldAndCodecs.zipWithIndex.map { case ((field, codec), index) => + codec.recordFields.zipWithIndex.map { case ((field, codec), index) => HttpVariable( field.name, examples.map(values => { val fieldValue = values(index) .orElse(field.defaultValue) .getOrElse(throw new Exception(s"No value or default value for field ${field.name}")) - codec.codec(config).encode(fieldValue).asString + codec.stringCodec.encode(fieldValue) }), ) } diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala index 3ccac01aa..4586ab5d0 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala @@ -253,6 +253,16 @@ object JsonSchema { .toOption .get + def fromTextCodec(codec: TextCodec[_]): JsonSchema = + codec match { + case TextCodec.Constant(string) => JsonSchema.Enum(Chunk(EnumValue.Str(string))) + case TextCodec.StringCodec => JsonSchema.String() + case TextCodec.IntCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) + case TextCodec.LongCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) + case TextCodec.BooleanCodec => JsonSchema.Boolean + case TextCodec.UUIDCodec => JsonSchema.String(JsonSchema.StringFormat.UUID) + } + private[openapi] def fromSerializableSchema(schema: SerializableJsonSchema): JsonSchema = { val additionalProperties = schema.additionalProperties match { case Some(BoolOrSchema.BooleanWrapper(bool)) => Left(bool) @@ -361,16 +371,6 @@ object JsonSchema { jsonSchema } - def fromTextCodec(codec: TextCodec[_]): JsonSchema = - codec match { - case TextCodec.Constant(string) => JsonSchema.Enum(Chunk(EnumValue.Str(string))) - case TextCodec.StringCodec => JsonSchema.String() - case TextCodec.IntCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) - case TextCodec.LongCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) - case TextCodec.BooleanCodec => JsonSchema.Boolean - case TextCodec.UUIDCodec => JsonSchema.String(JsonSchema.StringFormat.UUID) - } - def fromSegmentCodec(codec: SegmentCodec[_]): JsonSchema = codec match { case SegmentCodec.BoolSeg(_) => JsonSchema.Boolean diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index 64ba05780..cbeb35c04 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -634,10 +634,10 @@ object OpenAPIGen { def queryParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = { inAtoms.query.collect { - case mc @ MetaCodec(q @ HttpCodec.Query(HttpCodec.Query.QueryType.Primitive(name, codec), _), _) => + case mc @ MetaCodec(q @ HttpCodec.Query(codec, _), _) if codec.isPrimitive => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.queryParameter( - name = name, + name = q.nameUnsafe, description = mc.docsOpt, schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), deprecated = mc.deprecated, @@ -650,15 +650,15 @@ object OpenAPIGen { required = mc.required && !q.isOptional, ), ) :: Nil - case mc @ MetaCodec(HttpCodec.Query(record @ HttpCodec.Query.QueryType.Record(schema), _), _) => - val recordSchema = (schema match { + case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isRecord => + val recordSchema = (codec.schema match { case schema if schema.isInstanceOf[Schema.Optional[_]] => schema.asInstanceOf[Schema.Optional[_]].schema - case _ => schema + case _ => codec.schema }).asInstanceOf[Schema.Record[Any]] val examples = mc.examples.map { case (exName, ex) => exName -> recordSchema.deconstruct(ex)(Unsafe.unsafe) } - record.fieldAndCodecs.zipWithIndex.map { case ((field, codec), index) => + codec.recordFields.zipWithIndex.map { case ((field, codec), index) => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.queryParameter( name = field.name, @@ -675,9 +675,7 @@ object OpenAPIGen { throw new Exception(s"No value or default value found for field ${exName}_${field.name}"), ) s"${exName}_${field.name}" -> OpenAPI.ReferenceOr.Or( - OpenAPI.Example(value = - Json.Str(codec.codec(CodecConfig.defaultConfig).encode(fieldValue).asString), - ), + OpenAPI.Example(value = Json.Str(codec.stringCodec.encode(fieldValue))), ) }, required = mc.required, @@ -685,22 +683,22 @@ object OpenAPIGen { ) } - case mc @ MetaCodec( - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - _, - HttpCodec.Query.QueryType.Primitive(name, codec), - optional, - ), - _, - ), - _, - ) => + case mc @ MetaCodec(q @ HttpCodec.Query(codec, _), _) if codec.isCollection => + var required = false + val schema = codec.schema.asInstanceOf[Schema.Collection[_, _]] match { + case s: Schema.Sequence[_, _, _] => s.elementSchema + case _: Schema.Map[_, _] => throw new Exception("Map query parameters not supported") + case _: Schema.NonEmptyMap[_, _] => throw new Exception("Map query parameters not supported") + case s: Schema.NonEmptySequence[_, _, _] => + required = true + s.elementSchema + case s: Schema.Set[_] => s.elementSchema + } OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.queryParameter( - name = name, + name = q.nameUnsafe, description = mc.docsOpt, - schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), + schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(schema))), deprecated = mc.deprecated, style = OpenAPI.Parameter.Style.Form, explode = false, @@ -708,7 +706,7 @@ object OpenAPIGen { examples = mc.examples.map { case (exName, value) => exName -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(value = Json.Str(value.toString))) }, - required = mc.required && !optional, + required = required, ), ) :: Nil } @@ -737,13 +735,12 @@ object OpenAPIGen { .map { case mc @ MetaCodec(codec, _) => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.headerParameter( - name = mc.name.getOrElse(codec.name), + name = mc.name.getOrElse(codec.headerType.name), description = mc.docsOpt, - definition = - Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromTextCodec(codec.textCodec).nullable(!mc.required))), + definition = Some(OpenAPI.ReferenceOr.Or(JsonSchema.String().nullable(!mc.required))), deprecated = mc.deprecated, examples = mc.examples.map { case (name, value) => - name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.textCodec.encode(value).toJsonAST.toOption.get)) + name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.headerType.render(value).toJsonAST.toOption.get)) }, required = mc.required, ), @@ -1014,13 +1011,13 @@ object OpenAPIGen { private def headersFrom(codec: AtomizedMetaCodecs) = { codec.header.map { case mc @ MetaCodec(codec, _) => - codec.name -> OpenAPI.ReferenceOr.Or( + codec.headerType.name -> OpenAPI.ReferenceOr.Or( OpenAPI.Header( description = mc.docsOpt, required = true, deprecated = mc.deprecated, allowEmptyValue = false, - schema = Some(JsonSchema.fromTextCodec(codec.textCodec)), + schema = Some(JsonSchema.String().nullable(!mc.required)), ), ) }.toMap diff --git a/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala b/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala index 8b6ad2b16..1a9841c32 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala @@ -67,6 +67,13 @@ trait HeaderGetters { self => /** Gets the raw unparsed header value */ final def rawHeader(name: CharSequence): Option[String] = headers.get(name) + final def rawHeaders(name: CharSequence): Chunk[String] = + Chunk.fromIterator( + headers.iterator + .filter(header => CharSequenceExtensions.equals(header.headerNameAsCharSequence, name, CaseMode.Insensitive)) + .map(_.renderedValue), + ) + /** Gets the raw unparsed header value */ final def rawHeader(headerType: HeaderType): Option[String] = rawHeader(headerType.name)