Skip to content

Commit

Permalink
s2a,netty: S2AHandshakerServiceChannel doesn't use custom event loop. (
Browse files Browse the repository at this point in the history
…grpc#11539)

* S2AHandshakerServiceChannel doesn't use custom event loop.

* use executorPool.

* log when channel not shutdown.

* use a cached threadpool.

* update non-executor version.
  • Loading branch information
rmehta19 authored and kannanjgithub committed Oct 23, 2024
1 parent 1aea7ba commit 48d35c9
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 65 deletions.
18 changes: 16 additions & 2 deletions netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
package io.grpc.netty;

import io.grpc.ChannelLogger;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler;
import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
import io.netty.channel.ChannelHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
import java.util.concurrent.Executor;

/**
* Internal accessor for {@link ProtocolNegotiators}.
Expand All @@ -35,9 +37,12 @@ private InternalProtocolNegotiators() {}
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
* may happen immediately, even before the TLS Handshake is complete.
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
*/
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext);
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
executorPool);
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {

@Override
Expand All @@ -58,6 +63,15 @@ public void close() {

return new TlsNegotiator();
}

/**
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
* may happen immediately, even before the TLS Handshake is complete.
*/
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
return tls(sslContext, null);
}

/**
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@
import io.grpc.MethodDescriptor;
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.concurrent.ThreadSafe;

/**
Expand All @@ -61,7 +59,6 @@
public final class S2AHandshakerServiceChannel {
private static final ConcurrentMap<String, Resource<Channel>> SHARED_RESOURCE_CHANNELS =
Maps.newConcurrentMap();
private static final Duration DELEGATE_TERMINATION_TIMEOUT = Duration.ofSeconds(2);
private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10);

/**
Expand Down Expand Up @@ -95,41 +92,34 @@ public ChannelResource(String targetAddress, Optional<ChannelCredentials> channe
}

/**
* Creates a {@code EventLoopHoldingChannel} instance to the service running at {@code
* targetAddress}. This channel uses a dedicated thread pool for its {@code EventLoopGroup}
* instance to avoid blocking.
* Creates a {@code HandshakerServiceChannel} instance to the service running at {@code
* targetAddress}.
*/
@Override
public Channel create() {
EventLoopGroup eventLoopGroup =
new NioEventLoopGroup(1, new DefaultThreadFactory("S2A channel pool", true));
ManagedChannel channel = null;
if (channelCredentials.isPresent()) {
// Create a secure channel.
channel =
NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get())
.channelType(NioSocketChannel.class)
.directExecutor()
.eventLoopGroup(eventLoopGroup)
.build();
} else {
// Create a plaintext channel.
channel =
NettyChannelBuilder.forTarget(targetAddress)
.channelType(NioSocketChannel.class)
.directExecutor()
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
return EventLoopHoldingChannel.create(channel, eventLoopGroup);
return HandshakerServiceChannel.create(channel);
}

/** Destroys a {@code EventLoopHoldingChannel} instance. */
/** Destroys a {@code HandshakerServiceChannel} instance. */
@Override
public void close(Channel instanceChannel) {
checkNotNull(instanceChannel);
EventLoopHoldingChannel channel = (EventLoopHoldingChannel) instanceChannel;
HandshakerServiceChannel channel = (HandshakerServiceChannel) instanceChannel;
channel.close();
}

Expand All @@ -140,23 +130,21 @@ public String toString() {
}

