Skip to content

Commit

Permalink
Refactor: Add Response Handler (#727)
Browse files Browse the repository at this point in the history
* added response handler that is responsible of transofrming internal zio-http Response in netty HTTP response.

* added response handler that is responsible of transforming internal zio-http Response in netty HTTP response.

* removed redundant code

* merged and fixed code.

* formatting

* simplified the channelRead0 implementation avoiding patter matching on status. Updated comments

* simplified more the code

* formatting

* more formatting
  • Loading branch information
gciuloaica authored Jan 11, 2022
1 parent 83792d1 commit 2f0eb65
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 132 deletions.
4 changes: 1 addition & 3 deletions zio-http/src/main/scala/zhttp/http/Http.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package zhttp.http
import io.netty.channel.ChannelHandler
import zhttp.html.Html
import zhttp.http.headers.HeaderModifier
import zhttp.service.server.ServerTimeGenerator
import zhttp.service.{Handler, HttpRuntime, Server}
import zio._
import zio.clock.Clock
Expand Down Expand Up @@ -393,11 +392,10 @@ object Http {
private[zhttp] def compile[R1 <: R](
zExec: HttpRuntime[R1],
settings: Server.Config[R1, Throwable],
serverTime: ServerTimeGenerator,
)(implicit
evE: E <:< Throwable,
): ChannelHandler =
Handler(http.asInstanceOf[HttpApp[R1, Throwable]], zExec, settings, serverTime)
Handler(http.asInstanceOf[HttpApp[R1, Throwable]], zExec, settings)
}

/**
Expand Down
142 changes: 15 additions & 127 deletions zio-http/src/main/scala/zhttp/service/Handler.scala
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
package zhttp.service

import io.netty.buffer.{ByteBuf, Unpooled}
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.{ChannelHandlerContext, DefaultFileRegion, SimpleChannelInboundHandler}
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.handler.codec.http._
import zhttp.core.Util
import zhttp.http._
import zhttp.service.server.{ServerTimeGenerator, WebSocketUpgrade}
import zio.stream.ZStream
import zhttp.service.server.WebSocketUpgrade
import zio.{Task, UIO, ZIO}

import java.io.File
import java.net.{InetAddress, InetSocketAddress}

@Sharable
private[zhttp] final case class Handler[R](
app: HttpApp[R, Throwable],
runtime: HttpRuntime[R],
config: Server.Config[R, Throwable],
serverTime: ServerTimeGenerator,
) extends SimpleChannelInboundHandler[FullHttpRequest](false)
with WebSocketUpgrade[R] {
self =>
with WebSocketUpgrade[R] { self =>

type Ctx = ChannelHandlerContext

Expand Down Expand Up @@ -50,42 +45,6 @@ private[zhttp] final case class Handler[R](
)
}

override def exceptionCaught(ctx: Ctx, cause: Throwable): Unit = {
config.error.fold(super.exceptionCaught(ctx, cause))(f => runtime.unsafeRun(ctx)(f(cause)))
}

/**
* Checks if an encoded version of the response exists, uses it if it does. Otherwise, it will return a fresh
* response. It will also set the server time if requested by the client.
*/
private def encodeResponse(res: Response): HttpResponse = {

val jResponse = res.attribute.encoded match {

// Check if the encoded response exists and/or was modified.
case Some((oRes, jResponse)) if oRes eq res =>
jResponse match {
// Duplicate the response without allocating much memory
case response: FullHttpResponse =>
response.retainedDuplicate()

case response =>
response
}

case _ => res.unsafeEncode()
}
// Identify if the server time should be set and update if required.
if (res.attribute.serverTime) jResponse.headers().set(HttpHeaderNames.DATE, serverTime.refreshAndGet())
jResponse
}

private def notFoundResponse: HttpResponse = {
val response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND, false)
response.headers().setInt(HttpHeaderNames.CONTENT_LENGTH, 0)
response
}

/**
* Releases the FullHttpRequest safely.
*/
Expand All @@ -95,17 +54,6 @@ private[zhttp] final case class Handler[R](
}
}

private def serverErrorResponse(cause: Throwable): HttpResponse = {
val content = Util.prettyPrintHtml(cause)
val response = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1,
HttpResponseStatus.INTERNAL_SERVER_ERROR,
Unpooled.copiedBuffer(content, HTTP_CHARSET),
)
response.headers().set(HttpHeaderNames.CONTENT_LENGTH, content.length)
response
}

