Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s2a,netty: S2AHandshakerServiceChannel doesn't use custom event loop. #11539

Merged
merged 7 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
package io.grpc.netty;

import io.grpc.ChannelLogger;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler;
import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
import io.netty.channel.ChannelHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
import java.util.concurrent.Executor;

/**
* Internal accessor for {@link ProtocolNegotiators}.
Expand All @@ -31,6 +33,36 @@

private InternalProtocolNegotiators() {}

ejona86 marked this conversation as resolved.
Show resolved Hide resolved
/**
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
* may happen immediately, even before the TLS Handshake is complete.
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
*/
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext);
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {

Check warning on line 45 in netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java#L44-L45

Added lines #L44 - L45 were not covered by tests

@Override
public AsciiString scheme() {
return negotiator.scheme();

Check warning on line 49 in netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java#L49

Added line #L49 was not covered by tests
}

@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return negotiator.newHandler(grpcHandler);

Check warning on line 54 in netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java#L54

Added line #L54 was not covered by tests
}

@Override
public void close() {
negotiator.close();
}

Check warning on line 60 in netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java#L59-L60

Added lines #L59 - L60 were not covered by tests
}

return new TlsNegotiator();

Check warning on line 63 in netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java#L63

Added line #L63 was not covered by tests
}

/**
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@
import io.grpc.MethodDescriptor;
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ConcurrentMap;
Expand Down Expand Up @@ -61,7 +57,6 @@
public final class S2AHandshakerServiceChannel {
private static final ConcurrentMap<String, Resource<Channel>> SHARED_RESOURCE_CHANNELS =
Maps.newConcurrentMap();
private static final Duration DELEGATE_TERMINATION_TIMEOUT = Duration.ofSeconds(2);
private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10);

/**
Expand Down Expand Up @@ -95,41 +90,34 @@ public ChannelResource(String targetAddress, Optional<ChannelCredentials> channe
}

/**
* Creates a {@code EventLoopHoldingChannel} instance to the service running at {@code
* targetAddress}. This channel uses a dedicated thread pool for its {@code EventLoopGroup}
* instance to avoid blocking.
* Creates a {@code HandshakerServiceChannel} instance to the service running at {@code
* targetAddress}.
*/
@Override
public Channel create() {
EventLoopGroup eventLoopGroup =
new NioEventLoopGroup(1, new DefaultThreadFactory("S2A channel pool", true));
ManagedChannel channel = null;
if (channelCredentials.isPresent()) {
// Create a secure channel.
channel =
NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get())
.channelType(NioSocketChannel.class)
.directExecutor()
.eventLoopGroup(eventLoopGroup)
.build();
} else {
// Create a plaintext channel.
channel =
NettyChannelBuilder.forTarget(targetAddress)
.channelType(NioSocketChannel.class)
.directExecutor()
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
return EventLoopHoldingChannel.create(channel, eventLoopGroup);
return HandshakerServiceChannel.create(channel);
}

/** Destroys a {@code EventLoopHoldingChannel} instance. */
/** Destroys a {@code HandshakerServiceChannel} instance. */
@Override
public void close(Channel instanceChannel) {
checkNotNull(instanceChannel);
EventLoopHoldingChannel channel = (EventLoopHoldingChannel) instanceChannel;
HandshakerServiceChannel channel = (HandshakerServiceChannel) instanceChannel;
channel.close();
}

Expand All @@ -140,23 +128,19 @@ public String toString() {
}

