Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s2a: Address comments on PR#11113 #11534

Merged
merged 9 commits into from
Sep 20, 2024
21 changes: 7 additions & 14 deletions s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,16 @@
import static com.google.common.base.Strings.isNullOrEmpty;

import io.grpc.ChannelCredentials;
import io.grpc.ExperimentalApi;
import io.grpc.TlsChannelCredentials;
import io.grpc.util.AdvancedTlsX509KeyManager;
import io.grpc.util.AdvancedTlsX509TrustManager;
import java.io.File;
import java.io.IOException;
import java.security.GeneralSecurityException;

/**
* Configures an {@code S2AChannelCredentials.Builder} instance with credentials used to establish a
* connection with the S2A to support talking to the S2A over mTLS.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11533")
public final class MtlsToS2AChannelCredentials {
/**
* Creates a {@code S2AChannelCredentials.Builder} builder, that talks to the S2A over mTLS.
Expand All @@ -42,7 +41,7 @@ public final class MtlsToS2AChannelCredentials {
* @param trustBundlePath the path to the trust bundle PEM.
* @return a {@code MtlsToS2AChannelCredentials.Builder} instance.
*/
public static Builder createBuilder(
public static Builder newBuilder(
String s2aAddress, String privateKeyPath, String certChainPath, String trustBundlePath) {
checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty.");
checkArgument(!isNullOrEmpty(privateKeyPath), "privateKeyPath must not be null or empty.");
Expand All @@ -66,7 +65,7 @@ public static final class Builder {
this.trustBundlePath = trustBundlePath;
}

public S2AChannelCredentials.Builder build() throws GeneralSecurityException, IOException {
public S2AChannelCredentials.Builder build() throws IOException {
checkState(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty.");
checkState(!isNullOrEmpty(privateKeyPath), "privateKeyPath must not be null or empty.");
checkState(!isNullOrEmpty(certChainPath), "certChainPath must not be null or empty.");
Expand All @@ -75,19 +74,13 @@ public S2AChannelCredentials.Builder build() throws GeneralSecurityException, IO
File certChainFile = new File(certChainPath);
File trustBundleFile = new File(trustBundlePath);

AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager();
keyManager.updateIdentityCredentials(certChainFile, privateKeyFile);

AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build();
trustManager.updateTrustCredentials(trustBundleFile);

ChannelCredentials channelToS2ACredentials =
TlsChannelCredentials.newBuilder()
.keyManager(keyManager)
.trustManager(trustManager)
.keyManager(certChainFile, privateKeyFile)
.trustManager(trustBundleFile)
.build();

return S2AChannelCredentials.createBuilder(s2aAddress)
return S2AChannelCredentials.newBuilder(s2aAddress)
.setS2AChannelCredentials(channelToS2ACredentials);
}
}
Expand Down
12 changes: 7 additions & 5 deletions s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,31 @@
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.grpc.Channel;
import io.grpc.ChannelCredentials;
import io.grpc.ExperimentalApi;
import io.grpc.InsecureChannelCredentials;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.InternalNettyChannelCredentials;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.s2a.channel.S2AHandshakerServiceChannel;
import io.grpc.s2a.handshaker.S2AIdentity;
import io.grpc.s2a.handshaker.S2AProtocolNegotiatorFactory;
import java.util.Optional;
import javax.annotation.concurrent.NotThreadSafe;
import org.checkerframework.checker.nullness.qual.Nullable;

/**
* Configures gRPC to use S2A for transport security when establishing a secure channel. Only for
* use on the client side of a gRPC connection.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11533")
public final class S2AChannelCredentials {
/**
* Creates a channel credentials builder for establishing an S2A-secured connection.
*
* @param s2aAddress the address of the S2A server used to secure the connection.
* @return a {@code S2AChannelCredentials.Builder} instance.
*/
public static Builder createBuilder(String s2aAddress) {
public static Builder newBuilder(String s2aAddress) {
checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty.");
return new Builder(s2aAddress);
}
Expand All @@ -56,13 +58,13 @@ public static Builder createBuilder(String s2aAddress) {
public static final class Builder {
private final String s2aAddress;
private ObjectPool<Channel> s2aChannelPool;
private Optional<ChannelCredentials> s2aChannelCredentials;
private ChannelCredentials s2aChannelCredentials;
private @Nullable S2AIdentity localIdentity = null;

Builder(String s2aAddress) {
this.s2aAddress = s2aAddress;
this.s2aChannelPool = null;
this.s2aChannelCredentials = Optional.empty();
this.s2aChannelCredentials = InsecureChannelCredentials.create();
}

/**
Expand Down Expand Up @@ -107,7 +109,7 @@ public Builder setLocalUid(String localUid) {
/** Sets the credentials to be used when connecting to the S2A. */
@CanIgnoreReturnValue
public Builder setS2AChannelCredentials(ChannelCredentials s2aChannelCredentials) {
this.s2aChannelCredentials = Optional.of(s2aChannelCredentials);
this.s2aChannelCredentials = s2aChannelCredentials;
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyChannelBuilder;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand Down Expand Up @@ -71,8 +70,9 @@ public final class S2AHandshakerServiceChannel {
* running at {@code s2aAddress}.
*/
public static Resource<Channel> getChannelResource(
String s2aAddress, Optional<ChannelCredentials> s2aChannelCredentials) {
String s2aAddress, ChannelCredentials s2aChannelCredentials) {
checkNotNull(s2aAddress);
checkNotNull(s2aChannelCredentials);
return SHARED_RESOURCE_CHANNELS.computeIfAbsent(
s2aAddress, channelResource -> new ChannelResource(s2aAddress, s2aChannelCredentials));
}
Expand All @@ -84,9 +84,9 @@ public static Resource<Channel> getChannelResource(
*/
private static class ChannelResource implements Resource<Channel> {
private final String targetAddress;
private final Optional<ChannelCredentials> channelCredentials;
private final ChannelCredentials channelCredentials;

public ChannelResource(String targetAddress, Optional<ChannelCredentials> channelCredentials) {
public ChannelResource(String targetAddress, ChannelCredentials channelCredentials) {
this.targetAddress = targetAddress;
this.channelCredentials = channelCredentials;
}
Expand All @@ -97,21 +97,10 @@ public ChannelResource(String targetAddress, Optional<ChannelCredentials> channe
*/
@Override
public Channel create() {
ManagedChannel channel = null;
if (channelCredentials.isPresent()) {
// Create a secure channel.
channel =
NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get())
.directExecutor()
.build();
} else {
// Create a plaintext channel.
channel =
NettyChannelBuilder.forTarget(targetAddress)
.directExecutor()
.usePlaintext()
.build();
}
ManagedChannel channel =
NettyChannelBuilder.forTarget(targetAddress, channelCredentials)
.directExecutor()
.build();
return HandshakerServiceChannel.create(channel);
}

Expand Down
3 changes: 3 additions & 0 deletions s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ private void checkPeerTrusted(X509Certificate[] chain, boolean isCheckingClientC
try {
resp = stub.send(reqBuilder.build());
} catch (IOException | InterruptedException e) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This catch needs to be split, because we don't want to interrupt on IOException.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out, done in b060a49

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more standard as the following (but I don't feel strongly about it as it requires the CertificateException definition to be duplicated)

    } catch (IOException e) {
      Thread.currentThread().interrupt();
      throw new CertificateException("Failed to send request to S2A.", e);
    } catch (InterruptedException e) {
      throw new CertificateException("Failed to send request to S2A.", e);
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Larry, I'll address this in a followup along with #11539 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in f264c58

if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
throw new CertificateException("Failed to send request to S2A.", e);
}
if (resp.hasStatus() && resp.getStatus().getCode() != 0) {
Expand Down
34 changes: 17 additions & 17 deletions s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,95 +26,95 @@
@RunWith(JUnit4.class)
public final class MtlsToS2AChannelCredentialsTest {
@Test
public void createBuilder_nullAddress_throwsException() throws Exception {
public void newBuilder_nullAddress_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ null,
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "src/test/resources/client_cert.pem",
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_nullPrivateKeyPath_throwsException() throws Exception {
public void newBuilder_nullPrivateKeyPath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ null,
/* certChainPath= */ "src/test/resources/client_cert.pem",
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_nullCertChainPath_throwsException() throws Exception {
public void newBuilder_nullCertChainPath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ null,
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_nullTrustBundlePath_throwsException() throws Exception {
public void newBuilder_nullTrustBundlePath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "src/test/resources/client_cert.pem",
/* trustBundlePath= */ null));
}

@Test
public void createBuilder_emptyAddress_throwsException() throws Exception {
public void newBuilder_emptyAddress_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "src/test/resources/client_cert.pem",
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_emptyPrivateKeyPath_throwsException() throws Exception {
public void newBuilder_emptyPrivateKeyPath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "",
/* certChainPath= */ "src/test/resources/client_cert.pem",
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_emptyCertChainPath_throwsException() throws Exception {
public void newBuilder_emptyCertChainPath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "",
/* trustBundlePath= */ "src/test/resources/root_cert.pem"));
}

@Test
public void createBuilder_emptyTrustBundlePath_throwsException() throws Exception {
public void newBuilder_emptyTrustBundlePath_throwsException() throws Exception {
assertThrows(
IllegalArgumentException.class,
() ->
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "src/test/resources/client_cert.pem",
Expand All @@ -124,7 +124,7 @@ public void createBuilder_emptyTrustBundlePath_throwsException() throws Exceptio
@Test
public void build_s2AChannelCredentials_success() throws Exception {
assertThat(
MtlsToS2AChannelCredentials.createBuilder(
MtlsToS2AChannelCredentials.newBuilder(
/* s2aAddress= */ "s2a_address",
/* privateKeyPath= */ "src/test/resources/client_key.pem",
/* certChainPath= */ "src/test/resources/client_cert.pem",
Expand Down
24 changes: 12 additions & 12 deletions s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,40 @@
@RunWith(JUnit4.class)
public final class S2AChannelCredentialsTest {
@Test
public void createBuilder_nullArgument_throwsException() throws Exception {
assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.createBuilder(null));
public void newBuilder_nullArgument_throwsException() throws Exception {
assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.newBuilder(null));
}

@Test
public void createBuilder_emptyAddress_throwsException() throws Exception {
assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.createBuilder(""));
public void newBuilder_emptyAddress_throwsException() throws Exception {
assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.newBuilder(""));
}

@Test
public void setLocalSpiffeId_nullArgument_throwsException() throws Exception {
assertThrows(
NullPointerException.class,
() -> S2AChannelCredentials.createBuilder("s2a_address").setLocalSpiffeId(null));
() -> S2AChannelCredentials.newBuilder("s2a_address").setLocalSpiffeId(null));
}

@Test
public void setLocalHostname_nullArgument_throwsException() throws Exception {
assertThrows(
NullPointerException.class,
() -> S2AChannelCredentials.createBuilder("s2a_address").setLocalHostname(null));
() -> S2AChannelCredentials.newBuilder("s2a_address").setLocalHostname(null));
}

@Test
public void setLocalUid_nullArgument_throwsException() throws Exception {
assertThrows(
NullPointerException.class,
() -> S2AChannelCredentials.createBuilder("s2a_address").setLocalUid(null));
() -> S2AChannelCredentials.newBuilder("s2a_address").setLocalUid(null));
}

@Test
public void build_withLocalSpiffeId_succeeds() throws Exception {
assertThat(
S2AChannelCredentials.createBuilder("s2a_address")
S2AChannelCredentials.newBuilder("s2a_address")
.setLocalSpiffeId("spiffe://test")
.build())
.isNotNull();
Expand All @@ -72,28 +72,28 @@ public void build_withLocalSpiffeId_succeeds() throws Exception {
@Test
public void build_withLocalHostname_succeeds() throws Exception {
assertThat(
S2AChannelCredentials.createBuilder("s2a_address")
S2AChannelCredentials.newBuilder("s2a_address")
.setLocalHostname("local_hostname")
.build())
.isNotNull();
}

@Test
public void build_withLocalUid_succeeds() throws Exception {
assertThat(S2AChannelCredentials.createBuilder("s2a_address").setLocalUid("local_uid").build())
assertThat(S2AChannelCredentials.newBuilder("s2a_address").setLocalUid("local_uid").build())
.isNotNull();
}

@Test
public void build_withNoLocalIdentity_succeeds() throws Exception {
assertThat(S2AChannelCredentials.createBuilder("s2a_address").build())
assertThat(S2AChannelCredentials.newBuilder("s2a_address").build())
.isNotNull();
}

@Test
public void build_withTlsChannelCredentials_succeeds() throws Exception {
assertThat(
S2AChannelCredentials.createBuilder("s2a_address")
S2AChannelCredentials.newBuilder("s2a_address")
.setLocalSpiffeId("spiffe://test")
.setS2AChannelCredentials(getTlsChannelCredentials())
.build())
Expand Down
Loading