diff --git a/zio-http/src/main/scala/zhttp/http/Http.scala b/zio-http/src/main/scala/zhttp/http/Http.scala index 4686fac9b..1906e5209 100644 --- a/zio-http/src/main/scala/zhttp/http/Http.scala +++ b/zio-http/src/main/scala/zhttp/http/Http.scala @@ -3,7 +3,6 @@ package zhttp.http import io.netty.channel.ChannelHandler import zhttp.html.Html import zhttp.http.headers.HeaderModifier -import zhttp.service.server.ServerTimeGenerator import zhttp.service.{Handler, HttpRuntime, Server} import zio._ import zio.clock.Clock @@ -393,11 +392,10 @@ object Http { private[zhttp] def compile[R1 <: R]( zExec: HttpRuntime[R1], settings: Server.Config[R1, Throwable], - serverTime: ServerTimeGenerator, )(implicit evE: E <:< Throwable, ): ChannelHandler = - Handler(http.asInstanceOf[HttpApp[R1, Throwable]], zExec, settings, serverTime) + Handler(http.asInstanceOf[HttpApp[R1, Throwable]], zExec, settings) } /** diff --git a/zio-http/src/main/scala/zhttp/service/Handler.scala b/zio-http/src/main/scala/zhttp/service/Handler.scala index f7acbe018..c6e4761c0 100644 --- a/zio-http/src/main/scala/zhttp/service/Handler.scala +++ b/zio-http/src/main/scala/zhttp/service/Handler.scala @@ -1,16 +1,13 @@ package zhttp.service -import io.netty.buffer.{ByteBuf, Unpooled} +import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandler.Sharable -import io.netty.channel.{ChannelHandlerContext, DefaultFileRegion, SimpleChannelInboundHandler} +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.handler.codec.http._ -import zhttp.core.Util import zhttp.http._ -import zhttp.service.server.{ServerTimeGenerator, WebSocketUpgrade} -import zio.stream.ZStream +import zhttp.service.server.WebSocketUpgrade import zio.{Task, UIO, ZIO} -import java.io.File import java.net.{InetAddress, InetSocketAddress} @Sharable @@ -18,10 +15,8 @@ private[zhttp] final case class Handler[R]( app: HttpApp[R, Throwable], runtime: HttpRuntime[R], config: Server.Config[R, Throwable], - serverTime: ServerTimeGenerator, ) extends SimpleChannelInboundHandler[FullHttpRequest](false) - with WebSocketUpgrade[R] { - self => + with WebSocketUpgrade[R] { self => type Ctx = ChannelHandlerContext @@ -50,42 +45,6 @@ private[zhttp] final case class Handler[R]( ) } - override def exceptionCaught(ctx: Ctx, cause: Throwable): Unit = { - config.error.fold(super.exceptionCaught(ctx, cause))(f => runtime.unsafeRun(ctx)(f(cause))) - } - - /** - * Checks if an encoded version of the response exists, uses it if it does. Otherwise, it will return a fresh - * response. It will also set the server time if requested by the client. - */ - private def encodeResponse(res: Response): HttpResponse = { - - val jResponse = res.attribute.encoded match { - - // Check if the encoded response exists and/or was modified. - case Some((oRes, jResponse)) if oRes eq res => - jResponse match { - // Duplicate the response without allocating much memory - case response: FullHttpResponse => - response.retainedDuplicate() - - case response => - response - } - - case _ => res.unsafeEncode() - } - // Identify if the server time should be set and update if required. - if (res.attribute.serverTime) jResponse.headers().set(HttpHeaderNames.DATE, serverTime.refreshAndGet()) - jResponse - } - - private def notFoundResponse: HttpResponse = { - val response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND, false) - response.headers().setInt(HttpHeaderNames.CONTENT_LENGTH, 0) - response - } - /** * Releases the FullHttpRequest safely. */ @@ -95,17 +54,6 @@ private[zhttp] final case class Handler[R]( } } - private def serverErrorResponse(cause: Throwable): HttpResponse = { - val content = Util.prettyPrintHtml(cause) - val response = new DefaultFullHttpResponse( - HttpVersion.HTTP_1_1, - HttpResponseStatus.INTERNAL_SERVER_ERROR, - Unpooled.copiedBuffer(content, HTTP_CHARSET), - ) - response.headers().set(HttpHeaderNames.CONTENT_LENGTH, content.length) - response - } - /** * Executes http apps */ @@ -121,12 +69,12 @@ private[zhttp] final case class Handler[R]( { case Some(cause) => UIO { - unsafeWriteAndFlushErrorResponse(cause) + ctx.fireChannelRead(Response.fromHttpError(HttpError.InternalServerError(cause = Some(cause)))) releaseRequest(jReq) } case None => UIO { - unsafeWriteAndFlushEmptyResponse() + ctx.fireChannelRead(Response.status(Status.NOT_FOUND)) releaseRequest(jReq) } }, @@ -135,16 +83,7 @@ private[zhttp] final case class Handler[R]( else { for { _ <- UIO { - // Write the initial line and the header. - unsafeWriteAndFlushAnyResponse(res) - } - _ <- res.data match { - case HttpData.BinaryStream(stream) => writeStreamContent(stream) - case HttpData.File(file) => - UIO { - unsafeWriteFileContent(file) - } - case _ => UIO(ctx.flush()) + ctx.fireChannelRead(res) } _ <- Task(releaseRequest(jReq)) } yield () @@ -156,22 +95,18 @@ private[zhttp] final case class Handler[R]( if (self.isWebSocket(res)) { self.upgradeToWebSocket(ctx, jReq, res) } else { - // Write the initial line and the header. - unsafeWriteAndFlushAnyResponse(res) - res.data match { - case HttpData.BinaryStream(stream) => unsafeRunZIO(writeStreamContent(stream) *> Task(releaseRequest(jReq))) - case HttpData.File(file) => - unsafeWriteFileContent(file) - case _ => releaseRequest(jReq) - } + ctx.fireChannelRead(res) + releaseRequest(jReq) } - case HExit.Failure(e) => - unsafeWriteAndFlushErrorResponse(e) + + case HExit.Failure(e) => + ctx.fireChannelRead(e) releaseRequest(jReq) - case HExit.Empty => - unsafeWriteAndFlushEmptyResponse() + case HExit.Empty => + ctx.fireChannelRead(Response.status(Status.NOT_FOUND)) releaseRequest(jReq) } + } /** @@ -181,51 +116,4 @@ private[zhttp] final case class Handler[R]( runtime.unsafeRun(ctx) { program } - - /** - * Writes any response to the Channel - */ - private def unsafeWriteAndFlushAnyResponse[A](res: Response)(implicit ctx: Ctx): Unit = { - ctx.writeAndFlush(encodeResponse(res)): Unit - } - - /** - * Writes not found error response to the Channel - */ - private def unsafeWriteAndFlushEmptyResponse()(implicit ctx: Ctx): Unit = { - ctx.writeAndFlush(notFoundResponse): Unit - } - - /** - * Writes error response to the Channel - */ - private def unsafeWriteAndFlushErrorResponse(cause: Throwable)(implicit ctx: Ctx): Unit = { - ctx.writeAndFlush(serverErrorResponse(cause)): Unit - } - - /** - * Writes Binary Stream data to the Channel - */ - private def writeStreamContent( - stream: ZStream[Any, Throwable, ByteBuf], - )(implicit ctx: Ctx): ZIO[Any, Throwable, Unit] = { - for { - _ <- stream.foreach(c => UIO(ctx.writeAndFlush(c))) - _ <- ChannelFuture.unit(ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT)) - } yield () - } - - /** - * Writes file content to the Channel. Does not use Chunked transfer encoding - */ - private def unsafeWriteFileContent(file: File)(implicit ctx: ChannelHandlerContext): Unit = { - import java.io.RandomAccessFile - - val raf = new RandomAccessFile(file, "r") - val fileLength = raf.length() - // Write the content. - ctx.write(new DefaultFileRegion(raf.getChannel, 0, fileLength)) - // Write the end marker. - ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT): Unit - } } diff --git a/zio-http/src/main/scala/zhttp/service/Server.scala b/zio-http/src/main/scala/zhttp/service/Server.scala index c728ac1db..0823b8617 100644 --- a/zio-http/src/main/scala/zhttp/service/Server.scala +++ b/zio-http/src/main/scala/zhttp/service/Server.scala @@ -6,6 +6,7 @@ import zhttp.http.Http._ import zhttp.http.{Http, HttpApp} import zhttp.service.server.ServerSSLHandler._ import zhttp.service.server._ +import zhttp.service.server.content.handlers.ServerResponseHandler import zio.{ZManaged, _} import java.net.{InetAddress, InetSocketAddress} @@ -127,8 +128,9 @@ object Server { channelFactory <- ZManaged.access[ServerChannelFactory](_.get) eventLoopGroup <- ZManaged.access[EventLoopGroup](_.get) zExec <- HttpRuntime.default[R].toManaged_ - reqHandler = settings.app.compile(zExec, settings, ServerTimeGenerator.make) - init = ServerChannelInitializer(zExec, settings, reqHandler) + reqHandler = settings.app.compile(zExec, settings) + respHandler = ServerResponseHandler(zExec, settings, ServerTimeGenerator.make) + init = ServerChannelInitializer(zExec, settings, reqHandler, respHandler) serverBootstrap = new ServerBootstrap().channelFactory(channelFactory).group(eventLoopGroup) chf <- ZManaged.effect(serverBootstrap.childHandler(init).bind(settings.address)) _ <- ChannelFuture.asManaged(chf) diff --git a/zio-http/src/main/scala/zhttp/service/package.scala b/zio-http/src/main/scala/zhttp/service/package.scala index da5bee035..ac3fbcdbe 100644 --- a/zio-http/src/main/scala/zhttp/service/package.scala +++ b/zio-http/src/main/scala/zhttp/service/package.scala @@ -8,6 +8,7 @@ package object service { private[service] val SERVER_CODEC_HANDLER = "SERVER_CODEC" private[service] val OBJECT_AGGREGATOR = "OBJECT_AGGREGATOR" private[service] val HTTP_REQUEST_HANDLER = "HTTP_REQUEST" + private[service] val HTTP_RESPONSE_HANDLER = "HTTP_RESPONSE" private[service] val HTTP_KEEPALIVE_HANDLER = "HTTP_KEEPALIVE" private[service] val FLOW_CONTROL_HANDLER = "FLOW_CONTROL_HANDLER" private[service] val WEB_SOCKET_HANDLER = "WEB_SOCKET_HANDLER" diff --git a/zio-http/src/main/scala/zhttp/service/server/ServerChannelInitializer.scala b/zio-http/src/main/scala/zhttp/service/server/ServerChannelInitializer.scala index 7427fe24c..055bf7aac 100644 --- a/zio-http/src/main/scala/zhttp/service/server/ServerChannelInitializer.scala +++ b/zio-http/src/main/scala/zhttp/service/server/ServerChannelInitializer.scala @@ -21,6 +21,7 @@ final case class ServerChannelInitializer[R]( zExec: HttpRuntime[R], cfg: Config[R, Throwable], reqHandler: ChannelHandler, + respHandler: ChannelHandler, ) extends ChannelInitializer[Channel] { override def initChannel(channel: Channel): Unit = { // !! IMPORTANT !! @@ -61,6 +62,9 @@ final case class ServerChannelInitializer[R]( // Always add ZIO Http Request Handler pipeline.addLast(HTTP_REQUEST_HANDLER, reqHandler) + // ServerResponseHandler - transforms Response to HttpResponse + pipeline.addLast(HTTP_RESPONSE_HANDLER, respHandler) + () } diff --git a/zio-http/src/main/scala/zhttp/service/server/content/handlers/ServerResponseHandler.scala b/zio-http/src/main/scala/zhttp/service/server/content/handlers/ServerResponseHandler.scala new file mode 100644 index 000000000..8914770ae --- /dev/null +++ b/zio-http/src/main/scala/zhttp/service/server/content/handlers/ServerResponseHandler.scala @@ -0,0 +1,89 @@ +package zhttp.service.server.content.handlers + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandler.Sharable +import io.netty.channel.{ChannelHandlerContext, DefaultFileRegion, SimpleChannelInboundHandler} +import io.netty.handler.codec.http._ +import zhttp.http.{HttpData, Response} +import zhttp.service.server.ServerTimeGenerator +import zhttp.service.{ChannelFuture, HttpRuntime, Server} +import zio.stream.ZStream +import zio.{UIO, ZIO} + +import java.io.File + +@Sharable +private[zhttp] case class ServerResponseHandler[R]( + runtime: HttpRuntime[R], + config: Server.Config[R, Throwable], + serverTime: ServerTimeGenerator, +) extends SimpleChannelInboundHandler[Response](false) { + + type Ctx = ChannelHandlerContext + + override def channelRead0(ctx: Ctx, response: Response): Unit = { + implicit val iCtx: ChannelHandlerContext = ctx + + ctx.write(encodeResponse(response)) + response.data match { + case HttpData.BinaryStream(stream) => runtime.unsafeRun(ctx) { writeStreamContent(stream) } + case HttpData.File(file) => unsafeWriteFileContent(file) + case _ => ctx.flush() + } + () + } + + override def exceptionCaught(ctx: Ctx, cause: Throwable): Unit = { + config.error.fold(super.exceptionCaught(ctx, cause))(f => runtime.unsafeRun(ctx)(f(cause))) + } + + /** + * Checks if an encoded version of the response exists, uses it if it does. Otherwise, it will return a fresh + * response. It will also set the server time if requested by the client. + */ + private def encodeResponse(res: Response): HttpResponse = { + + val jResponse = res.attribute.encoded match { + + // Check if the encoded response exists and/or was modified. + case Some((oRes, jResponse)) if oRes eq res => + jResponse match { + // Duplicate the response without allocating much memory + case response: FullHttpResponse => response.retainedDuplicate() + + case response => response + } + + case _ => res.unsafeEncode() + } + // Identify if the server time should be set and update if required. + if (res.attribute.serverTime) jResponse.headers().set(HttpHeaderNames.DATE, serverTime.refreshAndGet()) + jResponse + } + + /** + * Writes Binary Stream data to the Channel + */ + private def writeStreamContent[A]( + stream: ZStream[R, Throwable, ByteBuf], + )(implicit ctx: Ctx): ZIO[R, Throwable, Unit] = { + for { + _ <- stream.foreach(c => UIO(ctx.writeAndFlush(c))) + _ <- ChannelFuture.unit(ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT)) + } yield () + } + + /** + * Writes file content to the Channel. Does not use Chunked transfer encoding + */ + private def unsafeWriteFileContent(file: File)(implicit ctx: ChannelHandlerContext): Unit = { + import java.io.RandomAccessFile + + val raf = new RandomAccessFile(file, "r") + val fileLength = raf.length() + // Write the content. + ctx.write(new DefaultFileRegion(raf.getChannel, 0, fileLength)) + // Write the end marker. + ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT): Unit + } +}