/**
* Manages a channel using a {@link ManagedChannel} instance that belong to the {@code
* EventLoopGroup} thread pool.
* Manages a channel using a {@link ManagedChannel} instance.
*/
@VisibleForTesting
static class EventLoopHoldingChannel extends Channel {
static class HandshakerServiceChannel extends Channel {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChannelResource doesn't need this class any longer. Delete it. (If you want to do it in a follow-up, that's fine.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a15421b

private final ManagedChannel delegate;
private final EventLoopGroup eventLoopGroup;

static EventLoopHoldingChannel create(ManagedChannel delegate, EventLoopGroup eventLoopGroup) {
static HandshakerServiceChannel create(ManagedChannel delegate) {
checkNotNull(delegate);
checkNotNull(eventLoopGroup);
return new EventLoopHoldingChannel(delegate, eventLoopGroup);
return new HandshakerServiceChannel(delegate);
}

private EventLoopHoldingChannel(ManagedChannel delegate, EventLoopGroup eventLoopGroup) {
private HandshakerServiceChannel(ManagedChannel delegate) {
this.delegate = delegate;
this.eventLoopGroup = eventLoopGroup;
}

/**
Expand All @@ -178,16 +162,11 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
@SuppressWarnings("FutureReturnValueIgnored")
public void close() {
delegate.shutdownNow();
boolean isDelegateTerminated;
try {
isDelegateTerminated =
delegate.awaitTermination(DELEGATE_TERMINATION_TIMEOUT.getSeconds(), SECONDS);
delegate.awaitTermination(CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
} catch (InterruptedException e) {
isDelegateTerminated = false;
Thread.currentThread().interrupt();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if shutdownNow didn't complete, you'll probably want to know about it for debugging, so I'd suggest logging a warning here.

Copy link
Contributor Author

@rmehta19 rmehta19 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 653ac1e

}
long quietPeriodSeconds = isDelegateTerminated ? 0 : 1;
eventLoopGroup.shutdownGracefully(
quietPeriodSeconds, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.google.common.util.concurrent.MoreExecutors;
import com.google.errorprone.annotations.ThreadSafe;
import io.grpc.Channel;
import io.grpc.internal.FixedObjectPool;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiator;
Expand Down Expand Up @@ -227,7 +228,9 @@ protected void handlerAdded0(ChannelHandlerContext ctx) {
@Override
public void onSuccess(SslContext sslContext) {
ChannelHandler handler =
InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);
InternalProtocolNegotiators.tls(
sslContext, new FixedObjectPool<>(Executors.newFixedThreadPool(1)))
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
.newHandler(grpcHandler);

// Remove the bufferReads handler and delegate the rest of the handshake to the TLS
// handler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import io.grpc.CallOptions;
import io.grpc.Channel;
Expand All @@ -39,15 +35,13 @@
import io.grpc.benchmarks.Utils;
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.s2a.channel.S2AHandshakerServiceChannel.EventLoopHoldingChannel;
import io.grpc.s2a.channel.S2AHandshakerServiceChannel.HandshakerServiceChannel;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
import io.netty.channel.EventLoopGroup;
import java.io.File;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
Expand All @@ -60,8 +54,6 @@
@RunWith(JUnit4.class)
public final class S2AHandshakerServiceChannelTest {
@ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10);
private final EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class);
private Server mtlsServer;
private Server plaintextServer;

Expand Down Expand Up @@ -191,7 +183,7 @@ public void close_mtlsSuccess() throws Exception {
}

/**
* Verifies that an {@code EventLoopHoldingChannel}'s {@code newCall} method can be used to
* Verifies that an {@code HandshakerServiceChannel}'s {@code newCall} method can be used to
* perform a simple RPC.
*/
@Test
Expand All @@ -201,7 +193,7 @@ public void newCall_performSimpleRpcSuccess() {
"localhost:" + plaintextServer.getPort(),
/* s2aChannelCredentials= */ Optional.empty());
Channel channel = resource.create();
assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channel).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()))
.isEqualToDefaultInstance();
Expand All @@ -214,53 +206,49 @@ public void newCall_mtlsPerformSimpleRpcSuccess() throws Exception {
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + mtlsServer.getPort(), getTlsChannelCredentials());
Channel channel = resource.create();
assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channel).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()))
.isEqualToDefaultInstance();
}

/** Creates a {@code EventLoopHoldingChannel} instance and verifies its authority. */
/** Creates a {@code HandshakerServiceChannel} instance and verifies its authority. */
@Test
public void authority_success() throws Exception {
ManagedChannel channel = new FakeManagedChannel(true);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
assertThat(eventLoopHoldingChannel.authority()).isEqualTo("FakeManagedChannel");
}

/**
* Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} terminates
* successfully.
* Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel}
* terminates successfully.
*/
@Test
public void close_withDelegateTerminatedSuccess() throws Exception {
ManagedChannel channel = new FakeManagedChannel(true);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
eventLoopHoldingChannel.close();
assertThat(channel.isShutdown()).isTrue();
verify(mockEventLoopGroup, times(1))
.shutdownGracefully(0, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}

/**
* Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} does not
* Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel} does not
* terminate successfully.
*/
@Test
public void close_withDelegateTerminatedFailure() throws Exception {
ManagedChannel channel = new FakeManagedChannel(false);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
eventLoopHoldingChannel.close();
assertThat(channel.isShutdown()).isTrue();
verify(mockEventLoopGroup, times(1))
.shutdownGracefully(1, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}

/**
* Creates and closes a {@code EventLoopHoldingChannel}, creates a new channel from the same
* Creates and closes a {@code HandshakerServiceChannel}, creates a new channel from the same
* resource, and verifies that this second channel is useable.
*/
@Test
Expand All @@ -273,7 +261,7 @@ public void create_succeedsAfterCloseIsCalledOnce() throws Exception {
resource.close(channelOne);

Channel channelTwo = resource.create();
assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channelTwo)
.unaryRpc(SimpleRequest.getDefaultInstance()))
Expand All @@ -291,7 +279,7 @@ public void create_mtlsSucceedsAfterCloseIsCalledOnce() throws Exception {
resource.close(channelOne);

Channel channelTwo = resource.create();
assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channelTwo)
.unaryRpc(SimpleRequest.getDefaultInstance()))
Expand Down