/**
* Manages a channel using a {@link ManagedChannel} instance that belong to the {@code
* EventLoopGroup} thread pool.
* Manages a channel using a {@link ManagedChannel} instance.
*/
@VisibleForTesting
static class EventLoopHoldingChannel extends Channel {
static class HandshakerServiceChannel extends Channel {
private static final Logger logger =
Logger.getLogger(S2AHandshakerServiceChannel.class.getName());
private final ManagedChannel delegate;
private final EventLoopGroup eventLoopGroup;

static EventLoopHoldingChannel create(ManagedChannel delegate, EventLoopGroup eventLoopGroup) {
static HandshakerServiceChannel create(ManagedChannel delegate) {
checkNotNull(delegate);
checkNotNull(eventLoopGroup);
return new EventLoopHoldingChannel(delegate, eventLoopGroup);
return new HandshakerServiceChannel(delegate);
}

private EventLoopHoldingChannel(ManagedChannel delegate, EventLoopGroup eventLoopGroup) {
private HandshakerServiceChannel(ManagedChannel delegate) {
this.delegate = delegate;
this.eventLoopGroup = eventLoopGroup;
}

/**
Expand All @@ -178,16 +166,12 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
@SuppressWarnings("FutureReturnValueIgnored")
public void close() {
delegate.shutdownNow();
boolean isDelegateTerminated;
try {
isDelegateTerminated =
delegate.awaitTermination(DELEGATE_TERMINATION_TIMEOUT.getSeconds(), SECONDS);
delegate.awaitTermination(CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
} catch (InterruptedException e) {
isDelegateTerminated = false;
Thread.currentThread().interrupt();
logger.log(Level.WARNING, "Channel to S2A was not shutdown.");
}
long quietPeriodSeconds = isDelegateTerminated ? 0 : 1;
eventLoopGroup.shutdownGracefully(
quietPeriodSeconds, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import com.google.common.util.concurrent.MoreExecutors;
import com.google.errorprone.annotations.ThreadSafe;
import io.grpc.Channel;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
Expand Down Expand Up @@ -227,7 +229,10 @@ protected void handlerAdded0(ChannelHandlerContext ctx) {
@Override
public void onSuccess(SslContext sslContext) {
ChannelHandler handler =
InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);
InternalProtocolNegotiators.tls(
sslContext,
SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR))
.newHandler(grpcHandler);

// Remove the bufferReads handler and delegate the rest of the handshake to the TLS
// handler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import io.grpc.CallOptions;
import io.grpc.Channel;
Expand All @@ -39,15 +35,13 @@
import io.grpc.benchmarks.Utils;
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.s2a.channel.S2AHandshakerServiceChannel.EventLoopHoldingChannel;
import io.grpc.s2a.channel.S2AHandshakerServiceChannel.HandshakerServiceChannel;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
import io.netty.channel.EventLoopGroup;
import java.io.File;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
Expand All @@ -60,8 +54,6 @@
@RunWith(JUnit4.class)
public final class S2AHandshakerServiceChannelTest {
@ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10);
private final EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class);
private Server mtlsServer;
private Server plaintextServer;

Expand Down Expand Up @@ -191,7 +183,7 @@ public void close_mtlsSuccess() throws Exception {
}

/**
* Verifies that an {@code EventLoopHoldingChannel}'s {@code newCall} method can be used to
* Verifies that an {@code HandshakerServiceChannel}'s {@code newCall} method can be used to
* perform a simple RPC.
*/
@Test
Expand All @@ -201,7 +193,7 @@ public void newCall_performSimpleRpcSuccess() {
"localhost:" + plaintextServer.getPort(),
/* s2aChannelCredentials= */ Optional.empty());
Channel channel = resource.create();
assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channel).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()))
.isEqualToDefaultInstance();
Expand All @@ -214,53 +206,49 @@ public void newCall_mtlsPerformSimpleRpcSuccess() throws Exception {
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + mtlsServer.getPort(), getTlsChannelCredentials());
Channel channel = resource.create();
assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channel).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()))
.isEqualToDefaultInstance();
}

/** Creates a {@code EventLoopHoldingChannel} instance and verifies its authority. */
/** Creates a {@code HandshakerServiceChannel} instance and verifies its authority. */
@Test
public void authority_success() throws Exception {
ManagedChannel channel = new FakeManagedChannel(true);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
assertThat(eventLoopHoldingChannel.authority()).isEqualTo("FakeManagedChannel");
}

/**
* Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} terminates
* successfully.
* Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel}
* terminates successfully.
*/
@Test
public void close_withDelegateTerminatedSuccess() throws Exception {
ManagedChannel channel = new FakeManagedChannel(true);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
eventLoopHoldingChannel.close();
assertThat(channel.isShutdown()).isTrue();
verify(mockEventLoopGroup, times(1))
.shutdownGracefully(0, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}

/**
* Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} does not
* Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel} does not
* terminate successfully.
*/
@Test
public void close_withDelegateTerminatedFailure() throws Exception {
ManagedChannel channel = new FakeManagedChannel(false);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
eventLoopHoldingChannel.close();
assertThat(channel.isShutdown()).isTrue();
verify(mockEventLoopGroup, times(1))
.shutdownGracefully(1, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}

/**
* Creates and closes a {@code EventLoopHoldingChannel}, creates a new channel from the same
* Creates and closes a {@code HandshakerServiceChannel}, creates a new channel from the same
* resource, and verifies that this second channel is useable.
*/
@Test
Expand All @@ -273,7 +261,7 @@ public void create_succeedsAfterCloseIsCalledOnce() throws Exception {
resource.close(channelOne);

Channel channelTwo = resource.create();
assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channelTwo)
.unaryRpc(SimpleRequest.getDefaultInstance()))
Expand All @@ -291,7 +279,7 @@ public void create_mtlsSucceedsAfterCloseIsCalledOnce() throws Exception {
resource.close(channelOne);

Channel channelTwo = resource.create();
assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channelTwo)
.unaryRpc(SimpleRequest.getDefaultInstance()))
Expand Down

0 comments on commit 48d35c9

Please sign in to comment.