From 138fb195424a4ad9ed1bbf0dc82249ce0ad66de4 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Fri, 4 Oct 2024 17:00:28 -0700 Subject: [PATCH 1/6] Add S2AStub cleanup handler. --- .../S2AProtocolNegotiatorFactory.java | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java index 188faf63435..5fa245faaa3 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -29,6 +29,7 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.errorprone.annotations.ThreadSafe; import io.grpc.Channel; +import io.grpc.ChannelLogger; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; @@ -179,6 +180,26 @@ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { } } + private static final class S2AStubCleanupNegotiationHandler extends ProtocolNegotiationHandler { + private final S2AStub s2aStub; + + private S2AStubCleanupNegotiationHandler( + ChannelHandler next, + ChannelLogger logger, + S2AStub s2aStub) { + super(next, logger); + this.s2aStub = s2aStub; + } + + @Override + protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { + s2aStub.close(); + fireProtocolNegotiationEvent(ctx); + ctx.pipeline().remove(this); + } + + } + private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler { private final Channel channel; private final Optional localIdentity; @@ -233,6 +254,9 @@ public void onSuccess(SslContext sslContext) { SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR)) .newHandler(grpcHandler); + ctx.pipeline().addAfter(ctx.name(), /* name= */ null, + new S2AStubCleanupNegotiationHandler(handler, + grpcHandler.getNegotiationLogger(), s2aStub)); // Remove the bufferReads handler and delegate the rest of the handshake to the TLS // handler. ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler); From a7fd45f77f1cb93baa6a47d6e845e8cc0fecdd40 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Wed, 9 Oct 2024 10:36:16 -0700 Subject: [PATCH 2/6] Give TLS and Cleanup handlers name + update comment. --- .../internal/handshaker/S2AProtocolNegotiatorFactory.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java index 5fa245faaa3..bb160186738 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -254,12 +254,12 @@ public void onSuccess(SslContext sslContext) { SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR)) .newHandler(grpcHandler); - ctx.pipeline().addAfter(ctx.name(), /* name= */ null, + // Delegate the rest of the handshake to the TLS handler. and remove the + // bufferReads handler. + ctx.pipeline().addAfter(ctx.name(), "tlsHandler", handler); + ctx.pipeline().addAfter("tlsHandler", "cleanupHandler", new S2AStubCleanupNegotiationHandler(handler, grpcHandler.getNegotiationLogger(), s2aStub)); - // Remove the bufferReads handler and delegate the rest of the handshake to the TLS - // handler. - ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler); fireProtocolNegotiationEvent(ctx); ctx.pipeline().remove(bufferReads); } From 8256c1eea91b4870f556546acef21a1d79e59f3a Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Wed, 9 Oct 2024 18:20:20 -0700 Subject: [PATCH 3/6] Don't add TLS handler twice. --- .../internal/handshaker/S2AProtocolNegotiatorFactory.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java index bb160186738..4d143d16b1e 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -257,11 +257,11 @@ public void onSuccess(SslContext sslContext) { // Delegate the rest of the handshake to the TLS handler. and remove the // bufferReads handler. ctx.pipeline().addAfter(ctx.name(), "tlsHandler", handler); - ctx.pipeline().addAfter("tlsHandler", "cleanupHandler", - new S2AStubCleanupNegotiationHandler(handler, - grpcHandler.getNegotiationLogger(), s2aStub)); fireProtocolNegotiationEvent(ctx); ctx.pipeline().remove(bufferReads); + ctx.pipeline().addLast( + new S2AStubCleanupNegotiationHandler(null, + null, s2aStub)); } @Override From d9d43175fc141f632236475e0a5fe4d54d9dd202 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Wed, 9 Oct 2024 18:22:59 -0700 Subject: [PATCH 4/6] Don't remove explicitly, since done by fireProtocolNegotiationEvent. --- .../s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java | 1 - 1 file changed, 1 deletion(-) diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java index 4d143d16b1e..1497b779078 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -195,7 +195,6 @@ private S2AStubCleanupNegotiationHandler( protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { s2aStub.close(); fireProtocolNegotiationEvent(ctx); - ctx.pipeline().remove(this); } } From 1698b54622602a30f9c3654ec8841e8619f7191d Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 10 Oct 2024 15:22:32 -0700 Subject: [PATCH 5/6] plumb S2AStub close to handshake end + add integration test. --- .../netty/InternalProtocolNegotiators.java | 11 ++- .../io/grpc/netty/NettyChannelBuilder.java | 3 +- .../io/grpc/netty/ProtocolNegotiators.java | 24 +++-- .../grpc/netty/NettyClientTransportTest.java | 4 +- .../grpc/netty/ProtocolNegotiatorsTest.java | 12 +-- .../io/grpc/s2a/S2AChannelCredentials.java | 14 ++- .../S2AProtocolNegotiatorFactory.java | 87 ++++++++++--------- .../grpc/s2a/internal/handshaker/S2AStub.java | 8 +- .../internal/handshaker/IntegrationTest.java | 23 +++++ .../S2AProtocolNegotiatorFactoryTest.java | 8 +- 10 files changed, 130 insertions(+), 64 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index b9c6a77982a..8aeb44d0fc2 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -24,6 +24,7 @@ import io.netty.channel.ChannelHandler; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; +import java.util.Optional; import java.util.concurrent.Executor; /** @@ -40,9 +41,10 @@ private InternalProtocolNegotiators() {} * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool) { + ObjectPool executorPool, + Optional handshakeCompleteRunnable) { final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, - executorPool); + executorPool, handshakeCompleteRunnable); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -70,7 +72,7 @@ public void close() { * may happen immediately, even before the TLS Handshake is complete. */ public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { - return tls(sslContext, null); + return tls(sslContext, null, Optional.empty()); } /** @@ -167,7 +169,8 @@ public static ChannelHandler grpcNegotiationHandler(GrpcHttp2ConnectionHandler n public static ChannelHandler clientTlsHandler( ChannelHandler next, SslContext sslContext, String authority, ChannelLogger negotiationLogger) { - return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger); + return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger, + Optional.empty()); } public static class ProtocolNegotiationHandler diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 305ad128454..d1d9810b485 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -63,6 +63,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -604,7 +605,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType( case PLAINTEXT_UPGRADE: return ProtocolNegotiators.plaintextUpgrade(); case TLS: - return ProtocolNegotiators.tls(sslContext, executorPool); + return ProtocolNegotiators.tls(sslContext, executorPool, Optional.empty()); default: throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 80df5d0e3c7..eece60455ec 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -72,6 +72,7 @@ import java.nio.channels.ClosedChannelException; import java.util.Arrays; import java.util.EnumSet; +import java.util.Optional; import java.util.Set; import java.util.concurrent.Executor; import java.util.logging.Level; @@ -543,16 +544,18 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator { public ClientTlsProtocolNegotiator(SslContext sslContext, - ObjectPool executorPool) { + ObjectPool executorPool, Optional handshakeCompleteRunnable) { this.sslContext = checkNotNull(sslContext, "sslContext"); this.executorPool = executorPool; if (this.executorPool != null) { this.executor = this.executorPool.getObject(); } + this.handshakeCompleteRunnable = handshakeCompleteRunnable; } private final SslContext sslContext; private final ObjectPool executorPool; + private final Optional handshakeCompleteRunnable; private Executor executor; @Override @@ -565,7 +568,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(), - this.executor, negotiationLogger); + this.executor, negotiationLogger, handshakeCompleteRunnable); return new WaitUntilActiveHandler(cth, negotiationLogger); } @@ -583,15 +586,18 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private final String host; private final int port; private Executor executor; + private final Optional handshakeCompleteRunnable; ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, - Executor executor, ChannelLogger negotiationLogger) { + Executor executor, ChannelLogger negotiationLogger, + Optional handshakeCompleteRunnable) { super(next, negotiationLogger); this.sslContext = checkNotNull(sslContext, "sslContext"); HostPort hostPort = parseAuthority(authority); this.host = hostPort.host; this.port = hostPort.port; this.executor = executor; + this.handshakeCompleteRunnable = handshakeCompleteRunnable; } @Override @@ -634,6 +640,9 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws .withCause(t) .asRuntimeException(); } + if (handshakeCompleteRunnable.isPresent()) { + handshakeCompleteRunnable.get().run(); + } ctx.fireExceptionCaught(t); } } else { @@ -649,6 +658,9 @@ private void propagateTlsComplete(ChannelHandlerContext ctx, SSLSession session) .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session) .build(); replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security)); + if (handshakeCompleteRunnable.isPresent()) { + handshakeCompleteRunnable.get().run(); + } fireProtocolNegotiationEvent(ctx); } } @@ -683,8 +695,8 @@ static HostPort parseAuthority(String authority) { * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool) { - return new ClientTlsProtocolNegotiator(sslContext, executorPool); + ObjectPool executorPool, Optional handshakeCompleteRunnable) { + return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable); } /** @@ -693,7 +705,7 @@ public static ProtocolNegotiator tls(SslContext sslContext, * may happen immediately, even before the TLS Handshake is complete. */ public static ProtocolNegotiator tls(SslContext sslContext) { - return tls(sslContext, null); + return tls(sslContext, null, Optional.empty()); } public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) { diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 9777bb0926c..7cb269b7d62 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -105,6 +105,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -766,7 +767,8 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { .trustManager(caCert) .keyManager(clientCert, clientKey) .build(); - ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool); + ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool, + Optional.empty()); // after starting the client, the Executor in the client pool should be used assertEquals(true, clientExecutorPool.isInUse()); final NettyClientTransport transport = newTransport(negotiator); diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 6939d835892..2ccdb2de543 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -120,6 +120,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -876,7 +877,7 @@ public String applicationProtocol() { DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger); + "authority", elg, noopLogger, Optional.empty()); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -914,7 +915,7 @@ public String applicationProtocol() { .applicationProtocolConfig(apn).build(); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger); + "authority", elg, noopLogger, Optional.empty()); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -938,7 +939,7 @@ public String applicationProtocol() { DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger); + "authority", elg, noopLogger, Optional.empty()); pipeline.addLast(handler); final AtomicReference error = new AtomicReference<>(); @@ -966,7 +967,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { @Test public void clientTlsHandler_closeDuringNegotiation() throws Exception { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", null, noopLogger); + "authority", null, noopLogger, Optional.empty()); pipeline.addLast(new WriteBufferingAndExceptionHandler(handler)); ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); @@ -1228,7 +1229,8 @@ public void clientTlsHandler_firesNegotiation() throws Exception { serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build(); } FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); - ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null); + ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, + null, Optional.empty()); WriteBufferingAndExceptionHandler clientWbaeh = new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); diff --git a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java index 2e040964dfa..2cbdf7e4c5f 100644 --- a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java +++ b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java @@ -31,6 +31,7 @@ import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; import io.grpc.s2a.internal.handshaker.S2AIdentity; import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory; +import io.grpc.s2a.internal.handshaker.S2AStub; import javax.annotation.concurrent.NotThreadSafe; import org.checkerframework.checker.nullness.qual.Nullable; @@ -59,6 +60,7 @@ public static final class Builder { private final String s2aAddress; private final ChannelCredentials s2aChannelCredentials; private @Nullable S2AIdentity localIdentity = null; + private @Nullable S2AStub stub = null; Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) { this.s2aAddress = s2aAddress; @@ -104,6 +106,16 @@ public Builder setLocalUid(String localUid) { return this; } + /** + * Sets the stub to use to communicate with S2A. This is only used for testing that the + * stream to S2A gets closed. + */ + public Builder setStub(S2AStub stub) { + checkNotNull(stub); + this.stub = stub; + return this; + } + public ChannelCredentials build() { return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory()); } @@ -113,7 +125,7 @@ InternalProtocolNegotiator.ClientFactory buildProtocolNegotiatorFactory() { SharedResourcePool.forResource( S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials)); checkNotNull(s2aChannelPool, "s2aChannelPool"); - return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool); + return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool, stub); } } diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java index 1497b779078..7ad9de991cf 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -29,7 +29,6 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.errorprone.annotations.ThreadSafe; import io.grpc.Channel; -import io.grpc.ChannelLogger; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; @@ -64,28 +63,35 @@ public final class S2AProtocolNegotiatorFactory { * @param localIdentity the identity of the client; if none is provided, the S2A will use the * client's default identity. * @param s2aChannelPool a pool of shared channels that can be used to connect to the S2A. + * @param stub the stub to use to communicate with S2A. If none is provided the channelPool + * will be used to create the stub. This is exposed for verifying the stream to S2A gets + * closed in tests. * @return a factory for creating a client-side protocol negotiator. */ public static InternalProtocolNegotiator.ClientFactory createClientFactory( - @Nullable S2AIdentity localIdentity, ObjectPool s2aChannelPool) { + @Nullable S2AIdentity localIdentity, ObjectPool s2aChannelPool, + @Nullable S2AStub stub) { checkNotNull(s2aChannelPool, "S2A channel pool should not be null."); - return new S2AClientProtocolNegotiatorFactory(localIdentity, s2aChannelPool); + return new S2AClientProtocolNegotiatorFactory(localIdentity, s2aChannelPool, stub); } static final class S2AClientProtocolNegotiatorFactory implements InternalProtocolNegotiator.ClientFactory { private final @Nullable S2AIdentity localIdentity; private final ObjectPool channelPool; + private final @Nullable S2AStub stub; S2AClientProtocolNegotiatorFactory( - @Nullable S2AIdentity localIdentity, ObjectPool channelPool) { + @Nullable S2AIdentity localIdentity, ObjectPool channelPool, + @Nullable S2AStub stub) { this.localIdentity = localIdentity; this.channelPool = channelPool; + this.stub = stub; } @Override public ProtocolNegotiator newNegotiator() { - return S2AProtocolNegotiator.createForClient(channelPool, localIdentity); + return S2AProtocolNegotiator.createForClient(channelPool, localIdentity, stub); } @Override @@ -99,18 +105,20 @@ public int getDefaultPort() { static final class S2AProtocolNegotiator implements ProtocolNegotiator { private final ObjectPool channelPool; - private final Channel channel; + private @Nullable Channel channel = null; private final Optional localIdentity; + private final @Nullable S2AStub stub; private final ListeningExecutorService service = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); static S2AProtocolNegotiator createForClient( - ObjectPool channelPool, @Nullable S2AIdentity localIdentity) { + ObjectPool channelPool, @Nullable S2AIdentity localIdentity, + @Nullable S2AStub stub) { checkNotNull(channelPool, "Channel pool should not be null."); if (localIdentity == null) { - return new S2AProtocolNegotiator(channelPool, Optional.empty()); + return new S2AProtocolNegotiator(channelPool, Optional.empty(), stub); } else { - return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity)); + return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity), stub); } } @@ -123,10 +131,13 @@ static S2AProtocolNegotiator createForClient( } private S2AProtocolNegotiator(ObjectPool channelPool, - Optional localIdentity) { + Optional localIdentity, @Nullable S2AStub stub) { this.channelPool = channelPool; this.localIdentity = localIdentity; - this.channel = channelPool.getObject(); + this.stub = stub; + if (this.stub == null) { + this.channel = channelPool.getObject(); + } } @Override @@ -140,13 +151,15 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { String hostname = getHostNameFromAuthority(grpcHandler.getAuthority()); checkArgument(!isNullOrEmpty(hostname), "hostname should not be null or empty."); return new S2AProtocolNegotiationHandler( - grpcHandler, channel, localIdentity, hostname, service); + grpcHandler, channel, localIdentity, hostname, service, stub); } @Override public void close() { service.shutdown(); - channelPool.returnObject(channel); + if (channel != null) { + channelPool.returnObject(channel); + } } } @@ -180,38 +193,21 @@ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { } } - private static final class S2AStubCleanupNegotiationHandler extends ProtocolNegotiationHandler { - private final S2AStub s2aStub; - - private S2AStubCleanupNegotiationHandler( - ChannelHandler next, - ChannelLogger logger, - S2AStub s2aStub) { - super(next, logger); - this.s2aStub = s2aStub; - } - - @Override - protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { - s2aStub.close(); - fireProtocolNegotiationEvent(ctx); - } - - } - private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler { - private final Channel channel; + private final @Nullable Channel channel; private final Optional localIdentity; private final String hostname; private final GrpcHttp2ConnectionHandler grpcHandler; private final ListeningExecutorService service; + private final @Nullable S2AStub stub; private S2AProtocolNegotiationHandler( GrpcHttp2ConnectionHandler grpcHandler, Channel channel, Optional localIdentity, String hostname, - ListeningExecutorService service) { + ListeningExecutorService service, + @Nullable S2AStub stub) { super( // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior @@ -229,6 +225,7 @@ public void handlerAdded(ChannelHandlerContext ctx) { this.hostname = hostname; checkNotNull(service, "service should not be null."); this.service = service; + this.stub = stub; } @Override @@ -237,8 +234,13 @@ protected void handlerAdded0(ChannelHandlerContext ctx) { BufferReadsHandler bufferReads = new BufferReadsHandler(); ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads); - S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(channel); - S2AStub s2aStub = S2AStub.newInstance(stub); + S2AStub s2aStub; + if (this.stub == null) { + checkNotNull(channel, "Channel to S2A should not be null"); + s2aStub = S2AStub.newInstance(S2AServiceGrpc.newStub(channel)); + } else { + s2aStub = this.stub; + } ListenableFuture sslContextFuture = service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity)); @@ -250,17 +252,20 @@ public void onSuccess(SslContext sslContext) { ChannelHandler handler = InternalProtocolNegotiators.tls( sslContext, - SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR)) + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR), + Optional.of(new Runnable() { + @Override + public void run() { + s2aStub.close(); + } + })) .newHandler(grpcHandler); // Delegate the rest of the handshake to the TLS handler. and remove the // bufferReads handler. - ctx.pipeline().addAfter(ctx.name(), "tlsHandler", handler); + ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler); fireProtocolNegotiationEvent(ctx); ctx.pipeline().remove(bufferReads); - ctx.pipeline().addLast( - new S2AStubCleanupNegotiationHandler(null, - null, s2aStub)); } @Override diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java index 0bfa3b4dac2..c5ac8f96d96 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java @@ -33,7 +33,7 @@ /** Reads and writes messages to and from the S2A. */ @NotThreadSafe -class S2AStub implements AutoCloseable { +public class S2AStub implements AutoCloseable { private static final Logger logger = Logger.getLogger(S2AStub.class.getName()); private static final long HANDSHAKE_RPC_DEADLINE_SECS = 20; private final StreamObserver reader = new Reader(); @@ -42,6 +42,7 @@ class S2AStub implements AutoCloseable { private StreamObserver writer; private boolean doneReading = false; private boolean doneWriting = false; + private boolean isClosed = false; static S2AStub newInstance(S2AServiceGrpc.S2AServiceStub serviceStub) { checkNotNull(serviceStub); @@ -136,6 +137,11 @@ public void close() { if (writer != null) { writer.onCompleted(); } + isClosed = true; + } + + public boolean isClosed() { + return isClosed; } /** Create a new writer if the writer is null. */ diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/IntegrationTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/IntegrationTest.java index e1ad3d278c3..d842fde8dea 100644 --- a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/IntegrationTest.java +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/IntegrationTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static java.util.concurrent.TimeUnit.SECONDS; +import io.grpc.Channel; import io.grpc.ChannelCredentials; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; @@ -29,9 +30,12 @@ import io.grpc.TlsChannelCredentials; import io.grpc.TlsServerCredentials; import io.grpc.benchmarks.Utils; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NettyServerBuilder; import io.grpc.s2a.S2AChannelCredentials; +import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; import io.grpc.s2a.internal.handshaker.FakeS2AServer; import io.grpc.stub.StreamObserver; import io.grpc.testing.protobuf.SimpleRequest; @@ -141,6 +145,25 @@ public void clientCommunicateUsingS2ACredentialsNoLocalIdentity_succeeds() throw assertThat(doUnaryRpc(channel)).isTrue(); } + @Test + public void clientCommunicateUsingS2ACredentialsSucceeds_verifyStreamToS2AClosed() + throws Exception { + ObjectPool s2aChannelPool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource(s2aAddress, + InsecureChannelCredentials.create())); + Channel ch = s2aChannelPool.getObject(); + S2AStub stub = S2AStub.newInstance(S2AServiceGrpc.newStub(ch)); + ChannelCredentials credentials = + S2AChannelCredentials.newBuilder(s2aAddress, InsecureChannelCredentials.create()) + .setLocalSpiffeId("test-spiffe-id").setStub(stub).build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + s2aChannelPool.returnObject(ch); + assertThat(doUnaryRpc(channel)).isTrue(); + assertThat(stub.isClosed()).isTrue(); + } + @Test public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception { String privateKeyPath = "src/test/resources/client_key.pem"; diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactoryTest.java index 48c512c4e5c..e537687c287 100644 --- a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactoryTest.java +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactoryTest.java @@ -122,7 +122,7 @@ public void createProtocolNegotiator_nullArgument() throws Exception { @Test public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception { InternalProtocolNegotiator.ClientFactory clientFactory = - S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool); + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null); assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT); } @@ -146,7 +146,7 @@ public void s2aProtocolNegotiator_getHostNameOnValidAuthority_returnsValidHostna public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds() throws Exception { InternalProtocolNegotiator.ClientFactory clientFactory = - S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool); + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null); ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); @@ -158,7 +158,7 @@ public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClien public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide() throws Exception { InternalProtocolNegotiator.ClientFactory clientFactory = - S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool); + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null); ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); clientNegotiator.close(); @@ -170,7 +170,7 @@ public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide public void createChannelHandler_addHandlerToMockContext() throws Exception { ProtocolNegotiator clientNegotiator = S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient( - channelPool, LOCAL_IDENTITY); + channelPool, LOCAL_IDENTITY, null); ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler); From 6cce1561c84905d3435b9e76195dd7b9bb25b93a Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 10 Oct 2024 15:39:50 -0700 Subject: [PATCH 6/6] close stub when TLS negotiation fails. --- netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index eece60455ec..84370ac8153 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -626,6 +626,9 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws Exception ex = unavailableException("Failed ALPN negotiation: Unable to find compatible protocol"); logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex); + if (handshakeCompleteRunnable.isPresent()) { + handshakeCompleteRunnable.get().run(); + } ctx.fireExceptionCaught(ex); } } else {