From 7e76ce337b2f8aec406bbe8f163062226b6127a8 Mon Sep 17 00:00:00 2001 From: Paul Lysak Date: Mon, 29 Aug 2022 05:22:24 +0300 Subject: [PATCH] Fix: Host header must include port (#1373) * Fix for #1372 - Host header must include port * Fix for #1372 - code review fixes * Fix for #1372 - fix formatting * Fix for #1372 - disambiguate priorities in test * Fix for #1372 - formatting * Fix for #1372 - Scala 2.12 compatibility for the tests --- zio-http/src/main/scala/zhttp/service/Client.scala | 8 ++------ .../src/main/scala/zhttp/service/EncodeRequest.scala | 4 ++-- .../src/test/scala/zhttp/http/EncodeRequestSpec.scala | 10 ++++++++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/zio-http/src/main/scala/zhttp/service/Client.scala b/zio-http/src/main/scala/zhttp/service/Client.scala index ea2634ece..de44181fd 100644 --- a/zio-http/src/main/scala/zhttp/service/Client.scala +++ b/zio-http/src/main/scala/zhttp/service/Client.scala @@ -19,7 +19,7 @@ import zhttp.service.client.{ClientInboundHandler, ClientSSLHandler} import zhttp.socket.SocketApp import zio.{Promise, Scope, Task, ZIO} -import java.net.{InetSocketAddress, URI} +import java.net.InetSocketAddress final case class Client[R](rtm: HttpRuntime[R], cf: JChannelFactory[JChannel], el: JEventLoopGroup) extends HttpMessageCodec { @@ -78,11 +78,7 @@ final case class Client[R](rtm: HttpRuntime[R], cf: JChannelFactory[JChannel], e ): JChannelFuture = { try { - val uri = new URI(jReq.uri()) - val host = if (uri.getHost == null) jReq.headers().get(HeaderNames.host) else uri.getHost - - assert(host != null, "Host name is required") - + val host = req.url.host.getOrElse { assert(false, "Host name is required"); "" } val port = req.url.port.getOrElse(80) val isWebSocket = req.url.scheme.exists(_.isWebSocket) diff --git a/zio-http/src/main/scala/zhttp/service/EncodeRequest.scala b/zio-http/src/main/scala/zhttp/service/EncodeRequest.scala index 325bd5637..c9577f77a 100644 --- a/zio-http/src/main/scala/zhttp/service/EncodeRequest.scala +++ b/zio-http/src/main/scala/zhttp/service/EncodeRequest.scala @@ -23,8 +23,8 @@ trait EncodeRequest { val encodedReqHeaders = req.headers.encode val headers = req.url.host match { - case Some(value) => encodedReqHeaders.set(HttpHeaderNames.HOST, value) - case None => encodedReqHeaders + case Some(host) => encodedReqHeaders.set(HttpHeaderNames.HOST, req.url.port.fold(host)(port => s"$host:$port")) + case _ => encodedReqHeaders } val writerIndex = content.writerIndex() diff --git a/zio-http/src/test/scala/zhttp/http/EncodeRequestSpec.scala b/zio-http/src/test/scala/zhttp/http/EncodeRequestSpec.scala index 4989ebc20..21e2f0d02 100644 --- a/zio-http/src/test/scala/zhttp/http/EncodeRequestSpec.scala +++ b/zio-http/src/test/scala/zhttp/http/EncodeRequestSpec.scala @@ -67,14 +67,20 @@ object EncodeRequestSpec extends ZIOSpecDefault with EncodeRequest { check(anyClientParam) { params => val req = encode(params).map(i => Option(i.headers().get(HttpHeaderNames.HOST))) - assertZIO(req)(equalTo(params.url.host)) + assertZIO(req)(equalTo((params.url.host, params.url.port) match { + case (Some(host), Some(port)) => Some(s"$host:$port") + case _ => params.url.host + })) } }, test("host header when absolute url") { check(clientParamWithAbsoluteUrl) { params => val req = encode(params) .map(i => Option(i.headers().get(HttpHeaderNames.HOST))) - assertZIO(req)(equalTo(params.url.host)) + assertZIO(req)(equalTo((params.url.host, params.url.port) match { + case (Some(host), Some(port)) => Some(s"$host:$port") + case _ => params.url.host + })) } }, test("only one host header exists") {