/**
* Executes http apps
*/
Expand All @@ -121,12 +69,12 @@ private[zhttp] final case class Handler[R](
{
case Some(cause) =>
UIO {
unsafeWriteAndFlushErrorResponse(cause)
ctx.fireChannelRead(Response.fromHttpError(HttpError.InternalServerError(cause = Some(cause))))
releaseRequest(jReq)
}
case None =>
UIO {
unsafeWriteAndFlushEmptyResponse()
ctx.fireChannelRead(Response.status(Status.NOT_FOUND))
releaseRequest(jReq)
}
},
Expand All @@ -135,16 +83,7 @@ private[zhttp] final case class Handler[R](
else {
for {
_ <- UIO {
// Write the initial line and the header.
unsafeWriteAndFlushAnyResponse(res)
}
_ <- res.data match {
case HttpData.BinaryStream(stream) => writeStreamContent(stream)
case HttpData.File(file) =>
UIO {
unsafeWriteFileContent(file)
}
case _ => UIO(ctx.flush())
ctx.fireChannelRead(res)
}
_ <- Task(releaseRequest(jReq))
} yield ()
Expand All @@ -156,22 +95,18 @@ private[zhttp] final case class Handler[R](
if (self.isWebSocket(res)) {
self.upgradeToWebSocket(ctx, jReq, res)
} else {
// Write the initial line and the header.
unsafeWriteAndFlushAnyResponse(res)
res.data match {
case HttpData.BinaryStream(stream) => unsafeRunZIO(writeStreamContent(stream) *> Task(releaseRequest(jReq)))
case HttpData.File(file) =>
unsafeWriteFileContent(file)
case _ => releaseRequest(jReq)
}
ctx.fireChannelRead(res)
releaseRequest(jReq)
}
case HExit.Failure(e) =>
unsafeWriteAndFlushErrorResponse(e)

case HExit.Failure(e) =>
ctx.fireChannelRead(e)
releaseRequest(jReq)
case HExit.Empty =>
unsafeWriteAndFlushEmptyResponse()
case HExit.Empty =>
ctx.fireChannelRead(Response.status(Status.NOT_FOUND))
releaseRequest(jReq)
}

}

/**
Expand All @@ -181,51 +116,4 @@ private[zhttp] final case class Handler[R](
runtime.unsafeRun(ctx) {
program
}

/**
* Writes any response to the Channel
*/
private def unsafeWriteAndFlushAnyResponse[A](res: Response)(implicit ctx: Ctx): Unit = {
ctx.writeAndFlush(encodeResponse(res)): Unit
}

/**
* Writes not found error response to the Channel
*/
private def unsafeWriteAndFlushEmptyResponse()(implicit ctx: Ctx): Unit = {
ctx.writeAndFlush(notFoundResponse): Unit
}

/**
* Writes error response to the Channel
*/
private def unsafeWriteAndFlushErrorResponse(cause: Throwable)(implicit ctx: Ctx): Unit = {
ctx.writeAndFlush(serverErrorResponse(cause)): Unit
}

/**
* Writes Binary Stream data to the Channel
*/
private def writeStreamContent(
stream: ZStream[Any, Throwable, ByteBuf],
)(implicit ctx: Ctx): ZIO[Any, Throwable, Unit] = {
for {
_ <- stream.foreach(c => UIO(ctx.writeAndFlush(c)))
_ <- ChannelFuture.unit(ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT))
} yield ()
}

/**
* Writes file content to the Channel. Does not use Chunked transfer encoding
*/
private def unsafeWriteFileContent(file: File)(implicit ctx: ChannelHandlerContext): Unit = {
import java.io.RandomAccessFile

val raf = new RandomAccessFile(file, "r")
val fileLength = raf.length()
// Write the content.
ctx.write(new DefaultFileRegion(raf.getChannel, 0, fileLength))
// Write the end marker.
ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT): Unit
}
}
6 changes: 4 additions & 2 deletions zio-http/src/main/scala/zhttp/service/Server.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import zhttp.http.Http._
import zhttp.http.{Http, HttpApp}
import zhttp.service.server.ServerSSLHandler._
import zhttp.service.server._
import zhttp.service.server.content.handlers.ServerResponseHandler
import zio.{ZManaged, _}

