Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix keepalive logic #50

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ extension ChannelPipeline.SynchronousOperations {
var http2HandlerStreamConfiguration = NIOHTTP2Handler.StreamConfiguration()
http2HandlerStreamConfiguration.targetWindowSize = clampedTargetWindowSize

let boundConnectionManagementHandler = NIOLoopBound(
serverConnectionHandler.syncView,
eventLoop: self.eventLoop
)
let streamMultiplexer = try self.configureAsyncHTTP2Pipeline(
mode: .server,
streamDelegate: serverConnectionHandler.http2StreamDelegate,
Expand All @@ -86,7 +90,8 @@ extension ChannelPipeline.SynchronousOperations {
acceptedEncodings: compressionConfig.enabledAlgorithms,
maxPayloadSize: rpcConfig.maxRequestPayloadSize,
methodDescriptorPromise: methodDescriptorPromise,
eventLoop: streamChannel.eventLoop
eventLoop: streamChannel.eventLoop,
connectionManagementHandler: boundConnectionManagementHandler.value
)
try streamChannel.pipeline.syncOperations.addHandler(streamHandler)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
}

/// Stats about recently written frames. Used to determine whether to reset keep-alive state.
private var frameStats: FrameStats
package var frameStats: FrameStats

struct FrameStats {
package struct FrameStats {
private(set) var didWriteHeadersOrData = false

/// Mark that a HEADERS frame has been written.
Expand Down Expand Up @@ -609,7 +609,13 @@ extension ServerConnectionManagementHandler {

context.write(self.wrapOutboundOut(goAway), promise: nil)
self.maybeFlush(context: context)
context.close(promise: nil)

// We must delay the channel close after sending the GOAWAY packet by a tick to make sure it
// gets flushed and delivered to the client before the connection is closed.
let loopBound = NIOLoopBound(context, eventLoop: context.eventLoop)
gjcairo marked this conversation as resolved.
Show resolved Hide resolved
context.eventLoop.execute {
loopBound.value.close(promise: nil)
}

case .sendAck:
() // ACKs are sent by NIO's HTTP/2 handler, don't double ack.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan

private var cancellationHandle: Optional<ServerContext.RPCCancellationHandle>

package let connectionManagementHandler: ServerConnectionManagementHandler.SyncView

// Existential errors unconditionally allocate, avoid this per-use allocation by doing it
// statically.
private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError(
Expand All @@ -55,6 +57,7 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan
maxPayloadSize: Int,
methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
eventLoop: any EventLoop,
connectionManagementHandler: ServerConnectionManagementHandler.SyncView,
cancellationHandler: ServerContext.RPCCancellationHandle? = nil,
skipStateMachineAssertions: Bool = false
) {
Expand All @@ -66,6 +69,7 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan
self.methodDescriptorPromise = methodDescriptorPromise
self.cancellationHandle = cancellationHandler
self.eventLoop = eventLoop
self.connectionManagementHandler = connectionManagementHandler
}

package func setCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) {
Expand Down Expand Up @@ -136,13 +140,16 @@ extension GRPCServerStreamHandler {
switch self.stateMachine.nextInboundMessage() {
case .receiveMessage(let message):
context.fireChannelRead(self.wrapInboundOut(.message(message)))

case .awaitMoreMessages:
break loop

case .noMoreMessages:
context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
break loop
}
}

case .doNothing:
()
}
Expand Down Expand Up @@ -261,6 +268,7 @@ extension GRPCServerStreamHandler {
self.flushPending = true
let headers = try self.stateMachine.send(metadata: metadata)
context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
self.connectionManagementHandler.wroteHeadersFrame()
} catch let invalidState {
let error = RPCError(invalidState)
promise?.fail(error)
Expand All @@ -270,6 +278,7 @@ extension GRPCServerStreamHandler {
case .message(let message):
do {
try self.stateMachine.send(message: message, promise: promise)
self.connectionManagementHandler.wroteDataFrame()
} catch let invalidState {
let error = RPCError(invalidState)
promise?.fail(error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,24 @@ extension ConnectionTest {
let h2 = NIOHTTP2Handler(mode: .server)
let mux = HTTP2StreamMultiplexer(mode: .server, channel: channel) { stream in
let sync = stream.pipeline.syncOperations
let connectionManagementHandler = ServerConnectionManagementHandler(
eventLoop: stream.eventLoop,
maxIdleTime: nil,
maxAge: nil,
maxGraceTime: nil,
keepaliveTime: nil,
keepaliveTimeout: nil,
allowKeepaliveWithoutCalls: false,
minPingIntervalWithoutCalls: .minutes(5),
requireALPN: false
)
let handler = GRPCServerStreamHandler(
scheme: .http,
acceptedEncodings: .none,
maxPayloadSize: .max,
methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self),
eventLoop: stream.eventLoop
eventLoop: stream.eventLoop,
connectionManagementHandler: connectionManagementHandler.syncView
)

return stream.eventLoop.makeCompletedFuture {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,24 @@ final class TestServer: Sendable {
let sync = channel.pipeline.syncOperations
let multiplexer = try sync.configureAsyncHTTP2Pipeline(mode: .server) { stream in
stream.eventLoop.makeCompletedFuture {
let connectionManagementHandler = ServerConnectionManagementHandler(
eventLoop: stream.eventLoop,
maxIdleTime: nil,
maxAge: nil,
maxGraceTime: nil,
keepaliveTime: nil,
keepaliveTimeout: nil,
allowKeepaliveWithoutCalls: false,
minPingIntervalWithoutCalls: .minutes(5),
requireALPN: false
)
let handler = GRPCServerStreamHandler(
scheme: .http,
acceptedEncodings: .all,
maxPayloadSize: .max,
methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self),
eventLoop: stream.eventLoop
eventLoop: stream.eventLoop,
connectionManagementHandler: connectionManagementHandler.syncView
)

try stream.pipeline.syncOperations.addHandlers(handler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,25 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
descriptorPromise: EventLoopPromise<MethodDescriptor>? = nil,
disableAssertions: Bool = false
) -> GRPCServerStreamHandler {
let serverConnectionManagementHandler = ServerConnectionManagementHandler(
eventLoop: channel.eventLoop,
maxIdleTime: nil,
maxAge: nil,
maxGraceTime: nil,
keepaliveTime: nil,
keepaliveTimeout: nil,
allowKeepaliveWithoutCalls: false,
minPingIntervalWithoutCalls: .minutes(5),
requireALPN: false
)

return GRPCServerStreamHandler(
scheme: scheme,
acceptedEncodings: acceptedEncodings,
maxPayloadSize: maxPayloadSize,
methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(),
eventLoop: channel.eventLoop,
connectionManagementHandler: serverConnectionManagementHandler.syncView,
skipStateMachineAssertions: disableAssertions
)
}
Expand Down Expand Up @@ -974,28 +987,50 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
}

struct ServerStreamHandlerTests {
private func makeServerStreamHandler(
struct ConnectionAndStreamHandlers {
let streamHandler: GRPCServerStreamHandler
let connectionHandler: ServerConnectionManagementHandler
}

private func makeServerConnectionAndStreamHandlers(
channel: any Channel,
scheme: Scheme = .http,
acceptedEncodings: CompressionAlgorithmSet = [],
maxPayloadSize: Int = .max,
descriptorPromise: EventLoopPromise<MethodDescriptor>? = nil,
disableAssertions: Bool = false
) -> GRPCServerStreamHandler {
return GRPCServerStreamHandler(
) -> ConnectionAndStreamHandlers {
let connectionManagementHandler = ServerConnectionManagementHandler(
eventLoop: channel.eventLoop,
maxIdleTime: nil,
maxAge: nil,
maxGraceTime: nil,
keepaliveTime: nil,
keepaliveTimeout: nil,
allowKeepaliveWithoutCalls: false,
minPingIntervalWithoutCalls: .minutes(5),
requireALPN: false
)
let streamHandler = GRPCServerStreamHandler(
scheme: scheme,
acceptedEncodings: acceptedEncodings,
maxPayloadSize: maxPayloadSize,
methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(),
eventLoop: channel.eventLoop,
connectionManagementHandler: connectionManagementHandler.syncView,
skipStateMachineAssertions: disableAssertions
)

return ConnectionAndStreamHandlers(
streamHandler: streamHandler,
connectionHandler: connectionManagementHandler
)
}

@Test("ChannelShouldQuiesceEvent is buffered and turns into RPC cancellation")
func shouldQuiesceEventIsBufferedBeforeHandleIsSet() async throws {
let channel = EmbeddedChannel()
let handler = self.makeServerStreamHandler(channel: channel)
let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
try channel.pipeline.syncOperations.addHandler(handler)
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())

Expand All @@ -1011,7 +1046,7 @@ struct ServerStreamHandlerTests {
@Test("ChannelShouldQuiesceEvent turns into RPC cancellation")
func shouldQuiesceEventTriggersCancellation() async throws {
let channel = EmbeddedChannel()
let handler = self.makeServerStreamHandler(channel: channel)
let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
try channel.pipeline.syncOperations.addHandler(handler)

await withServerContextRPCCancellationHandle { handle in
Expand All @@ -1028,7 +1063,7 @@ struct ServerStreamHandlerTests {
@Test("RST_STREAM turns into RPC cancellation")
func rstStreamTriggersCancellation() async throws {
let channel = EmbeddedChannel()
let handler = self.makeServerStreamHandler(channel: channel)
let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
try channel.pipeline.syncOperations.addHandler(handler)

await withServerContextRPCCancellationHandle { handle in
Expand All @@ -1045,6 +1080,51 @@ struct ServerStreamHandlerTests {
_ = try? channel.finish()
}

@Test("Connection FrameStats are updated when writing headers or data frames")
func connectionFrameStatsAreUpdatedAccordingly() async throws {
let channel = EmbeddedChannel()
let handlers = self.makeServerConnectionAndStreamHandlers(channel: channel)
try channel.pipeline.syncOperations.addHandler(handlers.streamHandler)

// We have written nothing yet, so expect FrameStats/didWriteHeadersOrData to be false
#expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)

// FrameStats aren't affected by pings received
channel.pipeline.fireChannelRead(
NIOAny(HTTP2Frame.FramePayload.ping(.init(withInteger: 42), ack: false))
)
#expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)

// Now write back headers and make sure FrameStats are updated accordingly:
// To do that, we first need to receive client's initial metadata...
let clientInitialMetadata: HPACKHeaders = [
GRPCHTTP2Keys.path.rawValue: "/SomeService/SomeMethod",
GRPCHTTP2Keys.scheme.rawValue: "http",
GRPCHTTP2Keys.method.rawValue: "POST",
GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
GRPCHTTP2Keys.te.rawValue: "trailers",
]
try channel.writeInbound(
HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata))
)

// Now we write back server's initial metadata...
let serverInitialMetadata = RPCResponsePart.metadata([:])
try channel.writeOutbound(serverInitialMetadata)

// And this should have updated the FrameStats
#expect(handlers.connectionHandler.frameStats.didWriteHeadersOrData)

// Manually reset the FrameStats to make sure that writing data also updates it correctly.
handlers.connectionHandler.frameStats.reset()
#expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)
try channel.writeOutbound(RPCResponsePart.message([42]))
#expect(handlers.connectionHandler.frameStats.didWriteHeadersOrData)

// Clean up.
// Throwing is fine: the channel is closed abruptly, errors are expected.
_ = try? channel.finish()
}
}

extension EmbeddedChannel {
Expand Down
Loading