Skip to content

Commit

Permalink
s2a: Address comments on PR#11113 (#11534)
Browse files Browse the repository at this point in the history
* Mark S2A public APIs as experimental.

* Rename S2AChannelCredentials createBuilder API to newBuilder.

* Remove usage of AdvancedTls.

* Use InsecureChannelCredentials.create instead of Optional.

* Invoke Thread.currentThread().interrupt() in a InterruptedException block.
  • Loading branch information
rmehta19 authored Sep 20, 2024
1 parent e75a044 commit d8f73e0
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 87 deletions.
21 changes: 7 additions & 14 deletions s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,16 @@
import static com.google.common.base.Strings.isNullOrEmpty;

import io.grpc.ChannelCredentials;
import io.grpc.ExperimentalApi;
import io.grpc.TlsChannelCredentials;
import io.grpc.util.AdvancedTlsX509KeyManager;
import io.grpc.util.AdvancedTlsX509TrustManager;
import java.io.File;
import java.io.IOException;
import java.security.GeneralSecurityException;

/**
* Configures an {@code S2AChannelCredentials.Builder} instance with credentials used to establish a
* connection with the S2A to support talking to the S2A over mTLS.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11533")
public final class MtlsToS2AChannelCredentials {
/**
* Creates a {@code S2AChannelCredentials.Builder} builder, that talks to the S2A over mTLS.
Expand All @@ -42,7 +41,7 @@ public final class MtlsToS2AChannelCredentials {
* @param trustBundlePath the path to the trust bundle PEM.
* @return a {@code MtlsToS2AChannelCredentials.Builder} instance.
*/
public static Builder createBuilder(
public static Builder newBuilder(
String s2aAddress, String privateKeyPath, String certChainPath, String trustBundlePath) {
checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty.");
checkArgument(!isNullOrEmpty(privateKeyPath), "privateKeyPath must not be null or empty.");
Expand All @@ -66,7 +65,7 @@ public static final class Builder {
this.trustBundlePath = trustBundlePath;
}

public S2AChannelCredentials.Builder build() throws GeneralSecurityException, IOException {
public S2AChannelCredentials.Builder build() throws IOException {
checkState(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty.");
checkState(!isNullOrEmpty(privateKeyPath), "privateKeyPath must not be null or empty.");
checkState(!isNullOrEmpty(certChainPath), "certChainPath must not be null or empty.");
Expand All @@ -75,19 +74,13 @@ public S2AChannelCredentials.Builder build() throws GeneralSecurityException, IO
File certChainFile = new File(certChainPath);
File trustBundleFile = new File(trustBundlePath);

AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager();
keyManager.updateIdentityCredentials(certChainFile, privateKeyFile);

AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build();
trustManager.updateTrustCredentials(trustBundleFile);

ChannelCredentials channelToS2ACredentials =
TlsChannelCredentials.newBuilder()
.keyManager(keyManager)
.trustManager(trustManager)
.keyManager(certChainFile, privateKeyFile)
.trustManager(trustBundleFile)
.build();

return S2AChannelCredentials.createBuilder(s2aAddress)
return S2AChannelCredentials.newBuilder(s2aAddress)
.setS2AChannelCredentials(channelToS2ACredentials);
}
}
Expand Down
12 changes: 7 additions & 5 deletions s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,31 @@
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.grpc.Channel;
import io.grpc.ChannelCredentials;
import io.grpc.ExperimentalApi;
import io.grpc.InsecureChannelCredentials;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.InternalNettyChannelCredentials;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.s2a.channel.S2AHandshakerServiceChannel;
import io.grpc.s2a.handshaker.S2AIdentity;
import io.grpc.s2a.handshaker.S2AProtocolNegotiatorFactory;
import java.util.Optional;
import javax.annotation.concurrent.NotThreadSafe;
import org.checkerframework.checker.nullness.qual.Nullable;

/**
* Configures gRPC to use S2A for transport security when establishing a secure channel. Only for
* use on the client side of a gRPC connection.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11533")
public final class S2AChannelCredentials {
/**
* Creates a channel credentials builder for establishing an S2A-secured connection.
*
* @param s2aAddress the address of the S2A server used to secure the connection.
* @return a {@code S2AChannelCredentials.Builder} instance.
*/
public static Builder createBuilder(String s2aAddress) {
public static Builder newBuilder(String s2aAddress) {
checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty.");
return new Builder(s2aAddress);
}
Expand All @@ -56,13 +58,13 @@ public static Builder createBuilder(String s2aAddress) {
public static final class Builder {
private final String s2aAddress;
private ObjectPool<Channel> s2aChannelPool;
private Optional<ChannelCredentials> s2aChannelCredentials;
private ChannelCredentials s2aChannelCredentials;
private @Nullable S2AIdentity localIdentity = null;

Builder(String s2aAddress) {
this.s2aAddress = s2aAddress;
this.s2aChannelPool = null;
this.s2aChannelCredentials = Optional.empty();
this.s2aChannelCredentials = InsecureChannelCredentials.create();
}

/**
Expand Down Expand Up @@ -107,7 +109,7 @@ public Builder setLocalUid(String localUid) {
/** Sets the credentials to be used when connecting to the S2A. */
@CanIgnoreReturnValue
public Builder setS2AChannelCredentials(ChannelCredentials s2aChannelCredentials) {
this.s2aChannelCredentials = Optional.of(s2aChannelCredentials);
this.s2aChannelCredentials = s2aChannelCredentials;
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyChannelBuilder;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand Down Expand Up @@ -71,8 +70,9 @@ public final class S2AHandshakerServiceChannel {
* running at {@code s2aAddress}.
*/
public static Resource<Channel> getChannelResource(
String s2aAddress, Optional<ChannelCredentials> s2aChannelCredentials) {
String s2aAddress, ChannelCredentials s2aChannelCredentials) {
checkNotNull(s2aAddress);
checkNotNull(s2aChannelCredentials);
return SHARED_RESOURCE_CHANNELS.computeIfAbsent(
s2aAddress, channelResource -> new ChannelResource(s2aAddress, s2aChannelCredentials));
}
Expand All @@ -84,9 +84,9 @@ public static Resource<Channel> getChannelResource(
*/
private static class ChannelResource implements Resource<Channel> {
private final String targetAddress;
private final Optional<ChannelCredentials> channelCredentials;
private final ChannelCredentials channelCredentials;

public ChannelResource(String targetAddress, Optional<ChannelCredentials> channelCredentials) {
public ChannelResource(String targetAddress, ChannelCredentials channelCredentials) {
this.targetAddress = targetAddress;
this.channelCredentials = channelCredentials;
}
Expand All @@ -97,21 +97,10 @@ public ChannelResource(String targetAddress, Optional<ChannelCredentials> channe
*/
@Override
public Channel create() {
ManagedChannel channel = null;
if (channelCredentials.isPresent()) {
// Create a secure channel.
channel =
NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get())
.directExecutor()
.build();
} else {
// Create a plaintext channel.
channel =
NettyChannelBuilder.forTarget(targetAddress)
.directExecutor()
.usePlaintext()
.build();
}
ManagedChannel channel =
NettyChannelBuilder.forTarget(targetAddress, channelCredentials)
.directExecutor()
.build();
return HandshakerServiceChannel.create(channel);
}

Expand Down
3 changes: 3 additions & 0 deletions s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ private void checkPeerTrusted(X509Certificate[] chain, boolean isCheckingClientC
try {
resp = stub.send(reqBuilder.build());
} catch (IOException | InterruptedException e) {
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
throw new CertificateException("Failed to send request to S2A.", e);
}
if (resp.hasStatus() && resp.getStatus().getCode() != 0) {
Expand Down
34 changes: 17 additions & 17 deletions s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,95 +26,95 @@
@RunWith(JUnit4.class)
public final class MtlsToS2AChannelCredentialsTest {
@Test
public void createBuilder_nullAddress_throwsException() throws Exception {
public void newBuilder_nullAddress_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ null,
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "src/test/resources/client_cert.pem",
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_nullPrivateKeyPath_throwsException() throws Exception {
public void newBuilder_nullPrivateKeyPath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ null,
/* certChainPath= */ "src/test/resources/client_cert.pem",
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_nullCertChainPath_throwsException() throws Exception {
public void newBuilder_nullCertChainPath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ null,
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_nullTrustBundlePath_throwsException() throws Exception {
public void newBuilder_nullTrustBundlePath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "src/test/resources/client_cert.pem",
/* trustBundlePath= */ null));
}

@Test
public void createBuilder_emptyAddress_throwsException() throws Exception {
public void newBuilder_emptyAddress_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "src/test/resources/client_cert.pem",
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_emptyPrivateKeyPath_throwsException() throws Exception {
public void newBuilder_emptyPrivateKeyPath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "",
/* certChainPath= */ "src/test/resources/client_cert.pem",
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_emptyCertChainPath_throwsException() throws Exception {
public void newBuilder_emptyCertChainPath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "",
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_emptyTrustBundlePath_throwsException() throws Exception {
public void newBuilder_emptyTrustBundlePath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "src/test/resources/client_cert.pem",
Expand All @@ -124,7 +124,7 @@ public void createBuilder_emptyTrustBundlePath_throwsException() throws Exceptio
@Test
public void build_s2AChannelCredentials_success() throws Exception {
assertThat(
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "src/test/resources/client_cert.pem",
Expand Down
24 changes: 12 additions & 12 deletions s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,40 @@
@RunWith(JUnit4.class)
public final class S2AChannelCredentialsTest {
@Test
public void createBuilder_nullArgument_throwsException() throws Exception {
assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.createBuilder(null));
public void newBuilder_nullArgument_throwsException() throws Exception {
assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.newBuilder(null));
}

@Test
public void createBuilder_emptyAddress_throwsException() throws Exception {
assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.createBuilder(""));
public void newBuilder_emptyAddress_throwsException() throws Exception {
assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.newBuilder(""));
}

@Test
public void setLocalSpiffeId_nullArgument_throwsException() throws Exception {
assertThrows(
NullPointerException.class,
() -> S2AChannelCredentials.createBuilder("s2a_address").setLocalSpiffeId(null));
() -> S2AChannelCredentials.newBuilder("s2a_address").setLocalSpiffeId(null));
}

@Test
public void setLocalHostname_nullArgument_throwsException() throws Exception {
assertThrows(
NullPointerException.class,
() -> S2AChannelCredentials.createBuilder("s2a_address").setLocalHostname(null));
() -> S2AChannelCredentials.newBuilder("s2a_address").setLocalHostname(null));
}

@Test
public void setLocalUid_nullArgument_throwsException() throws Exception {
assertThrows(
NullPointerException.class,
() -> S2AChannelCredentials.createBuilder("s2a_address").setLocalUid(null));
() -> S2AChannelCredentials.newBuilder("s2a_address").setLocalUid(null));
}

@Test
public void build_withLocalSpiffeId_succeeds() throws Exception {
assertThat(
S2AChannelCredentials.createBuilder("s2a_address")
S2AChannelCredentials.newBuilder("s2a_address")
.setLocalSpiffeId("spiffe://test")
.build())
.isNotNull();
Expand All @@ -72,28 +72,28 @@ public void build_withLocalSpiffeId_succeeds() throws Exception {
@Test
public void build_withLocalHostname_succeeds() throws Exception {
assertThat(
S2AChannelCredentials.createBuilder("s2a_address")
S2AChannelCredentials.newBuilder("s2a_address")
.setLocalHostname("local_hostname")
.build())
.isNotNull();
}

@Test
public void build_withLocalUid_succeeds() throws Exception {
assertThat(S2AChannelCredentials.createBuilder("s2a_address").setLocalUid("local_uid").build())
assertThat(S2AChannelCredentials.newBuilder("s2a_address").setLocalUid("local_uid").build())
.isNotNull();
}

@Test
public void build_withNoLocalIdentity_succeeds() throws Exception {
assertThat(S2AChannelCredentials.createBuilder("s2a_address").build())
assertThat(S2AChannelCredentials.newBuilder("s2a_address").build())
.isNotNull();
}

@Test
public void build_withTlsChannelCredentials_succeeds() throws Exception {
assertThat(
S2AChannelCredentials.createBuilder("s2a_address")
S2AChannelCredentials.newBuilder("s2a_address")
.setLocalSpiffeId("spiffe://test")
.setS2AChannelCredentials(getTlsChannelCredentials())
.build())
Expand Down
Loading

0 comments on commit d8f73e0

Please sign in to comment.