import java.net.{InetAddress, InetSocketAddress}
Expand Down Expand Up @@ -127,8 +128,9 @@ object Server {
channelFactory <- ZManaged.access[ServerChannelFactory](_.get)
eventLoopGroup <- ZManaged.access[EventLoopGroup](_.get)
zExec <- HttpRuntime.default[R].toManaged_
reqHandler = settings.app.compile(zExec, settings, ServerTimeGenerator.make)
init = ServerChannelInitializer(zExec, settings, reqHandler)
reqHandler = settings.app.compile(zExec, settings)
respHandler = ServerResponseHandler(zExec, settings, ServerTimeGenerator.make)
init = ServerChannelInitializer(zExec, settings, reqHandler, respHandler)
serverBootstrap = new ServerBootstrap().channelFactory(channelFactory).group(eventLoopGroup)
chf <- ZManaged.effect(serverBootstrap.childHandler(init).bind(settings.address))
_ <- ChannelFuture.asManaged(chf)
Expand Down
1 change: 1 addition & 0 deletions zio-http/src/main/scala/zhttp/service/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package object service {
private[service] val SERVER_CODEC_HANDLER = "SERVER_CODEC"
private[service] val OBJECT_AGGREGATOR = "OBJECT_AGGREGATOR"
private[service] val HTTP_REQUEST_HANDLER = "HTTP_REQUEST"
private[service] val HTTP_RESPONSE_HANDLER = "HTTP_RESPONSE"
private[service] val HTTP_KEEPALIVE_HANDLER = "HTTP_KEEPALIVE"
private[service] val FLOW_CONTROL_HANDLER = "FLOW_CONTROL_HANDLER"
private[service] val WEB_SOCKET_HANDLER = "WEB_SOCKET_HANDLER"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ final case class ServerChannelInitializer[R](
zExec: HttpRuntime[R],
cfg: Config[R, Throwable],
reqHandler: ChannelHandler,
respHandler: ChannelHandler,
) extends ChannelInitializer[Channel] {
override def initChannel(channel: Channel): Unit = {
// !! IMPORTANT !!
Expand Down Expand Up @@ -61,6 +62,9 @@ final case class ServerChannelInitializer[R](
// Always add ZIO Http Request Handler
pipeline.addLast(HTTP_REQUEST_HANDLER, reqHandler)

// ServerResponseHandler - transforms Response to HttpResponse
pipeline.addLast(HTTP_RESPONSE_HANDLER, respHandler)

()
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package zhttp.service.server.content.handlers

import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.{ChannelHandlerContext, DefaultFileRegion, SimpleChannelInboundHandler}
import io.netty.handler.codec.http._
import zhttp.http.{HttpData, Response}
import zhttp.service.server.ServerTimeGenerator
import zhttp.service.{ChannelFuture, HttpRuntime, Server}
import zio.stream.ZStream
import zio.{UIO, ZIO}

import java.io.File

@Sharable
private[zhttp] case class ServerResponseHandler[R](
runtime: HttpRuntime[R],
config: Server.Config[R, Throwable],
serverTime: ServerTimeGenerator,
) extends SimpleChannelInboundHandler[Response](false) {

type Ctx = ChannelHandlerContext

override def channelRead0(ctx: Ctx, response: Response): Unit = {
implicit val iCtx: ChannelHandlerContext = ctx

ctx.write(encodeResponse(response))
response.data match {
case HttpData.BinaryStream(stream) => runtime.unsafeRun(ctx) { writeStreamContent(stream) }
case HttpData.File(file) => unsafeWriteFileContent(file)
case _ => ctx.flush()
}
()
}

override def exceptionCaught(ctx: Ctx, cause: Throwable): Unit = {
config.error.fold(super.exceptionCaught(ctx, cause))(f => runtime.unsafeRun(ctx)(f(cause)))
}

/**
* Checks if an encoded version of the response exists, uses it if it does. Otherwise, it will return a fresh
* response. It will also set the server time if requested by the client.
*/
private def encodeResponse(res: Response): HttpResponse = {

val jResponse = res.attribute.encoded match {

// Check if the encoded response exists and/or was modified.
case Some((oRes, jResponse)) if oRes eq res =>
jResponse match {
// Duplicate the response without allocating much memory
case response: FullHttpResponse => response.retainedDuplicate()

case response => response
}

case _ => res.unsafeEncode()
}
// Identify if the server time should be set and update if required.
if (res.attribute.serverTime) jResponse.headers().set(HttpHeaderNames.DATE, serverTime.refreshAndGet())
jResponse
}

/**
* Writes Binary Stream data to the Channel
*/
private def writeStreamContent[A](
stream: ZStream[R, Throwable, ByteBuf],
)(implicit ctx: Ctx): ZIO[R, Throwable, Unit] = {
for {
_ <- stream.foreach(c => UIO(ctx.writeAndFlush(c)))
_ <- ChannelFuture.unit(ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT))
} yield ()
}

/**
* Writes file content to the Channel. Does not use Chunked transfer encoding
*/
private def unsafeWriteFileContent(file: File)(implicit ctx: ChannelHandlerContext): Unit = {
import java.io.RandomAccessFile

val raf = new RandomAccessFile(file, "r")
val fileLength = raf.length()
// Write the content.
ctx.write(new DefaultFileRegion(raf.getChannel, 0, fileLength))
// Write the end marker.
ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT): Unit
}
}

0 comments on commit 2f0eb65

Please sign in to comment.