diff --git a/MODULE.bazel b/MODULE.bazel index 1d79e362e11..78e6ccb70f2 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -41,7 +41,9 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "io.perfmark:perfmark-api:0.26.0", "junit:junit:4.13.2", "org.apache.tomcat:annotations-api:6.0.53", + "org.checkerframework:checker-qual:3.12.0", "org.codehaus.mojo:animal-sniffer-annotations:1.23", + "org.jcommander:jcommander:1.83", ] # GRPC_DEPS_END diff --git a/repositories.bzl b/repositories.bzl index 1f422d3380f..7ed5141fec3 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -45,7 +45,9 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "io.perfmark:perfmark-api:0.26.0", "junit:junit:4.13.2", "org.apache.tomcat:annotations-api:6.0.53", + "org.checkerframework:checker-qual:3.12.0", "org.codehaus.mojo:animal-sniffer-annotations:1.23", + "org.jcommander:jcommander:1.83", ] # GRPC_DEPS_END @@ -80,6 +82,7 @@ IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS = { "io.grpc:grpc-rls": "@io_grpc_grpc_java//rls", "io.grpc:grpc-services": "@io_grpc_grpc_java//services:services_maven", "io.grpc:grpc-stub": "@io_grpc_grpc_java//stub", + "io.grpc:grpc-s2a": "@io_grpc_grpc_java//s2a", "io.grpc:grpc-testing": "@io_grpc_grpc_java//testing", "io.grpc:grpc-xds": "@io_grpc_grpc_java//xds:xds_maven", "io.grpc:grpc-util": "@io_grpc_grpc_java//util", diff --git a/s2a/BUILD.bazel b/s2a/BUILD.bazel new file mode 100644 index 00000000000..0041ad52be6 --- /dev/null +++ b/s2a/BUILD.bazel @@ -0,0 +1,194 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load("//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + +java_library( + name = "s2a_channel_pool", + srcs = glob([ + "src/main/java/io/grpc/s2a/channel/*.java", + ]), + deps = [ + "//api", + "//core", + "//core:internal", + "//netty", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.checkerframework:checker-qual"), + artifact("io.netty:netty-common"), + artifact("io.netty:netty-transport"), + ], +) + +java_library( + name = "s2a_identity", + srcs = ["src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java"], + deps = [ + ":common_java_proto", + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + ], +) + +java_library( + name = "token_fetcher", + srcs = ["src/main/java/io/grpc/s2a/handshaker/tokenmanager/TokenFetcher.java"], + deps = [ + ":s2a_identity", + ], +) + +java_library( + name = "access_token_manager", + srcs = [ + "src/main/java/io/grpc/s2a/handshaker/tokenmanager/AccessTokenManager.java", + ], + deps = [ + ":s2a_identity", + ":token_fetcher", + artifact("com.google.code.findbugs:jsr305"), + ], +) + +java_library( + name = "single_token_fetcher", + srcs = [ + "src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java", + ], + deps = [ + ":s2a_identity", + ":token_fetcher", + artifact("org.jcommander:jcommander"), + ], +) + +java_library( + name = "s2a_handshaker", + srcs = [ + "src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java", + "src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java", + "src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java", + "src/main/java/io/grpc/s2a/handshaker/S2AConnectionException.java", + "src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java", + "src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java", + "src/main/java/io/grpc/s2a/handshaker/S2AStub.java", + "src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java", + "src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java", + ], + deps = [ + ":access_token_manager", + ":common_java_proto", + ":s2a_channel_pool", + ":s2a_identity", + ":s2a_java_proto", + ":s2a_java_grpc_proto", + ":single_token_fetcher", + "//api", + "//core:internal", + "//netty", + "//stub", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.checkerframework:checker-qual"), + "@com_google_protobuf//:protobuf_java", + artifact("io.netty:netty-common"), + artifact("io.netty:netty-handler"), + artifact("io.netty:netty-transport"), + ], +) + +java_library( + name = "s2av2_credentials", + srcs = ["src/main/java/io/grpc/s2a/S2AChannelCredentials.java"], + visibility = ["//visibility:public"], + deps = [ + ":s2a_channel_pool", + ":s2a_handshaker", + ":s2a_identity", + "//api", + "//core:internal", + "//netty", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.checkerframework:checker-qual"), + ], +) + +java_library( + name = "mtls_to_s2av2_credentials", + srcs = ["src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java"], + visibility = ["//visibility:public"], + deps = [ + ":s2a_channel_pool", + ":s2av2_credentials", + "//api", + "//util", + artifact("com.google.guava:guava"), + ], +) + +# bazel only accepts proto import with absolute path. +genrule( + name = "protobuf_imports", + srcs = glob(["src/main/proto/grpc/gcp/*.proto"]), + outs = [ + "protobuf_out/grpc/gcp/s2a.proto", + "protobuf_out/grpc/gcp/s2a_context.proto", + "protobuf_out/grpc/gcp/common.proto", + ], + cmd = "for fname in $(SRCS); do " + + "sed 's,import \",import \"s2a/protobuf_out/,g' $$fname > " + + "$(@D)/protobuf_out/grpc/gcp/$$(basename $$fname); done", +) + +proto_library( + name = "common_proto", + srcs = [ + "protobuf_out/grpc/gcp/common.proto", + ], +) + +proto_library( + name = "s2a_context_proto", + srcs = [ + "protobuf_out/grpc/gcp/s2a_context.proto", + ], + deps = [ + ":common_proto", + ], +) + +proto_library( + name = "s2a_proto", + srcs = [ + "protobuf_out/grpc/gcp/s2a.proto", + ], + deps = [ + ":common_proto", + ":s2a_context_proto", + ], +) + +java_proto_library( + name = "s2a_java_proto", + deps = [":s2a_proto"], +) + +java_proto_library( + name = "s2a_context_java_proto", + deps = [":s2a_context_proto"], +) + +java_proto_library( + name = "common_java_proto", + deps = [":common_proto"], +) + +java_grpc_library( + name = "s2a_java_grpc_proto", + srcs = [":s2a_proto"], + deps = [":s2a_java_proto"], +) diff --git a/s2a/build.gradle b/s2a/build.gradle new file mode 100644 index 00000000000..054039571d8 --- /dev/null +++ b/s2a/build.gradle @@ -0,0 +1,153 @@ +buildscript { + dependencies { + classpath 'com.google.gradle:osdetector-gradle-plugin:1.4.0' + } +} + +plugins { + id "java-library" + id "maven-publish" + + id "com.github.johnrengelman.shadow" + id "com.google.protobuf" + id "ru.vyarus.animalsniffer" +} + +description = "gRPC: S2A" + +apply plugin: "com.google.osdetector" + +dependencies { + + api project(':grpc-api') + implementation project(':grpc-stub'), + project(':grpc-protobuf'), + project(':grpc-core'), + libraries.protobuf.java, + libraries.conscrypt, + libraries.guava.jre // JRE required by protobuf-java-util from grpclb + compileOnly 'org.jcommander:jcommander:1.83' + def nettyDependency = implementation project(':grpc-netty') + compileOnly libraries.javax.annotation + + shadow configurations.implementation.getDependencies().minus(nettyDependency) + shadow project(path: ':grpc-netty-shaded', configuration: 'shadow') + + testImplementation project(':grpc-benchmarks'), + project(':grpc-testing'), + project(':grpc-testing-proto'), + testFixtures(project(':grpc-core')), + libraries.guava, + libraries.junit, + libraries.mockito.core, + libraries.truth, + libraries.conscrypt, + libraries.netty.transport.epoll + + testImplementation 'org.jcommander:jcommander:1.83' + testImplementation 'com.google.truth:truth:1.4.2' + testImplementation 'com.google.truth.extensions:truth-proto-extension:1.4.2' + testImplementation libraries.guava.testlib + + testRuntimeOnly libraries.netty.tcnative, + libraries.netty.tcnative.classes + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-x86_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-aarch_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-x86_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-aarch_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "windows-x86_64" + } + } + testRuntimeOnly (libraries.netty.transport.epoll) { + artifact { + classifier = "linux-x86_64" + } + } + + signature libraries.signature.java +} + +tasks.named("compileJava") { + dependsOn(tasks.named("generateProto")) + //dependsOn(tasks.named("syncGeneratedSourcesmain")) +} + + +tasks.named("sourcesJar") { + dependsOn(tasks.named("generateProto")) + //dependsOn(tasks.named("syncGeneratedSourcesmain")) +} + +sourceSets { + main { + //java.srcDirs += "src/generated/main/java" + //java.srcDirs += "src/generated/main/grpc" + } +} +//println sourceSets.main.java.srcDirs +//println sourceSets.test.resources.srcDirs + +configureProtoCompilation() + +tasks.named("javadoc").configure { + exclude 'io/grpc/s2a/**' +} + +tasks.named("jar").configure { + // Must use a different archiveClassifier to avoid conflicting with shadowJar + archiveClassifier = 'original' + manifest { + attributes('Automatic-Module-Name': 'io.grpc.s2a') + } +} + +// We want to use grpc-netty-shaded instead of grpc-netty. But we also want our +// source to work with Bazel, so we rewrite the code as part of the build. +tasks.named("shadowJar").configure { + archiveClassifier = null + dependencies { + exclude(dependency {true}) + } + relocate 'io.grpc.netty', 'io.grpc.netty.shaded.io.grpc.netty' + relocate 'io.netty', 'io.grpc.netty.shaded.io.netty' +} + +publishing { + publications { + maven(MavenPublication) { + // We want this to throw an exception if it isn't working + def originalJar = artifacts.find { dep -> dep.classifier == 'original'} + artifacts.remove(originalJar) + + pom.withXml { + def dependenciesNode = new Node(null, 'dependencies') + project.configurations.shadow.allDependencies.each { dep -> + def dependencyNode = dependenciesNode.appendNode('dependency') + dependencyNode.appendNode('groupId', dep.group) + dependencyNode.appendNode('artifactId', dep.name) + dependencyNode.appendNode('version', dep.version) + dependencyNode.appendNode('scope', 'compile') + } + asNode().dependencies[0].replaceNode(dependenciesNode) + } + } + } +} diff --git a/s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java b/s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java new file mode 100644 index 00000000000..fd6b991c039 --- /dev/null +++ b/s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java @@ -0,0 +1,285 @@ +package io.grpc.s2a.handshaker; + +import static io.grpc.MethodDescriptor.generateFullMethodName; + +/** + */ +@javax.annotation.Generated( + value = "by gRPC proto compiler", + comments = "Source: grpc/gcp/s2a.proto") +@io.grpc.stub.annotations.GrpcGenerated +public final class S2AServiceGrpc { + + private S2AServiceGrpc() {} + + public static final java.lang.String SERVICE_NAME = "grpc.gcp.S2AService"; + + // Static method descriptors that strictly reflect the proto. + private static volatile io.grpc.MethodDescriptor getSetUpSessionMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "SetUpSession", + requestType = io.grpc.s2a.handshaker.SessionReq.class, + responseType = io.grpc.s2a.handshaker.SessionResp.class, + methodType = io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + public static io.grpc.MethodDescriptor getSetUpSessionMethod() { + io.grpc.MethodDescriptor getSetUpSessionMethod; + if ((getSetUpSessionMethod = S2AServiceGrpc.getSetUpSessionMethod) == null) { + synchronized (S2AServiceGrpc.class) { + if ((getSetUpSessionMethod = S2AServiceGrpc.getSetUpSessionMethod) == null) { + S2AServiceGrpc.getSetUpSessionMethod = getSetUpSessionMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "SetUpSession")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.grpc.s2a.handshaker.SessionReq.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.grpc.s2a.handshaker.SessionResp.getDefaultInstance())) + .setSchemaDescriptor(new S2AServiceMethodDescriptorSupplier("SetUpSession")) + .build(); + } + } + } + return getSetUpSessionMethod; + } + + /** + * Creates a new async stub that supports all call types for the service + */ + public static S2AServiceStub newStub(io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public S2AServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceStub(channel, callOptions); + } + }; + return S2AServiceStub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports unary and streaming output calls on the service + */ + public static S2AServiceBlockingStub newBlockingStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public S2AServiceBlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceBlockingStub(channel, callOptions); + } + }; + return S2AServiceBlockingStub.newStub(factory, channel); + } + + /** + * Creates a new ListenableFuture-style stub that supports unary calls on the service + */ + public static S2AServiceFutureStub newFutureStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public S2AServiceFutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceFutureStub(channel, callOptions); + } + }; + return S2AServiceFutureStub.newStub(factory, channel); + } + + /** + */ + public interface AsyncService { + + /** + *
+     * SetUpSession is a bidirectional stream used by applications to offload
+     * operations from the TLS handshake.
+     * 
+ */ + default io.grpc.stub.StreamObserver setUpSession( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall(getSetUpSessionMethod(), responseObserver); + } + } + + /** + * Base class for the server implementation of the service S2AService. + */ + public static abstract class S2AServiceImplBase + implements io.grpc.BindableService, AsyncService { + + @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { + return S2AServiceGrpc.bindService(this); + } + } + + /** + * A stub to allow clients to do asynchronous rpc calls to service S2AService. + */ + public static final class S2AServiceStub + extends io.grpc.stub.AbstractAsyncStub { + private S2AServiceStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected S2AServiceStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceStub(channel, callOptions); + } + + /** + *
+     * SetUpSession is a bidirectional stream used by applications to offload
+     * operations from the TLS handshake.
+     * 
+ */ + public io.grpc.stub.StreamObserver setUpSession( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ClientCalls.asyncBidiStreamingCall( + getChannel().newCall(getSetUpSessionMethod(), getCallOptions()), responseObserver); + } + } + + /** + * A stub to allow clients to do synchronous rpc calls to service S2AService. + */ + public static final class S2AServiceBlockingStub + extends io.grpc.stub.AbstractBlockingStub { + private S2AServiceBlockingStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected S2AServiceBlockingStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceBlockingStub(channel, callOptions); + } + } + + /** + * A stub to allow clients to do ListenableFuture-style rpc calls to service S2AService. + */ + public static final class S2AServiceFutureStub + extends io.grpc.stub.AbstractFutureStub { + private S2AServiceFutureStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected S2AServiceFutureStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceFutureStub(channel, callOptions); + } + } + + private static final int METHODID_SET_UP_SESSION = 0; + + private static final class MethodHandlers implements + io.grpc.stub.ServerCalls.UnaryMethod, + io.grpc.stub.ServerCalls.ServerStreamingMethod, + io.grpc.stub.ServerCalls.ClientStreamingMethod, + io.grpc.stub.ServerCalls.BidiStreamingMethod { + private final AsyncService serviceImpl; + private final int methodId; + + MethodHandlers(AsyncService serviceImpl, int methodId) { + this.serviceImpl = serviceImpl; + this.methodId = methodId; + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + default: + throw new AssertionError(); + } + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public io.grpc.stub.StreamObserver invoke( + io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + case METHODID_SET_UP_SESSION: + return (io.grpc.stub.StreamObserver) serviceImpl.setUpSession( + (io.grpc.stub.StreamObserver) responseObserver); + default: + throw new AssertionError(); + } + } + } + + public static final io.grpc.ServerServiceDefinition bindService(AsyncService service) { + return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) + .addMethod( + getSetUpSessionMethod(), + io.grpc.stub.ServerCalls.asyncBidiStreamingCall( + new MethodHandlers< + io.grpc.s2a.handshaker.SessionReq, + io.grpc.s2a.handshaker.SessionResp>( + service, METHODID_SET_UP_SESSION))) + .build(); + } + + private static abstract class S2AServiceBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier { + S2AServiceBaseDescriptorSupplier() {} + + @java.lang.Override + public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() { + return io.grpc.s2a.handshaker.S2AProto.getDescriptor(); + } + + @java.lang.Override + public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() { + return getFileDescriptor().findServiceByName("S2AService"); + } + } + + private static final class S2AServiceFileDescriptorSupplier + extends S2AServiceBaseDescriptorSupplier { + S2AServiceFileDescriptorSupplier() {} + } + + private static final class S2AServiceMethodDescriptorSupplier + extends S2AServiceBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoMethodDescriptorSupplier { + private final java.lang.String methodName; + + S2AServiceMethodDescriptorSupplier(java.lang.String methodName) { + this.methodName = methodName; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() { + return getServiceDescriptor().findMethodByName(methodName); + } + } + + private static volatile io.grpc.ServiceDescriptor serviceDescriptor; + + public static io.grpc.ServiceDescriptor getServiceDescriptor() { + io.grpc.ServiceDescriptor result = serviceDescriptor; + if (result == null) { + synchronized (S2AServiceGrpc.class) { + result = serviceDescriptor; + if (result == null) { + serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) + .setSchemaDescriptor(new S2AServiceFileDescriptorSupplier()) + .addMethod(getSetUpSessionMethod()) + .build(); + } + } + } + return result; + } +} diff --git a/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java new file mode 100644 index 00000000000..b2aee6db49e --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java @@ -0,0 +1,96 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; + +import io.grpc.ChannelCredentials; +import io.grpc.TlsChannelCredentials; +import io.grpc.util.AdvancedTlsX509KeyManager; +import io.grpc.util.AdvancedTlsX509TrustManager; +import java.io.File; +import java.io.IOException; +import java.security.GeneralSecurityException; + +/** + * Configures an {@code S2AChannelCredentials.Builder} instance with credentials used to establish a + * connection with the S2A to support talking to the S2A over mTLS. + */ +public final class MtlsToS2AChannelCredentials { + /** + * Creates a {@code S2AChannelCredentials.Builder} builder, that talks to the S2A over mTLS. + * + * @param s2aAddress the address of the S2A server used to secure the connection. + * @param privateKeyPath the path to the private key PEM to use for authenticating to the S2A. + * @param certChainPath the path to the cert chain PEM to use for authenticating to the S2A. + * @param trustBundlePath the path to the trust bundle PEM. + * @return a {@code MtlsToS2AChannelCredentials.Builder} instance. + */ + public static Builder createBuilder( + String s2aAddress, String privateKeyPath, String certChainPath, String trustBundlePath) { + checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); + checkArgument(!isNullOrEmpty(privateKeyPath), "privateKeyPath must not be null or empty."); + checkArgument(!isNullOrEmpty(certChainPath), "certChainPath must not be null or empty."); + checkArgument(!isNullOrEmpty(trustBundlePath), "trustBundlePath must not be null or empty."); + return new Builder(s2aAddress, privateKeyPath, certChainPath, trustBundlePath); + } + + /** Builds an {@code MtlsToS2AChannelCredentials} instance. */ + public static final class Builder { + private final String s2aAddress; + private final String privateKeyPath; + private final String certChainPath; + private final String trustBundlePath; + + Builder( + String s2aAddress, String privateKeyPath, String certChainPath, String trustBundlePath) { + this.s2aAddress = s2aAddress; + this.privateKeyPath = privateKeyPath; + this.certChainPath = certChainPath; + this.trustBundlePath = trustBundlePath; + } + + public S2AChannelCredentials.Builder build() throws GeneralSecurityException, IOException { + checkState(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); + checkState(!isNullOrEmpty(privateKeyPath), "privateKeyPath must not be null or empty."); + checkState(!isNullOrEmpty(certChainPath), "certChainPath must not be null or empty."); + checkState(!isNullOrEmpty(trustBundlePath), "trustBundlePath must not be null or empty."); + File privateKeyFile = new File(privateKeyPath); + File certChainFile = new File(certChainPath); + File trustBundleFile = new File(trustBundlePath); + + AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); + keyManager.updateIdentityCredentialsFromFile(privateKeyFile, certChainFile); + + AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); + trustManager.updateTrustCredentialsFromFile(trustBundleFile); + + ChannelCredentials channelToS2ACredentials = + TlsChannelCredentials.newBuilder() + .keyManager(keyManager) + .trustManager(trustManager) + .build(); + + return S2AChannelCredentials.createBuilder(s2aAddress) + .setS2AChannelCredentials(channelToS2ACredentials); + } + } + + private MtlsToS2AChannelCredentials() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java new file mode 100644 index 00000000000..4ad05b4541a --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java @@ -0,0 +1,132 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; + +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.netty.InternalNettyChannelCredentials; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.s2a.channel.S2AHandshakerServiceChannel; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.grpc.s2a.handshaker.S2AProtocolNegotiatorFactory; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Configures gRPC to use S2A for transport security when establishing a secure channel. Only for + * use on the client side of a gRPC connection. + */ +@NotThreadSafe +public final class S2AChannelCredentials { + /** + * Creates a channel credentials builder for establishing an S2A-secured connection. + * + * @param s2aAddress the address of the S2A server used to secure the connection. + * @return a {@code S2AChannelCredentials.Builder} instance. + */ + public static Builder createBuilder(String s2aAddress) { + checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); + return new Builder(s2aAddress); + } + + /** Builds an {@code S2AChannelCredentials} instance. */ + public static final class Builder { + private final String s2aAddress; + private ObjectPool s2aChannelPool; + private Optional s2aChannelCredentials; + private @Nullable S2AIdentity localIdentity = null; + + Builder(String s2aAddress) { + this.s2aAddress = s2aAddress; + this.s2aChannelPool = null; + this.s2aChannelCredentials = Optional.empty(); + } + + /** + * Sets the local identity of the client in the form of a SPIFFE ID. The client may set at most + * 1 local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalSpiffeId(String localSpiffeId) { + checkNotNull(localSpiffeId); + localIdentity = S2AIdentity.fromSpiffeId(localSpiffeId); + return this; + } + + /** + * Sets the local identity of the client in the form of a hostname. The client may set at most 1 + * local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalHostname(String localHostname) { + checkNotNull(localHostname); + localIdentity = S2AIdentity.fromHostname(localHostname); + return this; + } + + /** + * Sets the local identity of the client in the form of a UID. The client may set at most 1 + * local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalUid(String localUid) { + checkNotNull(localUid); + localIdentity = S2AIdentity.fromUid(localUid); + return this; + } + + /** Sets the credentials to be used when connecting to the S2A. */ + @CanIgnoreReturnValue + public Builder setS2AChannelCredentials(ChannelCredentials s2aChannelCredentials) { + this.s2aChannelCredentials = Optional.of(s2aChannelCredentials); + return this; + } + + public ChannelCredentials build() { + checkState(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); + ObjectPool s2aChannelPool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials)); + checkNotNull(s2aChannelPool, "s2aChannelPool"); + this.s2aChannelPool = s2aChannelPool; + return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory()); + } + + InternalProtocolNegotiator.ClientFactory buildProtocolNegotiatorFactory() { + if (localIdentity == null) { + return S2AProtocolNegotiatorFactory.createClientFactory(Optional.empty(), s2aChannelPool); + } else { + return S2AProtocolNegotiatorFactory.createClientFactory( + Optional.of(localIdentity), s2aChannelPool); + } + } + } + + private S2AChannelCredentials() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java new file mode 100644 index 00000000000..e0501e91c66 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java @@ -0,0 +1,43 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.channel; + +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.grpc.Channel; +import javax.annotation.concurrent.ThreadSafe; + +/** Manages a channel pool to be used for communication with the S2A. */ +@ThreadSafe +public interface S2AChannelPool extends AutoCloseable { + /** + * Retrieves an open channel to the S2A from the channel pool. + * + *

If no channel is available, blocks until a channel can be retrieved from the channel pool. + */ + @CanIgnoreReturnValue + Channel getChannel(); + + /** Returns a channel to the channel pool. */ + void returnChannel(Channel channel); + + /** + * Returns all channels to the channel pool and closes the pool so that no new channels can be + * retrieved from the pool. + */ + @Override + void close(); +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java new file mode 100644 index 00000000000..1d1de28e64e --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java @@ -0,0 +1,112 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.channel; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.Channel; +import io.grpc.internal.ObjectPool; +import javax.annotation.concurrent.ThreadSafe; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Manages a gRPC channel pool and a cached gRPC channel to be used for communication with the S2A. + */ +@ThreadSafe +public final class S2AGrpcChannelPool implements S2AChannelPool { + private static final int MAX_NUMBER_USERS_OF_CACHED_CHANNEL = 100000; + private final ObjectPool channelPool; + + @GuardedBy("this") + private @Nullable Channel cachedChannel; + + @GuardedBy("this") + private int numberOfUsersOfCachedChannel = 0; + + private enum State { + OPEN, + CLOSED, + } + + ; + + @GuardedBy("this") + private State state = State.OPEN; + + public static S2AChannelPool create(ObjectPool channelPool) { + checkNotNull(channelPool, "Channel pool should not be null."); + return new S2AGrpcChannelPool(channelPool); + } + + private S2AGrpcChannelPool(ObjectPool channelPool) { + this.channelPool = channelPool; + } + + /** + * Retrieves a channel from {@code channelPool} if {@code channel} is null, and returns {@code + * channel} otherwise. + * + * @return a {@link Channel} obtained from the channel pool. + */ + @Override + public synchronized Channel getChannel() { + checkState(state.equals(State.OPEN), "Channel pool is not open."); + checkState( + numberOfUsersOfCachedChannel >= 0, + "Number of users of cached channel must be non-negative."); + checkState( + numberOfUsersOfCachedChannel < MAX_NUMBER_USERS_OF_CACHED_CHANNEL, + "Max number of channels have been retrieved from the channel pool."); + if (cachedChannel == null) { + cachedChannel = channelPool.getObject(); + } + numberOfUsersOfCachedChannel += 1; + return cachedChannel; + } + + /** + * Returns {@code channel} to {@code channelPool}. + * + *

The caller must ensure that {@code channel} was retrieved from this channel pool. + */ + @Override + public synchronized void returnChannel(Channel channel) { + checkState(state.equals(State.OPEN), "Channel pool is not open."); + checkArgument( + cachedChannel != null && numberOfUsersOfCachedChannel > 0 && cachedChannel.equals(channel), + "Cannot return the channel to channel pool because the channel was not obtained from" + + " channel pool."); + numberOfUsersOfCachedChannel -= 1; + if (numberOfUsersOfCachedChannel == 0) { + channelPool.returnObject(channel); + cachedChannel = null; + } + } + + @Override + public synchronized void close() { + state = State.CLOSED; + numberOfUsersOfCachedChannel = 0; + if (cachedChannel != null) { + channelPool.returnObject(cachedChannel); + cachedChannel = null; + } + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java new file mode 100644 index 00000000000..75ec7347bb5 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java @@ -0,0 +1,195 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.channel; + +import static com.google.common.base.Preconditions.checkNotNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Maps; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.concurrent.DefaultThreadFactory; +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.ConcurrentMap; +import javax.annotation.concurrent.ThreadSafe; + +/** + * Provides APIs for managing gRPC channels to S2A servers. Each channel is local and plaintext. If + * credentials are provided, they are used to secure the channel. + * + *

This is done as follows: for each S2A server, provides an implementation of gRPC's {@link + * SharedResourceHolder.Resource} interface called a {@code Resource}. A {@code + * Resource} is a factory for creating gRPC channels to the S2A server at a given address, + * and a channel must be returned to the {@code Resource} when it is no longer needed. + * + *

Typical usage pattern is below: + * + *

{@code
+ * Resource resource = S2AHandshakerServiceChannel.getChannelResource("localhost:1234",
+ * creds);
+ * Channel channel = resource.create();
+ * // Send an RPC over the channel to the S2A server running at localhost:1234.
+ * resource.close(channel);
+ * }
+ */ +@ThreadSafe +public final class S2AHandshakerServiceChannel { + private static final ConcurrentMap> SHARED_RESOURCE_CHANNELS = + Maps.newConcurrentMap(); + private static final Duration DELEGATE_TERMINATION_TIMEOUT = Duration.ofSeconds(2); + private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10); + + /** + * Returns a {@link SharedResourceHolder.Resource} instance for managing channels to an S2A server + * running at {@code s2aAddress}. + * + * @param s2aAddress the address of the S2A, typically in the format {@code host:port}. + * @param s2aChannelCredentials the credentials to use when establishing a connection to the S2A. + * @return a {@link ChannelResource} instance that manages a {@link Channel} to the S2A server + * running at {@code s2aAddress}. + */ + public static Resource getChannelResource( + String s2aAddress, Optional s2aChannelCredentials) { + checkNotNull(s2aAddress); + return SHARED_RESOURCE_CHANNELS.computeIfAbsent( + s2aAddress, channelResource -> new ChannelResource(s2aAddress, s2aChannelCredentials)); + } + + /** + * Defines how to create and destroy a {@link Channel} instance that uses shared resources. A + * channel created by {@code ChannelResource} is a plaintext, local channel to the service running + * at {@code targetAddress}. + */ + private static class ChannelResource implements Resource { + private final String targetAddress; + private final Optional channelCredentials; + + public ChannelResource(String targetAddress, Optional channelCredentials) { + this.targetAddress = targetAddress; + this.channelCredentials = channelCredentials; + } + + /** + * Creates a {@code EventLoopHoldingChannel} instance to the service running at {@code + * targetAddress}. This channel uses a dedicated thread pool for its {@code EventLoopGroup} + * instance to avoid blocking. + */ + @Override + public Channel create() { + EventLoopGroup eventLoopGroup = + new NioEventLoopGroup(1, new DefaultThreadFactory("S2A channel pool", true)); + ManagedChannel channel = null; + if (channelCredentials.isPresent()) { + // Create a secure channel. + channel = + NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get()) + .channelType(NioSocketChannel.class) + .directExecutor() + .eventLoopGroup(eventLoopGroup) + .build(); + } else { + // Create a plaintext channel. + channel = + NettyChannelBuilder.forTarget(targetAddress) + .channelType(NioSocketChannel.class) + .directExecutor() + .eventLoopGroup(eventLoopGroup) + .usePlaintext() + .build(); + } + return EventLoopHoldingChannel.create(channel, eventLoopGroup); + } + + /** Destroys a {@code EventLoopHoldingChannel} instance. */ + @Override + public void close(Channel instanceChannel) { + checkNotNull(instanceChannel); + EventLoopHoldingChannel channel = (EventLoopHoldingChannel) instanceChannel; + channel.close(); + } + + @Override + public String toString() { + return "grpc-s2a-channel"; + } + } + + /** + * Manages a channel using a {@link ManagedChannel} instance that belong to the {@code + * EventLoopGroup} thread pool. + */ + @VisibleForTesting + static class EventLoopHoldingChannel extends Channel { + private final ManagedChannel delegate; + private final EventLoopGroup eventLoopGroup; + + static EventLoopHoldingChannel create(ManagedChannel delegate, EventLoopGroup eventLoopGroup) { + checkNotNull(delegate); + checkNotNull(eventLoopGroup); + return new EventLoopHoldingChannel(delegate, eventLoopGroup); + } + + private EventLoopHoldingChannel(ManagedChannel delegate, EventLoopGroup eventLoopGroup) { + this.delegate = delegate; + this.eventLoopGroup = eventLoopGroup; + } + + /** + * Returns the address of the service to which the {@code delegate} channel connects, which is + * typically of the form {@code host:port}. + */ + @Override + public String authority() { + return delegate.authority(); + } + + /** Creates a {@link ClientCall} that invokes the operations in {@link MethodDescriptor}. */ + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions options) { + return delegate.newCall(methodDescriptor, options); + } + + @SuppressWarnings("FutureReturnValueIgnored") + public void close() { + delegate.shutdownNow(); + boolean isDelegateTerminated; + try { + isDelegateTerminated = + delegate.awaitTermination(DELEGATE_TERMINATION_TIMEOUT.getSeconds(), SECONDS); + } catch (InterruptedException e) { + isDelegateTerminated = false; + } + long quietPeriodSeconds = isDelegateTerminated ? 0 : 1; + eventLoopGroup.shutdownGracefully( + quietPeriodSeconds, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); + } + } + + private S2AHandshakerServiceChannel() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java b/s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java new file mode 100644 index 00000000000..1f9b2d5a23a --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import java.io.IOException; + +/** Indicates that a connection has been closed. */ +@SuppressWarnings("serial") // This class is never serialized. +final class ConnectionIsClosedException extends IOException { + public ConnectionIsClosedException(String errorMessage) { + super(errorMessage); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java b/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java new file mode 100644 index 00000000000..3b17a5ed322 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java @@ -0,0 +1,60 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import com.google.errorprone.annotations.Immutable; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.grpc.s2a.handshaker.tokenmanager.AccessTokenManager; +import java.util.Optional; + +/** Retrieves the authentication mechanism for a given local identity. */ +@Immutable +final class GetAuthenticationMechanisms { + private static final Optional TOKEN_MANAGER = AccessTokenManager.create(); + + /** + * Retrieves the authentication mechanism for a given local identity. + * + * @param localIdentity the identity for which to fetch a token. + * @return an {@link AuthenticationMechanism} for the given local identity. + */ + static Optional getAuthMechanism(Optional localIdentity) { + Optional authMechanism = Optional.empty(); + if (!TOKEN_MANAGER.isPresent()) { + return Optional.empty(); + } + AccessTokenManager manager = TOKEN_MANAGER.get(); + // If no identity is provided, fetch the default access token and DO NOT attach an identity + // to the request. + if (!localIdentity.isPresent()) { + authMechanism = + Optional.of( + AuthenticationMechanism.newBuilder().setToken(manager.getDefaultToken()).build()); + } else { + // Fetch an access token for the provided identity. + authMechanism = + Optional.of( + AuthenticationMechanism.newBuilder() + .setIdentity(localIdentity.get().identity()) + .setToken(manager.getToken(localIdentity.get())) + .build()); + } + return authMechanism; + } + + private GetAuthenticationMechanisms() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java b/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java new file mode 100644 index 00000000000..34cc4bbe737 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +/** Converts proto messages to Netty strings. */ +final class ProtoUtil { + /** + * Converts {@link Ciphersuite} to its {@link String} representation. + * + * @param ciphersuite the {@link Ciphersuite} to be converted. + * @return a {@link String} representing the ciphersuite. + * @throws AssertionError if the {@link Ciphersuite} is not one of the supported ciphersuites. + */ + static String convertCiphersuite(Ciphersuite ciphersuite) { + switch (ciphersuite) { + case CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: + return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"; + case CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: + return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"; + case CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: + return "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"; + case CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; + case CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384: + return "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"; + case CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: + return "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256"; + default: + throw new AssertionError( + String.format("Ciphersuite %d is not supported.", ciphersuite.getNumber())); + } + } + + /** + * Converts a {@link TLSVersion} object to its {@link String} representation. + * + * @param tlsVersion the {@link TLSVersion} object to be converted. + * @return a {@link String} representation of the TLS version. + * @throws AssertionError if the {@code tlsVersion} is not one of the supported TLS versions. + */ + static String convertTlsProtocolVersion(TLSVersion tlsVersion) { + switch (tlsVersion) { + case TLS_VERSION_1_3: + return "TLSv1.3"; + case TLS_VERSION_1_2: + return "TLSv1.2"; + case TLS_VERSION_1_1: + return "TLSv1.1"; + case TLS_VERSION_1_0: + return "TLSv1"; + default: + throw new AssertionError( + String.format("TLS version %d is not supported.", tlsVersion.getNumber())); + } + } + + private ProtoUtil() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AConnectionException.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AConnectionException.java new file mode 100644 index 00000000000..d976308ad22 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AConnectionException.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +/** Exception that denotes a runtime error that was encountered when talking to the S2A server. */ +@SuppressWarnings("serial") // This class is never serialized. +public class S2AConnectionException extends RuntimeException { + S2AConnectionException(String message) { + super(message); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java new file mode 100644 index 00000000000..30957acd521 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.errorprone.annotations.ThreadSafe; + +/** + * Stores an identity in such a way that it can be sent to the S2A handshaker service. The identity + * may be formatted as a SPIFFE ID or as a hostname. + */ +@ThreadSafe +public final class S2AIdentity { + private final Identity identity; + + /** Returns an {@link S2AIdentity} instance with SPIFFE ID set to {@code spiffeId}. */ + public static S2AIdentity fromSpiffeId(String spiffeId) { + checkNotNull(spiffeId); + return new S2AIdentity(Identity.newBuilder().setSpiffeId(spiffeId).build()); + } + + /** Returns an {@link S2AIdentity} instance with hostname set to {@code hostname}. */ + public static S2AIdentity fromHostname(String hostname) { + checkNotNull(hostname); + return new S2AIdentity(Identity.newBuilder().setHostname(hostname).build()); + } + + /** Returns an {@link S2AIdentity} instance with UID set to {@code uid}. */ + public static S2AIdentity fromUid(String uid) { + checkNotNull(uid); + return new S2AIdentity(Identity.newBuilder().setUid(uid).build()); + } + + /** Returns an {@link S2AIdentity} instance with {@code identity} set. */ + public static S2AIdentity fromIdentity(Identity identity) { + return new S2AIdentity(identity == null ? Identity.getDefaultInstance() : identity); + } + + private S2AIdentity(Identity identity) { + this.identity = identity; + } + + /** Returns the proto {@link Identity} representation of this identity instance. */ + public Identity identity() { + return identity; + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java new file mode 100644 index 00000000000..fb4908d99fc --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java @@ -0,0 +1,143 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslPrivateKeyMethod; +import java.io.IOException; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import javax.net.ssl.SSLEngine; + +/** + * Handles requests on signing bytes with a private key designated by {@code stub}. + * + *

This is done by sending the to-be-signed bytes to an S2A server (designated by {@code stub}) + * and read the signature from the server. + * + *

OpenSSL libraries must be appropriately initialized before using this class. One possible way + * to initialize OpenSSL library is to call {@code + * GrpcSslContexts.configure(SslContextBuilder.forClient());}. + */ +@NotThreadSafe +final class S2APrivateKeyMethod implements OpenSslPrivateKeyMethod { + private final S2AStub stub; + private final Optional localIdentity; + private static final ImmutableMap + OPENSSL_TO_S2A_SIGNATURE_ALGORITHM_MAP = + ImmutableMap.of( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA512, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512); + + public static S2APrivateKeyMethod create(S2AStub stub, Optional localIdentity) { + checkNotNull(stub); + return new S2APrivateKeyMethod(stub, localIdentity); + } + + private S2APrivateKeyMethod(S2AStub stub, Optional localIdentity) { + this.stub = stub; + this.localIdentity = localIdentity; + } + + /** + * Converts the signature algorithm to an enum understood by S2A. + * + * @param signatureAlgorithm the int representation of the signature algorithm define by {@code + * OpenSslPrivateKeyMethod}. + * @return the signature algorithm enum defined by S2A proto. + * @throws UnsupportedOperationException if the algorithm is not supported by S2A. + */ + @VisibleForTesting + static SignatureAlgorithm convertOpenSslSignAlgToS2ASignAlg(int signatureAlgorithm) { + SignatureAlgorithm sig = OPENSSL_TO_S2A_SIGNATURE_ALGORITHM_MAP.get(signatureAlgorithm); + if (sig == null) { + throw new UnsupportedOperationException( + String.format("Signature Algorithm %d is not supported.", signatureAlgorithm)); + } + return sig; + } + + /** + * Signs the input bytes by sending the request to the S2A srever. + * + * @param engine not used. + * @param signatureAlgorithm the {@link OpenSslPrivateKeyMethod}'s signature algorithm + * representation + * @param input the bytes to be signed. + * @return the signature of the {@code input}. + * @throws IOException if the connection to the S2A server is corrupted. + * @throws InterruptedException if the connection to the S2A server is interrupted. + * @throws S2AConnectionException if the response from the S2A server does not contain valid data. + */ + @Override + public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) + throws IOException, InterruptedException { + checkArgument(input.length > 0, "No bytes to sign."); + SignatureAlgorithm s2aSignatureAlgorithm = + convertOpenSslSignAlgToS2ASignAlg(signatureAlgorithm); + SessionReq.Builder reqBuilder = + SessionReq.newBuilder() + .setOffloadPrivateKeyOperationReq( + OffloadPrivateKeyOperationReq.newBuilder() + .setOperation(OffloadPrivateKeyOperationReq.PrivateKeyOperation.SIGN) + .setSignatureAlgorithm(s2aSignatureAlgorithm) + .setRawBytes(ByteString.copyFrom(input))); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().identity()); + } + + SessionResp resp = stub.send(reqBuilder.build()); + + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new S2AConnectionException( + String.format( + "Error occurred in response from S2A, error code: %d, error message: \"%s\".", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + if (!resp.hasOffloadPrivateKeyOperationResp()) { + throw new S2AConnectionException("No valid response received from S2A."); + } + return resp.getOffloadPrivateKeyOperationResp().getOutBytes().toByteArray(); + } + + @Override + public byte[] decrypt(SSLEngine engine, byte[] input) { + throw new UnsupportedOperationException("decrypt is not supported."); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java new file mode 100644 index 00000000000..7f00e198fae --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java @@ -0,0 +1,194 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.net.HostAndPort; +import com.google.errorprone.annotations.ThreadSafe; +import io.grpc.Channel; +import io.grpc.ChannelLogger; +import io.grpc.internal.ObjectPool; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiators; +import io.grpc.netty.InternalProtocolNegotiators.ProtocolNegotiationHandler; +import io.grpc.s2a.channel.S2AChannelPool; +import io.grpc.s2a.channel.S2AGrpcChannelPool; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.ssl.SslContext; +import io.netty.util.AsciiString; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.util.Optional; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Factory for performing negotiation of a secure channel using the S2A. */ +@ThreadSafe +public final class S2AProtocolNegotiatorFactory { + @VisibleForTesting static final int DEFAULT_PORT = 443; + private static final AsciiString SCHEME = AsciiString.of("https"); + + /** + * Creates a {@code S2AProtocolNegotiatorFactory} configured for a client to establish secure + * connections using the S2A. + * + * @param localIdentity the identity of the client; if none is provided, the S2A will use the + * client's default identity. + * @param s2aChannelPool a pool of shared channels that can be used to connect to the S2A. + * @return a factory for creating a client-side protocol negotiator. + */ + public static InternalProtocolNegotiator.ClientFactory createClientFactory( + Optional localIdentity, ObjectPool s2aChannelPool) { + checkNotNull(s2aChannelPool, "S2A channel pool should not be null."); + checkNotNull(localIdentity, "Local identity should not be null on the client side."); + S2AChannelPool channelPool = S2AGrpcChannelPool.create(s2aChannelPool); + return new S2AClientProtocolNegotiatorFactory(localIdentity, channelPool); + } + + static final class S2AClientProtocolNegotiatorFactory + implements InternalProtocolNegotiator.ClientFactory { + private final Optional localIdentity; + private final S2AChannelPool channelPool; + + S2AClientProtocolNegotiatorFactory( + Optional localIdentity, S2AChannelPool channelPool) { + this.localIdentity = localIdentity; + this.channelPool = channelPool; + } + + @Override + public ProtocolNegotiator newNegotiator() { + return S2AProtocolNegotiator.createForClient(channelPool, localIdentity); + } + + @Override + public int getDefaultPort() { + return DEFAULT_PORT; + } + } + + /** Negotiates the TLS handshake using S2A. */ + @VisibleForTesting + static final class S2AProtocolNegotiator implements ProtocolNegotiator { + + private final S2AChannelPool channelPool; + private final Optional localIdentity; + + static S2AProtocolNegotiator createForClient( + S2AChannelPool channelPool, Optional localIdentity) { + checkNotNull(channelPool, "Channel pool should not be null."); + checkNotNull(localIdentity, "Local identity should not be null on the client side."); + return new S2AProtocolNegotiator(channelPool, localIdentity); + } + + @VisibleForTesting + static @Nullable String getHostNameFromAuthority(@Nullable String authority) { + if (authority == null) { + return null; + } + return HostAndPort.fromString(authority).getHost(); + } + + private S2AProtocolNegotiator(S2AChannelPool channelPool, Optional localIdentity) { + this.channelPool = channelPool; + this.localIdentity = localIdentity; + } + + @Override + public AsciiString scheme() { + return SCHEME; + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + checkNotNull(grpcHandler, "grpcHandler should not be null."); + String hostname = getHostNameFromAuthority(grpcHandler.getAuthority()); + checkNotNull(hostname, "hostname should not be null."); + return new S2AProtocolNegotiationHandler( + InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler), + grpcHandler.getNegotiationLogger(), + channelPool, + localIdentity, + hostname, + grpcHandler); + } + + @Override + public void close() { + channelPool.close(); + } + } + + private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler { + private final S2AChannelPool channelPool; + private final Optional localIdentity; + private final String hostname; + private InternalProtocolNegotiator.ProtocolNegotiator negotiator; + private final GrpcHttp2ConnectionHandler grpcHandler; + + private S2AProtocolNegotiationHandler( + ChannelHandler next, + ChannelLogger negotiationLogger, + S2AChannelPool channelPool, + Optional localIdentity, + String hostname, + GrpcHttp2ConnectionHandler grpcHandler) { + super(next, negotiationLogger); + this.channelPool = channelPool; + this.localIdentity = localIdentity; + this.hostname = hostname; + this.grpcHandler = grpcHandler; + } + + @Override + protected void handlerAdded0(ChannelHandlerContext ctx) throws GeneralSecurityException { + SslContext sslContext; + try { + // Establish a stream to S2A server. + Channel ch = channelPool.getChannel(); + S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(ch); + S2AStub s2aStub = S2AStub.newInstance(stub); + sslContext = SslContextFactory.createForClient(s2aStub, hostname, localIdentity); + } catch (InterruptedException + | IOException + | IllegalArgumentException + | UnrecoverableKeyException + | CertificateException + | NoSuchAlgorithmException + | KeyStoreException e) { + // GeneralSecurityException is intentionally not caught, and rather propagated. This is done + // because throwing a GeneralSecurityException in this context indicates that we encountered + // a retryable error. + throw new IllegalArgumentException( + "Something went wrong during the initialization of SslContext.", e); + } + negotiator = InternalProtocolNegotiators.tls(sslContext); + ctx.pipeline().addBefore(ctx.name(), /* name= */ null, negotiator.newHandler(grpcHandler)); + } + } + + private S2AProtocolNegotiatorFactory() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java new file mode 100644 index 00000000000..aa2502cd4fa --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java @@ -0,0 +1,225 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Verify.verify; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.Optional; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.concurrent.NotThreadSafe; + +/** Reads and writes messages to and from the S2A. */ +@NotThreadSafe +class S2AStub implements AutoCloseable { + private static final Logger logger = Logger.getLogger(S2AStub.class.getName()); + private static final long HANDSHAKE_RPC_DEADLINE_SECS = 20; + private final StreamObserver reader = new Reader(); + private final BlockingQueue responses = new ArrayBlockingQueue<>(10); + private S2AServiceGrpc.S2AServiceStub serviceStub; + private StreamObserver writer; + private boolean doneReading = false; + private boolean doneWriting = false; + + static S2AStub newInstance(S2AServiceGrpc.S2AServiceStub serviceStub) { + checkNotNull(serviceStub); + return new S2AStub(serviceStub); + } + + @VisibleForTesting + static S2AStub newInstanceForTesting(StreamObserver writer) { + checkNotNull(writer); + return new S2AStub(writer); + } + + private S2AStub(S2AServiceGrpc.S2AServiceStub serviceStub) { + this.serviceStub = serviceStub; + } + + private S2AStub(StreamObserver writer) { + this.writer = writer; + } + + @VisibleForTesting + StreamObserver getReader() { + return reader; + } + + @VisibleForTesting + BlockingQueue getResponses() { + return responses; + } + + /** + * Sends a request and returns the response. Caller must wait until this method executes prior to + * calling it again. If this method throws {@code ConnectionIsClosedException}, then it should not + * be called again, and both {@code reader} and {@code writer} are closed. + * + * @param req the {@code SessionReq} message to be sent to the S2A server. + * @return the {@code SessionResp} message received from the S2A server. + * @throws ConnectionIsClosedException if {@code reader} or {@code writer} calls their {@code + * onCompleted} method. + * @throws IOException if an unexpected response is received, or if the {@code reader} or {@code + * writer} calls their {@code onError} method. + */ + public SessionResp send(SessionReq req) throws IOException, InterruptedException { + if (doneWriting && doneReading) { + logger.log(Level.INFO, "Stream to the S2A is closed."); + throw new ConnectionIsClosedException("Stream to the S2A is closed."); + } + createWriterIfNull(); + if (!responses.isEmpty()) { + IOException exception = null; + SessionResp resp = null; + try { + resp = responses.take().getResultOrThrow(); + } catch (IOException e) { + exception = e; + } + responses.clear(); + if (exception != null) { + logger.log( + Level.WARNING, + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable. " + + exception.getMessage()); + throw new IOException( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable." + + exception.getMessage()); + } + return resp; + } + try { + writer.onNext(req); + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Error occurred while writing to the S2A.", e); + writer.onError(e); + responses.offer(Result.createWithThrowable(e)); + } + try { + return responses.take().getResultOrThrow(); + } catch (ConnectionIsClosedException e) { + // A ConnectionIsClosedException is thrown by getResultOrThrow when reader calls its + // onCompleted method. The close method is called to also close the writer, and then the + // ConnectionIsClosedException is re-thrown in order to indicate to the caller that send + // should not be called again. + close(); + throw e; + } + } + + @Override + public void close() { + if (doneWriting && doneReading) { + return; + } + verify(!doneWriting); + doneReading = true; + doneWriting = true; + if (writer != null) { + writer.onCompleted(); + } + } + + /** Create a new writer if the writer is null. */ + private void createWriterIfNull() { + if (writer == null) { + writer = + serviceStub.withDeadlineAfter(HANDSHAKE_RPC_DEADLINE_SECS, SECONDS).setUpSession(reader); + } + } + + private class Reader implements StreamObserver { + /** + * Places a {@code SessionResp} message in the {@code responses} queue, or an {@code + * IOException} if reading is complete. + * + * @param resp the {@code SessionResp} message received from the S2A handshaker module. + */ + @Override + public void onNext(SessionResp resp) { + verify(!doneReading); + responses.offer(Result.createWithResponse(resp)); + } + + /** + * Places a {@code Throwable} in the {@code responses} queue. + * + * @param t the {@code Throwable} caught when reading the stream to the S2A handshaker module. + */ + @Override + public void onError(Throwable t) { + logger.log(Level.WARNING, "Error occurred while reading from the S2A.", t); + responses.offer(Result.createWithThrowable(t)); + } + + /** + * Sets {@code doneReading} to true, and places a {@code ConnectionIsClosedException} in the + * {@code responses} queue. + */ + @Override + public void onCompleted() { + logger.log(Level.INFO, "Reading from the S2A is complete."); + doneReading = true; + responses.offer( + Result.createWithThrowable( + new ConnectionIsClosedException("Reading from the S2A is complete."))); + } + } + + private static final class Result { + private final Optional response; + private final Optional throwable; + + static Result createWithResponse(SessionResp response) { + return new Result(Optional.of(response), Optional.empty()); + } + + static Result createWithThrowable(Throwable throwable) { + return new Result(Optional.empty(), Optional.of(throwable)); + } + + private Result(Optional response, Optional throwable) { + checkArgument(response.isPresent() != throwable.isPresent()); + this.response = response; + this.throwable = throwable; + } + + /** Throws {@code throwable} if present, and returns {@code response} otherwise. */ + SessionResp getResultOrThrow() throws IOException { + if (throwable.isPresent()) { + if (throwable.get() instanceof ConnectionIsClosedException) { + ConnectionIsClosedException exception = (ConnectionIsClosedException) throwable.get(); + throw exception; + } else { + throw new IOException(throwable.get()); + } + } + verify(response.isPresent()); + return response.get(); + } + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java new file mode 100644 index 00000000000..014fcf4c4f8 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java @@ -0,0 +1,152 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.grpc.s2a.handshaker.ValidatePeerCertificateChainReq.VerificationMode; +import java.io.IOException; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import javax.net.ssl.X509TrustManager; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Offloads verification of the peer certificate chain to S2A. */ +@NotThreadSafe +final class S2ATrustManager implements X509TrustManager { + private final Optional localIdentity; + private final S2AStub stub; + private final String hostname; + + static S2ATrustManager createForClient( + S2AStub stub, String hostname, Optional localIdentity) { + checkNotNull(stub); + checkNotNull(hostname); + return new S2ATrustManager(stub, hostname, localIdentity); + } + + private S2ATrustManager(S2AStub stub, String hostname, Optional localIdentity) { + this.stub = stub; + this.hostname = hostname; + this.localIdentity = localIdentity; + } + + /** + * Validates the given certificate chain provided by the peer. + * + * @param chain the peer certificate chain + * @param authType the authentication type based on the client certificate + * @throws IllegalArgumentException if null or zero-length chain is passed in for the chain + * parameter. + * @throws CertificateException if the certificate chain is not trusted by this TrustManager. + */ + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + checkPeerTrusted(chain, /* isCheckingClientCertificateChain= */ true); + } + + /** + * Validates the given certificate chain provided by the peer. + * + * @param chain the peer certificate chain + * @param authType the authentication type based on the client certificate + * @throws IllegalArgumentException if null or zero-length chain is passed in for the chain + * parameter. + * @throws CertificateException if the certificate chain is not trusted by this TrustManager. + */ + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + checkPeerTrusted(chain, /* isCheckingClientCertificateChain= */ false); + } + + /** + * Returns null because the accepted issuers are held in S2A and this class receives decision made + * from S2A on the fly about which to use to verify a given chain. + * + * @return null. + */ + @Override + public X509Certificate @Nullable [] getAcceptedIssuers() { + return null; + } + + private void checkPeerTrusted(X509Certificate[] chain, boolean isCheckingClientCertificateChain) + throws CertificateException { + checkNotNull(chain); + checkArgument(chain.length > 0, "Certificate chain has zero certificates."); + + ValidatePeerCertificateChainReq.Builder validatePeerCertificateChainReq = + ValidatePeerCertificateChainReq.newBuilder().setMode(VerificationMode.UNSPECIFIED); + if (isCheckingClientCertificateChain) { + validatePeerCertificateChainReq.setClientPeer( + ValidatePeerCertificateChainReq.ClientPeer.newBuilder() + .addAllCertificateChain(certificateChainToDerChain(chain))); + } else { + validatePeerCertificateChainReq.setServerPeer( + ValidatePeerCertificateChainReq.ServerPeer.newBuilder() + .addAllCertificateChain(certificateChainToDerChain(chain)) + .setServerHostname(hostname)); + } + + SessionReq.Builder reqBuilder = + SessionReq.newBuilder().setValidatePeerCertificateChainReq(validatePeerCertificateChainReq); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().identity()); + } + + SessionResp resp; + try { + resp = stub.send(reqBuilder.build()); + } catch (IOException | InterruptedException e) { + throw new CertificateException("Failed to send request to S2A.", e); + } + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new CertificateException( + String.format( + "Error occurred in response from S2A, error code: %d, error message: %s.", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + + if (!resp.hasValidatePeerCertificateChainResp()) { + throw new CertificateException("No valid response received from S2A."); + } + + ValidatePeerCertificateChainResp validationResult = resp.getValidatePeerCertificateChainResp(); + if (validationResult.getValidationResult() + != ValidatePeerCertificateChainResp.ValidationResult.SUCCESS) { + throw new CertificateException(validationResult.getValidationDetails()); + } + } + + private static ImmutableList certificateChainToDerChain(X509Certificate[] chain) + throws CertificateEncodingException { + ImmutableList.Builder derChain = ImmutableList.builder(); + for (X509Certificate certificate : chain) { + derChain.add(ByteString.copyFrom(certificate.getEncoded())); + } + return derChain.build(); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java new file mode 100644 index 00000000000..bfa45146625 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java @@ -0,0 +1,179 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkNotNull; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.collect.ImmutableList; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslContextOption; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.OpenSslX509KeyManagerFactory; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Optional; +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLSessionContext; + +/** Creates {@link SslContext} objects with TLS configurations from S2A server. */ +final class SslContextFactory { + + /** + * Creates {@link SslContext} objects for client with TLS configurations from S2A server. + * + * @param stub the {@link S2AStub} to talk to the S2A server. + * @param targetName the {@link String} of the server that this client makes connection to. + * @param localIdentity the {@link S2AIdentity} that should be used when talking to S2A server. + * Will use default identity if empty. + * @return a {@link SslContext} object. + * @throws NullPointerException if either {@code stub} or {@code targetName} is null. + * @throws IOException if an unexpected response from S2A server is received. + * @throws InterruptedException if {@code stub} is closed. + */ + static SslContext createForClient( + S2AStub stub, String targetName, Optional localIdentity) + throws IOException, + InterruptedException, + CertificateException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException, + GeneralSecurityException { + checkNotNull(stub, "stub should not be null."); + checkNotNull(targetName, "targetName should not be null on client side."); + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration; + try { + clientTlsConfiguration = getClientTlsConfigurationFromS2A(stub, localIdentity); + } catch (IOException | InterruptedException e) { + throw new GeneralSecurityException("Failed to get client TLS configuration from S2A.", e); + } + + // Use the default value for timeout. + // Use the smallest possible value for cache size. + // The Provider is by default OPENSSL. No need to manually set it. + SslContextBuilder sslContextBuilder = + GrpcSslContexts.configure(SslContextBuilder.forClient()) + .sessionCacheSize(1) + .sessionTimeout(0); + + configureSslContextWithClientTlsConfiguration(clientTlsConfiguration, sslContextBuilder); + sslContextBuilder.trustManager( + S2ATrustManager.createForClient(stub, targetName, localIdentity)); + sslContextBuilder.option( + OpenSslContextOption.PRIVATE_KEY_METHOD, S2APrivateKeyMethod.create(stub, localIdentity)); + + SslContext sslContext = sslContextBuilder.build(); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + openSslSessionContext.setSessionCacheEnabled(false); + } + + return sslContext; + } + + private static GetTlsConfigurationResp.ClientTlsConfiguration getClientTlsConfigurationFromS2A( + S2AStub stub, Optional localIdentity) throws IOException, InterruptedException { + checkNotNull(stub, "stub should not be null."); + SessionReq.Builder reqBuilder = SessionReq.newBuilder(); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().identity()); + } + Optional authMechanism = + GetAuthenticationMechanisms.getAuthMechanism(localIdentity); + if (authMechanism.isPresent()) { + reqBuilder.addAuthenticationMechanisms(authMechanism.get()); + } + SessionResp resp = + stub.send( + reqBuilder + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build()); + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new S2AConnectionException( + String.format( + "response from S2A server has ean error %d with error message %s.", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + if (!resp.getGetTlsConfigurationResp().hasClientTlsConfiguration()) { + throw new S2AConnectionException( + "Response from S2A server does NOT contain ClientTlsConfiguration."); + } + return resp.getGetTlsConfigurationResp().getClientTlsConfiguration(); + } + + private static void configureSslContextWithClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration, + SslContextBuilder sslContextBuilder) + throws CertificateException, + IOException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException { + sslContextBuilder.keyManager(createKeylessManager(clientTlsConfiguration)); + sslContextBuilder.protocols( + ProtoUtil.convertTlsProtocolVersion(clientTlsConfiguration.getMinTlsVersion()), + ProtoUtil.convertTlsProtocolVersion(clientTlsConfiguration.getMaxTlsVersion())); + ImmutableList.Builder ciphersuites = ImmutableList.builder(); + for (int i = 0; i < clientTlsConfiguration.getCiphersuitesCount(); ++i) { + ciphersuites.add(ProtoUtil.convertCiphersuite(clientTlsConfiguration.getCiphersuites(i))); + } + sslContextBuilder.ciphers(ciphersuites.build()); + } + + private static KeyManager createKeylessManager( + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration) + throws CertificateException, + IOException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException { + X509Certificate[] certificates = + new X509Certificate[clientTlsConfiguration.getCertificateChainCount()]; + for (int i = 0; i < clientTlsConfiguration.getCertificateChainCount(); ++i) { + certificates[i] = convertStringToX509Cert(clientTlsConfiguration.getCertificateChain(i)); + } + KeyManager[] keyManagers = + OpenSslX509KeyManagerFactory.newKeyless(certificates).getKeyManagers(); + if (keyManagers == null || keyManagers.length == 0) { + throw new IllegalStateException("No key managers created."); + } + return keyManagers[0]; + } + + private static X509Certificate convertStringToX509Cert(String certificate) + throws CertificateException { + return (X509Certificate) + CertificateFactory.getInstance("X509") + .generateCertificate(new ByteArrayInputStream(certificate.getBytes(UTF_8))); + } + + private SslContextFactory() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/AccessTokenManager.java b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/AccessTokenManager.java new file mode 100644 index 00000000000..94549d11c87 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/AccessTokenManager.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker.tokenmanager; + +import io.grpc.s2a.handshaker.S2AIdentity; +import java.lang.reflect.Method; +import java.util.Optional; +import javax.annotation.concurrent.ThreadSafe; + +/** Manages access tokens for authenticating to the S2A. */ +@ThreadSafe +public final class AccessTokenManager { + private final TokenFetcher tokenFetcher; + + /** Creates an {@code AccessTokenManager} based on the environment where the application runs. */ + @SuppressWarnings("RethrowReflectiveOperationExceptionAsLinkageError") + public static Optional create() { + Optional tokenFetcher; + try { + Class singleTokenFetcherClass = + Class.forName("io.grpc.s2a.handshaker.tokenmanager.SingleTokenFetcher"); + Method createTokenFetcher = singleTokenFetcherClass.getMethod("create"); + tokenFetcher = (Optional) createTokenFetcher.invoke(null); + } catch (ClassNotFoundException e) { + tokenFetcher = Optional.empty(); + } catch (ReflectiveOperationException e) { + throw new AssertionError(e); + } + return tokenFetcher.isPresent() + ? Optional.of(new AccessTokenManager((TokenFetcher) tokenFetcher.get())) + : Optional.empty(); + } + + private AccessTokenManager(TokenFetcher tokenFetcher) { + this.tokenFetcher = tokenFetcher; + } + + /** Returns an access token when no identity is specified. */ + public String getDefaultToken() { + return tokenFetcher.getDefaultToken(); + } + + /** Returns an access token for the given identity. */ + public String getToken(S2AIdentity identity) { + return tokenFetcher.getToken(identity); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java new file mode 100644 index 00000000000..3b2bd051e84 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker.tokenmanager; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; +import io.grpc.s2a.handshaker.S2AIdentity; +import java.util.Optional; + +/** Fetches a single access token via an environment variable. */ +public final class SingleTokenFetcher implements TokenFetcher { + private static final String ENVIRONMENT_VARIABLE = "S2A_ACCESS_TOKEN"; + + /** Set an access token via a flag. */ + @Parameters(separators = "=") + public static class Flags { + @Parameter( + names = "--s2a_access_token", + description = "The access token used to authenticate to S2A.") + private static String accessToken = System.getenv(ENVIRONMENT_VARIABLE); + + public synchronized void reset() { + accessToken = null; + } + } + + private final String token; + + /** + * Creates a {@code SingleTokenFetcher} from {@code ENVIRONMENT_VARIABLE}, and returns an empty + * {@code Optional} instance if the token could not be fetched. + */ + public static Optional create() { + return Optional.ofNullable(Flags.accessToken).map(SingleTokenFetcher::new); + } + + private SingleTokenFetcher(String token) { + this.token = token; + } + + @Override + public String getDefaultToken() { + return token; + } + + @Override + public String getToken(S2AIdentity identity) { + return token; + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/TokenFetcher.java b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/TokenFetcher.java new file mode 100644 index 00000000000..9eeddaad844 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/TokenFetcher.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker.tokenmanager; + +import io.grpc.s2a.handshaker.S2AIdentity; + +/** Fetches tokens used to authenticate to S2A. */ +interface TokenFetcher { + /** Returns an access token when no identity is specified. */ + String getDefaultToken(); + + /** Returns an access token for the given identity. */ + String getToken(S2AIdentity identity); +} \ No newline at end of file diff --git a/s2a/src/main/proto/grpc/gcp/common.proto b/s2a/src/main/proto/grpc/gcp/common.proto new file mode 100644 index 00000000000..7c105c2ce05 --- /dev/null +++ b/s2a/src/main/proto/grpc/gcp/common.proto @@ -0,0 +1,79 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package grpc.gcp; + +option java_multiple_files = true; +option java_outer_classname = "CommonProto"; +option java_package = "io.grpc.s2a.handshaker"; + +// The TLS 1.0-1.2 ciphersuites that the application can negotiate when using +// S2A. +enum Ciphersuite { + CIPHERSUITE_UNSPECIFIED = 0; + CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 1; + CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 2; + CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 3; + CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 4; + CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 5; + CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 6; +} + +// The TLS versions supported by S2A's handshaker module. +enum TLSVersion { + TLS_VERSION_UNSPECIFIED = 0; + TLS_VERSION_1_0 = 1; + TLS_VERSION_1_1 = 2; + TLS_VERSION_1_2 = 3; + TLS_VERSION_1_3 = 4; +} + +// The side in the TLS connection. +enum ConnectionSide { + CONNECTION_SIDE_UNSPECIFIED = 0; + CONNECTION_SIDE_CLIENT = 1; + CONNECTION_SIDE_SERVER = 2; +} + +// The ALPN protocols that the application can negotiate during a TLS handshake. +enum AlpnProtocol { + ALPN_PROTOCOL_UNSPECIFIED = 0; + ALPN_PROTOCOL_GRPC = 1; + ALPN_PROTOCOL_HTTP2 = 2; + ALPN_PROTOCOL_HTTP1_1 = 3; +} + +message Identity { + oneof identity_oneof { + // The SPIFFE ID of a connection endpoint. + string spiffe_id = 1; + + // The hostname of a connection endpoint. + string hostname = 2; + + // The UID of a connection endpoint. + string uid = 4; + + // The username of a connection endpoint. + string username = 5; + + // The GCP ID of a connection endpoint. + string gcp_id = 6; + } + + // Additional identity-specific attributes. + map attributes = 3; +} diff --git a/s2a/src/main/proto/grpc/gcp/s2a.proto b/s2a/src/main/proto/grpc/gcp/s2a.proto new file mode 100644 index 00000000000..1a05b546ebb --- /dev/null +++ b/s2a/src/main/proto/grpc/gcp/s2a.proto @@ -0,0 +1,369 @@ +// Copyright 2024 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The canonical version of this proto can be found at +// https://github.com/grpc/grpc-proto/blob/master/grpc/gcp/s2a/s2a.proto + +syntax = "proto3"; + +package grpc.gcp; + +import "grpc/gcp/common.proto"; +import "grpc/gcp/s2a_context.proto"; + +option java_multiple_files = true; +option java_outer_classname = "S2AProto"; +option java_package = "io.grpc.s2a.handshaker"; + +enum SignatureAlgorithm { + S2A_SSL_SIGN_UNSPECIFIED = 0; + // RSA Public-Key Cryptography Standards #1. + S2A_SSL_SIGN_RSA_PKCS1_SHA256 = 1; + S2A_SSL_SIGN_RSA_PKCS1_SHA384 = 2; + S2A_SSL_SIGN_RSA_PKCS1_SHA512 = 3; + // ECDSA. + S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256 = 4; + S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384 = 5; + S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512 = 6; + // RSA Probabilistic Signature Scheme. + S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256 = 7; + S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384 = 8; + S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512 = 9; + // ED25519. + S2A_SSL_SIGN_ED25519 = 10; +} + +message AlpnPolicy { + // If true, the application MUST perform ALPN negotiation. + bool enable_alpn_negotiation = 1; + + // The ordered list of ALPN protocols that specify how the application SHOULD + // negotiate ALPN during the TLS handshake. + // + // The application MAY ignore any ALPN protocols in this list that are not + // supported by the application. + repeated AlpnProtocol alpn_protocols = 2; +} + +message AuthenticationMechanism { + // Applications may specify an identity associated to an authentication + // mechanism. Otherwise, S2A assumes that the authentication mechanism is + // associated with the default identity. If the default identity cannot be + // determined, the request is rejected. + Identity identity = 1; + + oneof mechanism_oneof { + // A token that the application uses to authenticate itself to S2A. + string token = 2; + } +} + +message Status { + // The status code that is specific to the application and the implementation + // of S2A, e.g., gRPC status code. + uint32 code = 1; + + // The status details. + string details = 2; +} + +message GetTlsConfigurationReq { + // The role of the application in the TLS connection. + ConnectionSide connection_side = 1; + + // The server name indication (SNI) extension, which MAY be populated when a + // server is offloading to S2A. The SNI is used to determine the server + // identity if the local identity in the request is empty. + string sni = 2; +} + +message GetTlsConfigurationResp { + // Next ID: 8 + message ClientTlsConfiguration { + reserved 4, 5; + + // The certificate chain that the client MUST use for the TLS handshake. + // It's a list of PEM-encoded certificates, ordered from leaf to root, + // excluding the root. + repeated string certificate_chain = 1; + + // The minimum TLS version number that the client MUST use for the TLS + // handshake. If this field is not provided, the client MUST use the default + // minimum version of the client's TLS library. + TLSVersion min_tls_version = 2; + + // The maximum TLS version number that the client MUST use for the TLS + // handshake. If this field is not provided, the client MUST use the default + // maximum version of the client's TLS library. + TLSVersion max_tls_version = 3; + + // The ordered list of TLS 1.0-1.2 ciphersuites that the client MAY offer to + // negotiate in the TLS handshake. + repeated Ciphersuite ciphersuites = 6; + + // The policy that dictates how the client negotiates ALPN during the TLS + // handshake. + AlpnPolicy alpn_policy = 7; + } + + // Next ID: 12 + message ServerTlsConfiguration { + reserved 4, 5; + + enum RequestClientCertificate { + UNSPECIFIED = 0; + DONT_REQUEST_CLIENT_CERTIFICATE = 1; + REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY = 2; + REQUEST_CLIENT_CERTIFICATE_AND_VERIFY = 3; + REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY = 4; + REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY = 5; + } + + // The certificate chain that the server MUST use for the TLS handshake. + // It's a list of PEM-encoded certificates, ordered from leaf to root, + // excluding the root. + repeated string certificate_chain = 1; + + // The minimum TLS version number that the server MUST use for the TLS + // handshake. If this field is not provided, the server MUST use the default + // minimum version of the server's TLS library. + TLSVersion min_tls_version = 2; + + // The maximum TLS version number that the server MUST use for the TLS + // handshake. If this field is not provided, the server MUST use the default + // maximum version of the server's TLS library. + TLSVersion max_tls_version = 3; + + // The ordered list of TLS 1.0-1.2 ciphersuites that the server MAY offer to + // negotiate in the TLS handshake. + repeated Ciphersuite ciphersuites = 10; + + // Whether to enable TLS resumption. + bool tls_resumption_enabled = 6; + + // Whether the server MUST request a client certificate (i.e. to negotiate + // TLS vs. mTLS). + RequestClientCertificate request_client_certificate = 7; + + // Returns the maximum number of extra bytes that + // |OffloadResumptionKeyOperation| can add to the number of unencrypted + // bytes to form the encrypted bytes. + uint32 max_overhead_of_ticket_aead = 9; + + // The policy that dictates how the server negotiates ALPN during the TLS + // handshake. + AlpnPolicy alpn_policy = 11; + } + + oneof tls_configuration { + ClientTlsConfiguration client_tls_configuration = 1; + ServerTlsConfiguration server_tls_configuration = 2; + } +} + +message OffloadPrivateKeyOperationReq { + enum PrivateKeyOperation { + UNSPECIFIED = 0; + // When performing a TLS 1.2 or 1.3 handshake, the (partial) transcript of + // the TLS handshake must be signed to prove possession of the private key. + // + // See https://www.rfc-editor.org/rfc/rfc8446.html#section-4.4.3. + SIGN = 1; + // When performing a TLS 1.2 handshake using an RSA algorithm, the key + // exchange algorithm involves the client generating a premaster secret, + // encrypting it using the server's public key, and sending this encrypted + // blob to the server in a ClientKeyExchange message. + // + // See https://www.rfc-editor.org/rfc/rfc4346#section-7.4.7.1. + DECRYPT = 2; + } + + // The operation the private key is used for. + PrivateKeyOperation operation = 1; + + // The signature algorithm to be used for signing operations. + SignatureAlgorithm signature_algorithm = 2; + + // The input bytes to be signed or decrypted. + oneof in_bytes { + // Raw bytes to be hashed and signed, or decrypted. + bytes raw_bytes = 4; + // A SHA256 hash to be signed. Must be 32 bytes. + bytes sha256_digest = 5; + // A SHA384 hash to be signed. Must be 48 bytes. + bytes sha384_digest = 6; + // A SHA512 hash to be signed. Must be 64 bytes. + bytes sha512_digest = 7; + } +} + +message OffloadPrivateKeyOperationResp { + // The signed or decrypted output bytes. + bytes out_bytes = 1; +} + +message OffloadResumptionKeyOperationReq { + enum ResumptionKeyOperation { + UNSPECIFIED = 0; + ENCRYPT = 1; + DECRYPT = 2; + } + + // The operation the resumption key is used for. + ResumptionKeyOperation operation = 1; + + // The bytes to be encrypted or decrypted. + bytes in_bytes = 2; +} + +message OffloadResumptionKeyOperationResp { + // The encrypted or decrypted bytes. + bytes out_bytes = 1; +} + +message ValidatePeerCertificateChainReq { + enum VerificationMode { + // The default verification mode supported by S2A. + UNSPECIFIED = 0; + // The SPIFFE verification mode selects the set of trusted certificates to + // use for path building based on the SPIFFE trust domain in the peer's leaf + // certificate. + SPIFFE = 1; + // The connect-to-Google verification mode uses the trust bundle for + // connecting to Google, e.g. *.mtls.googleapis.com endpoints. + CONNECT_TO_GOOGLE = 2; + } + + message ClientPeer { + // The certificate chain to be verified. The chain MUST be a list of + // DER-encoded certificates, ordered from leaf to root, excluding the root. + repeated bytes certificate_chain = 1; + } + + message ServerPeer { + // The certificate chain to be verified. The chain MUST be a list of + // DER-encoded certificates, ordered from leaf to root, excluding the root. + repeated bytes certificate_chain = 1; + + // The expected hostname of the server. + string server_hostname = 2; + + // The UnrestrictedClientPolicy specified by the user. + bytes serialized_unrestricted_client_policy = 3; + } + + // The verification mode that S2A MUST use to validate the peer certificate + // chain. + VerificationMode mode = 1; + + oneof peer_oneof { + ClientPeer client_peer = 2; + ServerPeer server_peer = 3; + } +} + +message ValidatePeerCertificateChainResp { + enum ValidationResult { + UNSPECIFIED = 0; + SUCCESS = 1; + FAILURE = 2; + } + + // The result of validating the peer certificate chain. + ValidationResult validation_result = 1; + + // The validation details. This field is only populated when the validation + // result is NOT SUCCESS. + string validation_details = 2; + + // The S2A context contains information from the peer certificate chain. + // + // The S2A context MAY be populated even if validation of the peer certificate + // chain fails. + S2AContext context = 3; +} + +message SessionReq { + // The identity corresponding to the TLS configurations that MUST be used for + // the TLS handshake. + // + // If a managed identity already exists, the local identity and authentication + // mechanisms are ignored. If a managed identity doesn't exist and the local + // identity is not populated, S2A will try to deduce the managed identity to + // use from the SNI extension. If that also fails, S2A uses the default + // identity (if one exists). + Identity local_identity = 1; + + // The authentication mechanisms that the application wishes to use to + // authenticate to S2A, ordered by preference. S2A will always use the first + // authentication mechanism that matches the managed identity. + repeated AuthenticationMechanism authentication_mechanisms = 2; + + oneof req_oneof { + // Requests the certificate chain and TLS configuration corresponding to the + // local identity, which the application MUST use to negotiate the TLS + // handshake. + GetTlsConfigurationReq get_tls_configuration_req = 3; + + // Signs or decrypts the input bytes using a private key corresponding to + // the local identity in the request. + // + // WARNING: More than one OffloadPrivateKeyOperationReq may be sent to the + // S2Av2 by a server during a TLS 1.2 handshake. + OffloadPrivateKeyOperationReq offload_private_key_operation_req = 4; + + // Encrypts or decrypts the input bytes using a resumption key corresponding + // to the local identity in the request. + OffloadResumptionKeyOperationReq offload_resumption_key_operation_req = 5; + + // Verifies the peer's certificate chain using + // (a) trust bundles corresponding to the local identity in the request, and + // (b) the verification mode in the request. + ValidatePeerCertificateChainReq validate_peer_certificate_chain_req = 6; + } +} + +message SessionResp { + // Status of the session response. + // + // The status field is populated so that if an error occurs when making an + // individual request, then communication with the S2A may continue. If an + // error is returned directly (e.g. at the gRPC layer), then it may result + // that the bidirectional stream being closed. + Status status = 1; + + oneof resp_oneof { + // Contains the certificate chain and TLS configurations corresponding to + // the local identity. + GetTlsConfigurationResp get_tls_configuration_resp = 2; + + // Contains the signed or encrypted output bytes using the private key + // corresponding to the local identity. + OffloadPrivateKeyOperationResp offload_private_key_operation_resp = 3; + + // Contains the encrypted or decrypted output bytes using the resumption key + // corresponding to the local identity. + OffloadResumptionKeyOperationResp offload_resumption_key_operation_resp = 4; + + // Contains the validation result, peer identity and fingerprints of peer + // certificates. + ValidatePeerCertificateChainResp validate_peer_certificate_chain_resp = 5; + } +} + +service S2AService { + // SetUpSession is a bidirectional stream used by applications to offload + // operations from the TLS handshake. + rpc SetUpSession(stream SessionReq) returns (stream SessionResp) {} +} diff --git a/s2a/src/main/proto/grpc/gcp/s2a_context.proto b/s2a/src/main/proto/grpc/gcp/s2a_context.proto new file mode 100644 index 00000000000..5ad264bf875 --- /dev/null +++ b/s2a/src/main/proto/grpc/gcp/s2a_context.proto @@ -0,0 +1,61 @@ +// Copyright 2024 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The canonical version of this proto can be found at +// https://github.com/grpc/grpc-proto/blob/master/grpc/gcp/s2a/s2a_context.proto +syntax = "proto3"; + +package grpc.gcp; + +import "grpc/gcp/common.proto"; + +option java_multiple_files = true; +option java_outer_classname = "S2AContextProto"; +option java_package = "io.grpc.s2a.handshaker"; + +message S2AContext { + // The SPIFFE ID from the peer leaf certificate, if present. + // + // This field is only populated if the leaf certificate is a valid SPIFFE + // SVID; in particular, there is a unique URI SAN and this URI SAN is a valid + // SPIFFE ID. + string leaf_cert_spiffe_id = 1; + + // The URIs that are present in the SubjectAltName extension of the peer leaf + // certificate. + // + // Note that the extracted URIs are not validated and may not be properly + // formatted. + repeated string leaf_cert_uris = 2; + + // The DNSNames that are present in the SubjectAltName extension of the peer + // leaf certificate. + repeated string leaf_cert_dnsnames = 3; + + // The (ordered) list of fingerprints in the certificate chain used to verify + // the given leaf certificate. The order MUST be from leaf certificate + // fingerprint to root certificate fingerprint. + // + // A fingerprint is the base-64 encoding of the SHA256 hash of the + // DER-encoding of a certificate. The list MAY be populated even if the peer + // certificate chain was NOT validated successfully. + repeated string peer_certificate_chain_fingerprints = 4; + + // The local identity used during session setup. + Identity local_identity = 5; + + // The SHA256 hash of the DER-encoding of the local leaf certificate used in + // the handshake. + bytes local_leaf_cert_fingerprint = 6; +} diff --git a/s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java b/s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java new file mode 100644 index 00000000000..5ccc522292e --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java @@ -0,0 +1,135 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class MtlsToS2AChannelCredentialsTest { + @Test + public void createBuilder_nullAddress_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ null, + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_nullPrivateKeyPath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ null, + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_nullCertChainPath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ null, + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_nullTrustBundlePath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ null)); + } + + @Test + public void createBuilder_emptyAddress_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_emptyPrivateKeyPath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_emptyCertChainPath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "", + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_emptyTrustBundlePath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "")); + } + + @Test + public void build_s2AChannelCredentials_success() throws Exception { + assertThat( + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem") + .build()) + .isInstanceOf(S2AChannelCredentials.Builder.class); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java b/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java new file mode 100644 index 00000000000..a6133ed0af8 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.ChannelCredentials; +import io.grpc.TlsChannelCredentials; +import java.io.File; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@code S2AChannelCredentials}. */ +@RunWith(JUnit4.class) +public final class S2AChannelCredentialsTest { + @Test + public void createBuilder_nullArgument_throwsException() throws Exception { + assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.createBuilder(null)); + } + + @Test + public void createBuilder_emptyAddress_throwsException() throws Exception { + assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.createBuilder("")); + } + + @Test + public void setLocalSpiffeId_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.createBuilder("s2a_address").setLocalSpiffeId(null)); + } + + @Test + public void setLocalHostname_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.createBuilder("s2a_address").setLocalHostname(null)); + } + + @Test + public void setLocalUid_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.createBuilder("s2a_address").setLocalUid(null)); + } + + @Test + public void build_withLocalSpiffeId_succeeds() throws Exception { + assertThat( + S2AChannelCredentials.createBuilder("s2a_address") + .setLocalSpiffeId("spiffe://test") + .build()) + .isNotNull(); + } + + @Test + public void build_withLocalHostname_succeeds() throws Exception { + assertThat( + S2AChannelCredentials.createBuilder("s2a_address") + .setLocalHostname("local_hostname") + .build()) + .isNotNull(); + } + + @Test + public void build_withLocalUid_succeeds() throws Exception { + assertThat(S2AChannelCredentials.createBuilder("s2a_address").setLocalUid("local_uid").build()) + .isNotNull(); + } + + @Test + public void build_withNoLocalIdentity_succeeds() throws Exception { + assertThat(S2AChannelCredentials.createBuilder("s2a_address").build()) + .isNotNull(); + } + + @Test + public void build_withTlsChannelCredentials_succeeds() throws Exception { + assertThat( + S2AChannelCredentials.createBuilder("s2a_address") + .setLocalSpiffeId("spiffe://test") + .setS2AChannelCredentials(getTlsChannelCredentials()) + .build()) + .isNotNull(); + } + + private static ChannelCredentials getTlsChannelCredentials() throws Exception { + File clientCert = new File("src/test/resources/client_cert.pem"); + File clientKey = new File("src/test/resources/client_key.pem"); + File rootCert = new File("src/test/resources/root_cert.pem"); + return TlsChannelCredentials.newBuilder() + .keyManager(clientCert, clientKey) + .trustManager(rootCert) + .build(); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java b/s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java new file mode 100644 index 00000000000..13eccac682d --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java @@ -0,0 +1,125 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.channel; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; + +import io.grpc.Channel; +import io.grpc.internal.ObjectPool; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AGrpcChannelPool}. */ +@RunWith(JUnit4.class) +public final class S2AGrpcChannelPoolTest { + @Test + public void getChannel_success() throws Exception { + FakeChannelPool fakeChannelPool = new FakeChannelPool(); + S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(fakeChannelPool); + + Channel channel = s2aChannelPool.getChannel(); + + assertThat(channel).isNotNull(); + assertThat(fakeChannelPool.isChannelCached()).isTrue(); + assertThat(s2aChannelPool.getChannel()).isEqualTo(channel); + } + + @Test + public void returnChannel_success() throws Exception { + FakeChannelPool fakeChannelPool = new FakeChannelPool(); + S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(fakeChannelPool); + + s2aChannelPool.returnChannel(s2aChannelPool.getChannel()); + + assertThat(fakeChannelPool.isChannelCached()).isFalse(); + } + + @Test + public void returnChannel_channelStillCachedBecauseMultipleChannelsRetrieved() throws Exception { + FakeChannelPool fakeChannelPool = new FakeChannelPool(); + S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(fakeChannelPool); + + s2aChannelPool.getChannel(); + s2aChannelPool.returnChannel(s2aChannelPool.getChannel()); + + assertThat(fakeChannelPool.isChannelCached()).isTrue(); + } + + @Test + public void returnChannel_failureBecauseChannelWasNotFromPool() throws Exception { + S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(new FakeChannelPool()); + + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> s2aChannelPool.returnChannel(mock(Channel.class))); + assertThat(expected) + .hasMessageThat() + .isEqualTo( + "Cannot return the channel to channel pool because the channel was not obtained from" + + " channel pool."); + } + + @Test + public void close_success() throws Exception { + FakeChannelPool fakeChannelPool = new FakeChannelPool(); + try (S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(fakeChannelPool)) { + s2aChannelPool.getChannel(); + } + + assertThat(fakeChannelPool.isChannelCached()).isFalse(); + } + + @Test + public void close_poolIsUnusable() throws Exception { + S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(new FakeChannelPool()); + s2aChannelPool.close(); + + IllegalStateException expected = + assertThrows(IllegalStateException.class, s2aChannelPool::getChannel); + + assertThat(expected).hasMessageThat().isEqualTo("Channel pool is not open."); + } + + private static class FakeChannelPool implements ObjectPool { + private final Channel mockChannel = mock(Channel.class); + private @Nullable Channel cachedChannel = null; + + @Override + public Channel getObject() { + if (cachedChannel == null) { + cachedChannel = mockChannel; + } + return cachedChannel; + } + + @Override + public Channel returnObject(Object object) { + assertThat(object).isSameInstanceAs(mockChannel); + cachedChannel = null; + return null; + } + + public boolean isChannelCached() { + return (cachedChannel != null); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java b/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java new file mode 100644 index 00000000000..57288be1b6f --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java @@ -0,0 +1,390 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.channel; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.StatusRuntimeException; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.benchmarks.Utils; +import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.s2a.channel.S2AHandshakerServiceChannel.EventLoopHoldingChannel; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.netty.channel.EventLoopGroup; +import java.io.File; +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AHandshakerServiceChannel}. */ +@RunWith(JUnit4.class) +public final class S2AHandshakerServiceChannelTest { + @ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10); + private final EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class); + private Server mtlsServer; + private Server plaintextServer; + + @Before + public void setUp() throws Exception { + mtlsServer = createMtlsServer(); + plaintextServer = createPlaintextServer(); + mtlsServer.start(); + plaintextServer.start(); + } + + /** + * Creates a {@code Resource} and verifies that it produces a {@code ChannelResource} + * instance by using its {@code toString()} method. + */ + @Test + public void getChannelResource_success() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + assertThat(resource.toString()).isEqualTo("grpc-s2a-channel"); + } + + /** Same as getChannelResource_success, but use mTLS. */ + @Test + public void getChannelResource_mtlsSuccess() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + assertThat(resource.toString()).isEqualTo("grpc-s2a-channel"); + } + + /** + * Creates two {@code Resoure}s for the same target address and verifies that they are + * equal. + */ + @Test + public void getChannelResource_twoEqualChannels() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + assertThat(resource).isEqualTo(resourceTwo); + } + + /** Same as getChannelResource_twoEqualChannels, but use mTLS. */ + @Test + public void getChannelResource_mtlsTwoEqualChannels() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + assertThat(resource).isEqualTo(resourceTwo); + } + + /** + * Creates two {@code Resoure}s for different target addresses and verifies that they are + * distinct. + */ + @Test + public void getChannelResource_twoDistinctChannels() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + Utils.pickUnusedPort(), /* s2aChannelCredentials= */ Optional.empty()); + assertThat(resourceTwo).isNotEqualTo(resource); + } + + /** Same as getChannelResource_twoDistinctChannels, but use mTLS. */ + @Test + public void getChannelResource_mtlsTwoDistinctChannels() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + Utils.pickUnusedPort(), getTlsChannelCredentials()); + assertThat(resourceTwo).isNotEqualTo(resource); + } + + /** + * Uses a {@code Resource} to create a channel, closes the channel, and verifies that the + * channel is closed by attempting to make a simple RPC. + */ + @Test + public void close_success() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + Channel channel = resource.create(); + resource.close(channel); + StatusRuntimeException expected = + assertThrows( + StatusRuntimeException.class, + () -> + SimpleServiceGrpc.newBlockingStub(channel) + .unaryRpc(SimpleRequest.getDefaultInstance())); + assertThat(expected).hasMessageThat().isEqualTo("UNAVAILABLE: Channel shutdown invoked"); + } + + /** Same as close_success, but use mTLS. */ + @Test + public void close_mtlsSuccess() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Channel channel = resource.create(); + resource.close(channel); + StatusRuntimeException expected = + assertThrows( + StatusRuntimeException.class, + () -> + SimpleServiceGrpc.newBlockingStub(channel) + .unaryRpc(SimpleRequest.getDefaultInstance())); + assertThat(expected).hasMessageThat().isEqualTo("UNAVAILABLE: Channel shutdown invoked"); + } + + /** + * Verifies that an {@code EventLoopHoldingChannel}'s {@code newCall} method can be used to + * perform a simple RPC. + */ + @Test + public void newCall_performSimpleRpcSuccess() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + Channel channel = resource.create(); + assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + } + + /** Same as newCall_performSimpleRpcSuccess, but use mTLS. */ + @Test + public void newCall_mtlsPerformSimpleRpcSuccess() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Channel channel = resource.create(); + assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + } + + /** Creates a {@code EventLoopHoldingChannel} instance and verifies its authority. */ + @Test + public void authority_success() throws Exception { + ManagedChannel channel = new FakeManagedChannel(true); + EventLoopHoldingChannel eventLoopHoldingChannel = + EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + assertThat(eventLoopHoldingChannel.authority()).isEqualTo("FakeManagedChannel"); + } + + /** + * Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} terminates + * successfully. + */ + @Test + public void close_withDelegateTerminatedSuccess() throws Exception { + ManagedChannel channel = new FakeManagedChannel(true); + EventLoopHoldingChannel eventLoopHoldingChannel = + EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + eventLoopHoldingChannel.close(); + assertThat(channel.isShutdown()).isTrue(); + verify(mockEventLoopGroup, times(1)) + .shutdownGracefully(0, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); + } + + /** + * Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} does not + * terminate successfully. + */ + @Test + public void close_withDelegateTerminatedFailure() throws Exception { + ManagedChannel channel = new FakeManagedChannel(false); + EventLoopHoldingChannel eventLoopHoldingChannel = + EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + eventLoopHoldingChannel.close(); + assertThat(channel.isShutdown()).isTrue(); + verify(mockEventLoopGroup, times(1)) + .shutdownGracefully(1, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); + } + + /** + * Creates and closes a {@code EventLoopHoldingChannel}, creates a new channel from the same + * resource, and verifies that this second channel is useable. + */ + @Test + public void create_succeedsAfterCloseIsCalledOnce() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + Channel channelOne = resource.create(); + resource.close(channelOne); + + Channel channelTwo = resource.create(); + assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channelTwo) + .unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + resource.close(channelTwo); + } + + /** Same as create_succeedsAfterCloseIsCalledOnce, but use mTLS. */ + @Test + public void create_mtlsSucceedsAfterCloseIsCalledOnce() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Channel channelOne = resource.create(); + resource.close(channelOne); + + Channel channelTwo = resource.create(); + assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channelTwo) + .unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + resource.close(channelTwo); + } + + private static Server createMtlsServer() throws Exception { + SimpleServiceImpl service = new SimpleServiceImpl(); + File serverCert = new File("src/test/resources/server_cert.pem"); + File serverKey = new File("src/test/resources/server_key.pem"); + File rootCert = new File("src/test/resources/root_cert.pem"); + ServerCredentials creds = + TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverKey) + .trustManager(rootCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + return grpcCleanup.register( + NettyServerBuilder.forPort(Utils.pickUnusedPort(), creds).addService(service).build()); + } + + private static Server createPlaintextServer() { + SimpleServiceImpl service = new SimpleServiceImpl(); + return grpcCleanup.register( + ServerBuilder.forPort(Utils.pickUnusedPort()).addService(service).build()); + } + + private static Optional getTlsChannelCredentials() throws Exception { + File clientCert = new File("src/test/resources/client_cert.pem"); + File clientKey = new File("src/test/resources/client_key.pem"); + File rootCert = new File("src/test/resources/root_cert.pem"); + return Optional.of( + TlsChannelCredentials.newBuilder() + .keyManager(clientCert, clientKey) + .trustManager(rootCert) + .build()); + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest request, StreamObserver streamObserver) { + streamObserver.onNext(SimpleResponse.getDefaultInstance()); + streamObserver.onCompleted(); + } + } + + private static class FakeManagedChannel extends ManagedChannel { + private final boolean isDelegateTerminatedSuccess; + private boolean isShutdown = false; + + FakeManagedChannel(boolean isDelegateTerminatedSuccess) { + this.isDelegateTerminatedSuccess = isDelegateTerminatedSuccess; + } + + @Override + public String authority() { + return "FakeManagedChannel"; + } + + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions options) { + throw new UnsupportedOperationException("This method should not be called."); + } + + @Override + public ManagedChannel shutdown() { + throw new UnsupportedOperationException("This method should not be called."); + } + + @Override + public boolean isShutdown() { + return isShutdown; + } + + @Override + public boolean isTerminated() { + throw new UnsupportedOperationException("This method should not be called."); + } + + @Override + public ManagedChannel shutdownNow() { + isShutdown = true; + return null; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + if (isDelegateTerminatedSuccess) { + return true; + } + throw new InterruptedException("Await termination was interrupted."); + } + } +} diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServer.java b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServer.java new file mode 100644 index 00000000000..66f636ada22 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServer.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import io.grpc.stub.StreamObserver; +import java.security.NoSuchAlgorithmException; +import java.security.spec.InvalidKeySpecException; +import java.util.logging.Logger; + +/** A fake S2Av2 server that should be used for testing only. */ +public final class FakeS2AServer extends S2AServiceGrpc.S2AServiceImplBase { + private static final Logger logger = Logger.getLogger(FakeS2AServer.class.getName()); + + private final FakeWriter writer; + + public FakeS2AServer() throws InvalidKeySpecException, NoSuchAlgorithmException { + this.writer = new FakeWriter(); + this.writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS).initializePrivateKey(); + } + + @Override + public StreamObserver setUpSession(StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(SessionReq req) { + logger.info("Received a request from client."); + responseObserver.onNext(writer.handleResponse(req)); + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServerTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServerTest.java new file mode 100644 index 00000000000..e200d119867 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServerTest.java @@ -0,0 +1,265 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.benchmarks.Utils; +import io.grpc.s2a.handshaker.ValidatePeerCertificateChainReq.VerificationMode; +import io.grpc.stub.StreamObserver; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link FakeS2AServer}. */ +@RunWith(JUnit4.class) +public final class FakeS2AServerTest { + private static final Logger logger = Logger.getLogger(FakeS2AServerTest.class.getName()); + + private static final ImmutableList FAKE_CERT_DER_CHAIN = + ImmutableList.of( + ByteString.copyFrom( + new byte[] {'f', 'a', 'k', 'e', '-', 'd', 'e', 'r', '-', 'c', 'h', 'a', 'i', 'n'})); + private int port; + private String serverAddress; + private SessionResp response = null; + private Server fakeS2AServer; + + @Before + public void setUp() throws Exception { + port = Utils.pickUnusedPort(); + fakeS2AServer = ServerBuilder.forPort(port).addService(new FakeS2AServer()).build(); + fakeS2AServer.start(); + serverAddress = String.format("localhost:%d", port); + } + + @After + public void tearDown() { + fakeS2AServer.shutdown(); + } + + @Test + public void callS2AServerOnce_getTlsConfiguration_returnsValidResult() + throws InterruptedException { + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + @Override + public void onNext(SessionResp resp) { + response = resp; + } + + @Override + public void onError(Throwable t) { + throw new RuntimeException(t); + } + + @Override + public void onCompleted() {} + }); + try { + requestObserver.onNext( + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build()); + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + + SessionResp expected = + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(FakeWriter.LEAF_CERT) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_2) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLSVersion.TLS_VERSION_1_3) + .setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void callS2AServerOnce_validatePeerCertifiate_returnsValidResult() + throws InterruptedException { + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + @Override + public void onNext(SessionResp resp) { + response = resp; + } + + @Override + public void onError(Throwable t) { + throw new RuntimeException(t); + } + + @Override + public void onCompleted() {} + }); + try { + requestObserver.onNext( + SessionReq.newBuilder() + .setValidatePeerCertificateChainReq( + ValidatePeerCertificateChainReq.newBuilder() + .setMode(VerificationMode.UNSPECIFIED) + .setClientPeer( + ValidatePeerCertificateChainReq.ClientPeer.newBuilder() + .addAllCertificateChain(FAKE_CERT_DER_CHAIN))) + .build()); + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + + SessionResp expected = + SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS)) + .build(); + assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void callS2AServerRepeatedly_returnsValidResult() throws InterruptedException { + final int numberOfRequests = 10; + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + CountDownLatch finishLatch = new CountDownLatch(1); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + private int expectedNumberOfReplies = numberOfRequests; + + @Override + public void onNext(SessionResp reply) { + System.out.println("Received a message from the S2AService service."); + expectedNumberOfReplies -= 1; + } + + @Override + public void onError(Throwable t) { + finishLatch.countDown(); + if (expectedNumberOfReplies != 0) { + throw new RuntimeException(t); + } + } + + @Override + public void onCompleted() { + finishLatch.countDown(); + if (expectedNumberOfReplies != 0) { + throw new RuntimeException(); + } + } + }); + try { + for (int i = 0; i < numberOfRequests; i++) { + requestObserver.onNext(SessionReq.getDefaultInstance()); + } + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + if (!finishLatch.await(10, SECONDS)) { + throw new RuntimeException(); + } + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + } + +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java new file mode 100644 index 00000000000..505a0cf4a3a --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java @@ -0,0 +1,347 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static io.grpc.s2a.handshaker.TLSVersion.TLS_VERSION_1_3; + +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.protobuf.ByteString; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.security.KeyFactory; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.Signature; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.Base64; + +/** A fake Writer Class to mock the behavior of S2A server. */ +final class FakeWriter implements StreamObserver { + /** Fake behavior of S2A service. */ + enum Behavior { + OK_STATUS, + EMPTY_RESPONSE, + ERROR_STATUS, + ERROR_RESPONSE, + COMPLETE_STATUS + } + + enum VerificationResult { + UNSPECIFIED, + SUCCESS, + FAILURE + } + + public static final String LEAF_CERT = + "-----BEGIN CERTIFICATE-----\n" + + "MIICkDCCAjagAwIBAgIUSAtcrPhNNs1zxv51lIfGOVtkw6QwCgYIKoZIzj0EAwIw\n" + + "QTEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xEDAOBgNVBAsMB2NvbnRleHQxFDAS\n" + + "BgorBgEEAdZ5AggBDAQyMDIyMCAXDTIzMDcxNDIyMzYwNFoYDzIwNTAxMTI5MjIz\n" + + "NjA0WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC\n" + + "AAQGFlJpLxJMh4HuUm0DKjnUF7larH3tJvroQ12xpk+pPKQepn4ILoq9lZ8Xd3jz\n" + + "U98eDRXG5f4VjnX98DDHE4Ido4IBODCCATQwDgYDVR0PAQH/BAQDAgeAMCAGA1Ud\n" + + "JQEB/wQWMBQGCCsGAQUFBwMCBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMIGxBgNV\n" + + "HREBAf8EgaYwgaOGSnNwaWZmZTovL3NpZ25lci1yb2xlLmNvbnRleHQuc2VjdXJp\n" + + "dHktcmVhbG0ucHJvZC5nb29nbGUuY29tL3JvbGUvbGVhZi1yb2xlgjNzaWduZXIt\n" + + "cm9sZS5jb250ZXh0LnNlY3VyaXR5LXJlYWxtLnByb2Quc3BpZmZlLmdvb2eCIGZx\n" + + "ZG4tb2YtdGhlLW5vZGUucHJvZC5nb29nbGUuY29tMB0GA1UdDgQWBBSWSd5Fw6dI\n" + + "TGpt0m1Uxwf0iKqebzAfBgNVHSMEGDAWgBRm5agVVdpWfRZKM7u6OMuzHhqPcDAK\n" + + "BggqhkjOPQQDAgNIADBFAiB0sjRPSYy2eFq8Y0vQ8QN4AZ2NMajskvxnlifu7O4U\n" + + "RwIhANTh5Fkyx2nMYFfyl+W45dY8ODTw3HnlZ4b51hTAdkWl\n" + + "-----END CERTIFICATE-----"; + public static final String INTERMEDIATE_CERT_2 = + "-----BEGIN CERTIFICATE-----\n" + + "MIICQjCCAeigAwIBAgIUKxXRDlnWXefNV5lj5CwhDuXEq7MwCgYIKoZIzj0EAwIw\n" + + "OzEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xEDAOBgNVBAsMB2NvbnRleHQxDjAM\n" + + "BgNVBAMMBTEyMzQ1MCAXDTIzMDcxNDIyMzYwNFoYDzIwNTAxMTI5MjIzNjA0WjBB\n" + + "MRcwFQYDVQQKDA5zZWN1cml0eS1yZWFsbTEQMA4GA1UECwwHY29udGV4dDEUMBIG\n" + + "CisGAQQB1nkCCAEMBDIwMjIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAT/Zu7x\n" + + "UYVyg+T/vg2H+y4I6t36Kc4qxD0eqqZjRLYBVKkUQHxBqc14t0DpoROMYQCNd4DF\n" + + "pcxv/9m6DaJbRk6Ao4HBMIG+MA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAG\n" + + "AQH/AgEBMFgGA1UdHgEB/wROMEygSjA1gjNzaWduZXItcm9sZS5jb250ZXh0LnNl\n" + + "Y3VyaXR5LXJlYWxtLnByb2Quc3BpZmZlLmdvb2cwEYIPcHJvZC5nb29nbGUuY29t\n" + + "MB0GA1UdDgQWBBRm5agVVdpWfRZKM7u6OMuzHhqPcDAfBgNVHSMEGDAWgBQcjNAh\n" + + "SCHTj+BW8KrzSSLo2ASEgjAKBggqhkjOPQQDAgNIADBFAiEA6KyGd9VxXDZceMZG\n" + + "IsbC40rtunFjLYI0mjZw9RcRWx8CIHCIiIHxafnDaCi+VB99NZfzAdu37g6pJptB\n" + + "gjIY71MO\n" + + "-----END CERTIFICATE-----"; + public static final String INTERMEDIATE_CERT_1 = + "-----BEGIN CERTIFICATE-----\n" + + "MIICODCCAd6gAwIBAgIUXtZECORWRSKnS9rRTJYkiALUXswwCgYIKoZIzj0EAwIw\n" + + "NzEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xDTALBgNVBAsMBHJvb3QxDTALBgNV\n" + + "BAMMBDEyMzQwIBcNMjMwNzE0MjIzNjA0WhgPMjA1MDExMjkyMjM2MDRaMDsxFzAV\n" + + "BgNVBAoMDnNlY3VyaXR5LXJlYWxtMRAwDgYDVQQLDAdjb250ZXh0MQ4wDAYDVQQD\n" + + "DAUxMjM0NTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABAycVTZrjockbpD59f1a\n" + + "4l1SNL7nSyXz66Guz4eDveQqLmaMBg7vpACfO4CtiAGnolHEffuRtSkdM434m5En\n" + + "bXCjgcEwgb4wDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQIwWAYD\n" + + "VR0eAQH/BE4wTKBKMDWCM3NpZ25lci1yb2xlLmNvbnRleHQuc2VjdXJpdHktcmVh\n" + + "bG0ucHJvZC5zcGlmZmUuZ29vZzARgg9wcm9kLmdvb2dsZS5jb20wHQYDVR0OBBYE\n" + + "FByM0CFIIdOP4FbwqvNJIujYBISCMB8GA1UdIwQYMBaAFMX+vebuj/lYfYEC23IA\n" + + "8HoIW0HsMAoGCCqGSM49BAMCA0gAMEUCIQCfxeXEBd7UPmeImT16SseCRu/6cHxl\n" + + "kTDsq9sKZ+eXBAIgA+oViAVOUhUQO1/6Mjlczg8NmMy2vNtG4V/7g9dMMVU=\n" + + "-----END CERTIFICATE-----"; + + private static final String PRIVATE_KEY = + "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgqA2U0ld1OOHLMXWf" + + "uyN4GSaqhhudEIaKkll3rdIq0M+hRANCAAQGFlJpLxJMh4HuUm0DKjnUF7larH3t" + + "JvroQ12xpk+pPKQepn4ILoq9lZ8Xd3jzU98eDRXG5f4VjnX98DDHE4Id"; + private static final ImmutableMap + ALGORITHM_TO_SIGNATURE_INSTANCE_IDENTIFIER = + ImmutableMap.of( + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256, + "SHA256withECDSA", + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384, + "SHA384withECDSA", + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512, + "SHA512withECDSA"); + + private boolean fakeWriterClosed = false; + private Behavior behavior = Behavior.OK_STATUS; + private StreamObserver reader; + private VerificationResult verificationResult = VerificationResult.UNSPECIFIED; + private String failureReason; + private PrivateKey privateKey; + + @CanIgnoreReturnValue + FakeWriter setReader(StreamObserver reader) { + this.reader = reader; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setBehavior(Behavior behavior) { + this.behavior = behavior; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setVerificationResult(VerificationResult verificationResult) { + this.verificationResult = verificationResult; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setFailureReason(String failureReason) { + this.failureReason = failureReason; + return this; + } + + @CanIgnoreReturnValue + FakeWriter initializePrivateKey() throws InvalidKeySpecException, NoSuchAlgorithmException { + privateKey = + KeyFactory.getInstance("EC") + .generatePrivate(new PKCS8EncodedKeySpec(Base64.getDecoder().decode(PRIVATE_KEY))); + return this; + } + + @CanIgnoreReturnValue + FakeWriter resetPrivateKey() { + privateKey = null; + return this; + } + + void sendUnexpectedResponse() { + reader.onNext(SessionResp.getDefaultInstance()); + } + + void sendIoError() { + reader.onError(new IOException("Intended ERROR from FakeWriter.")); + } + + void sendGetTlsConfigResp() { + reader.onNext( + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(LEAF_CERT) + .addCertificateChain(INTERMEDIATE_CERT_2) + .addCertificateChain(INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build()); + } + + boolean isFakeWriterClosed() { + return fakeWriterClosed; + } + + @Override + public void onNext(SessionReq sessionReq) { + switch (behavior) { + case OK_STATUS: + reader.onNext(handleResponse(sessionReq)); + break; + case EMPTY_RESPONSE: + reader.onNext(SessionResp.getDefaultInstance()); + break; + case ERROR_STATUS: + reader.onNext( + SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(1) + .setDetails("Intended ERROR Status from FakeWriter.")) + .build()); + break; + case ERROR_RESPONSE: + reader.onError(new S2AConnectionException("Intended ERROR from FakeWriter.")); + break; + case COMPLETE_STATUS: + reader.onCompleted(); + break; + default: + reader.onNext(handleResponse(sessionReq)); + } + } + + SessionResp handleResponse(SessionReq sessionReq) { + if (sessionReq.hasGetTlsConfigurationReq()) { + return handleGetTlsConfigurationReq(sessionReq.getGetTlsConfigurationReq()); + } + + if (sessionReq.hasValidatePeerCertificateChainReq()) { + return handleValidatePeerCertificateChainReq(sessionReq.getValidatePeerCertificateChainReq()); + } + + if (sessionReq.hasOffloadPrivateKeyOperationReq()) { + return handleOffloadPrivateKeyOperationReq(sessionReq.getOffloadPrivateKeyOperationReq()); + } + + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder().setCode(255).setDetails("No supported operation designated.")) + .build(); + } + + private SessionResp handleGetTlsConfigurationReq(GetTlsConfigurationReq req) { + if (!req.getConnectionSide().equals(ConnectionSide.CONNECTION_SIDE_CLIENT)) { + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("No TLS configuration for the server side.")) + .build(); + } + return SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(LEAF_CERT) + .addCertificateChain(INTERMEDIATE_CERT_2) + .addCertificateChain(INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + } + + private SessionResp handleValidatePeerCertificateChainReq(ValidatePeerCertificateChainReq req) { + if (verifyValidatePeerCertificateChainReq(req) + && verificationResult == VerificationResult.SUCCESS) { + return SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS)) + .build(); + } + return SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult( + verificationResult == VerificationResult.FAILURE + ? ValidatePeerCertificateChainResp.ValidationResult.FAILURE + : ValidatePeerCertificateChainResp.ValidationResult.UNSPECIFIED) + .setValidationDetails(failureReason)) + .build(); + } + + private boolean verifyValidatePeerCertificateChainReq(ValidatePeerCertificateChainReq req) { + if (req.getMode() != ValidatePeerCertificateChainReq.VerificationMode.UNSPECIFIED) { + return false; + } + if (req.getClientPeer().getCertificateChainCount() > 0) { + return true; + } + if (req.getServerPeer().getCertificateChainCount() > 0 + && !req.getServerPeer().getServerHostname().isEmpty()) { + return true; + } + return false; + } + + private SessionResp handleOffloadPrivateKeyOperationReq(OffloadPrivateKeyOperationReq req) { + if (privateKey == null) { + return SessionResp.newBuilder() + .setStatus(Status.newBuilder().setCode(255).setDetails("No Private Key available.")) + .build(); + } + String signatureIdentifier = + ALGORITHM_TO_SIGNATURE_INSTANCE_IDENTIFIER.get(req.getSignatureAlgorithm()); + if (signatureIdentifier == null) { + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("Only ECDSA key algorithms are supported.")) + .build(); + } + + byte[] signature; + try { + Signature sig = Signature.getInstance(signatureIdentifier); + sig.initSign(privateKey); + sig.update(req.getRawBytes().toByteArray()); + signature = sig.sign(); + } catch (Exception e) { + return SessionResp.newBuilder() + .setStatus(Status.newBuilder().setCode(255).setDetails(e.getMessage())) + .build(); + } + + return SessionResp.newBuilder() + .setOffloadPrivateKeyOperationResp( + OffloadPrivateKeyOperationResp.newBuilder().setOutBytes(ByteString.copyFrom(signature))) + .build(); + } + + @Override + public void onError(Throwable t) { + throw new UnsupportedOperationException("onError is not supported by FakeWriter."); + } + + @Override + public void onCompleted() { + fakeWriterClosed = true; + reader.onCompleted(); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java new file mode 100644 index 00000000000..aea279ed8c5 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import com.beust.jcommander.JCommander; +import com.google.common.truth.Expect; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.grpc.s2a.handshaker.tokenmanager.SingleTokenFetcher; +import java.util.Optional; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link GetAuthenticationMechanisms}. */ +@RunWith(JUnit4.class) +public final class GetAuthenticationMechanismsTest { + @Rule public final Expect expect = Expect.create(); + private static final String TOKEN = "access_token"; + private static final String[] SET_TOKEN = {"--s2a_access_token", TOKEN}; + private static final SingleTokenFetcher.Flags FLAGS = new SingleTokenFetcher.Flags(); + + @BeforeClass + public static void setUpClass() { + // Set the token that the client will use to authenticate to the S2A. + JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + } + + @Test + public void getAuthMechanisms_emptyIdentity_success() { + expect + .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.empty())) + .isEqualTo( + Optional.of(AuthenticationMechanism.newBuilder().setToken("access_token").build())); + } + + @Test + public void getAuthMechanisms_nonEmptyIdentity_success() { + S2AIdentity fakeIdentity = S2AIdentity.fromSpiffeId("fake-spiffe-id"); + expect + .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.of(fakeIdentity))) + .isEqualTo( + Optional.of( + AuthenticationMechanism.newBuilder() + .setIdentity(fakeIdentity.identity()) + .setToken("access_token") + .build())); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java new file mode 100644 index 00000000000..f9de9765527 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java @@ -0,0 +1,322 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.SECONDS; + +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.benchmarks.Utils; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.s2a.MtlsToS2AChannelCredentials; +import io.grpc.s2a.S2AChannelCredentials; +import io.grpc.s2a.handshaker.FakeS2AServer; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.netty.handler.ssl.ClientAuth; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import java.io.ByteArrayInputStream; +import java.io.File; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSessionContext; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class IntegrationTest { + private static final Logger logger = Logger.getLogger(FakeS2AServer.class.getName()); + + private static final String CERT_CHAIN = + "-----BEGIN CERTIFICATE-----\n" + + "MIICkDCCAjagAwIBAgIUSAtcrPhNNs1zxv51lIfGOVtkw6QwCgYIKoZIzj0EAwIw\n" + + "QTEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xEDAOBgNVBAsMB2NvbnRleHQxFDAS\n" + + "BgorBgEEAdZ5AggBDAQyMDIyMCAXDTIzMDcxNDIyMzYwNFoYDzIwNTAxMTI5MjIz\n" + + "NjA0WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC\n" + + "AAQGFlJpLxJMh4HuUm0DKjnUF7larH3tJvroQ12xpk+pPKQepn4ILoq9lZ8Xd3jz\n" + + "U98eDRXG5f4VjnX98DDHE4Ido4IBODCCATQwDgYDVR0PAQH/BAQDAgeAMCAGA1Ud\n" + + "JQEB/wQWMBQGCCsGAQUFBwMCBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMIGxBgNV\n" + + "HREBAf8EgaYwgaOGSnNwaWZmZTovL3NpZ25lci1yb2xlLmNvbnRleHQuc2VjdXJp\n" + + "dHktcmVhbG0ucHJvZC5nb29nbGUuY29tL3JvbGUvbGVhZi1yb2xlgjNzaWduZXIt\n" + + "cm9sZS5jb250ZXh0LnNlY3VyaXR5LXJlYWxtLnByb2Quc3BpZmZlLmdvb2eCIGZx\n" + + "ZG4tb2YtdGhlLW5vZGUucHJvZC5nb29nbGUuY29tMB0GA1UdDgQWBBSWSd5Fw6dI\n" + + "TGpt0m1Uxwf0iKqebzAfBgNVHSMEGDAWgBRm5agVVdpWfRZKM7u6OMuzHhqPcDAK\n" + + "BggqhkjOPQQDAgNIADBFAiB0sjRPSYy2eFq8Y0vQ8QN4AZ2NMajskvxnlifu7O4U\n" + + "RwIhANTh5Fkyx2nMYFfyl+W45dY8ODTw3HnlZ4b51hTAdkWl\n" + + "-----END CERTIFICATE-----\n" + + "-----BEGIN CERTIFICATE-----\n" + + "MIICQjCCAeigAwIBAgIUKxXRDlnWXefNV5lj5CwhDuXEq7MwCgYIKoZIzj0EAwIw\n" + + "OzEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xEDAOBgNVBAsMB2NvbnRleHQxDjAM\n" + + "BgNVBAMMBTEyMzQ1MCAXDTIzMDcxNDIyMzYwNFoYDzIwNTAxMTI5MjIzNjA0WjBB\n" + + "MRcwFQYDVQQKDA5zZWN1cml0eS1yZWFsbTEQMA4GA1UECwwHY29udGV4dDEUMBIG\n" + + "CisGAQQB1nkCCAEMBDIwMjIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAT/Zu7x\n" + + "UYVyg+T/vg2H+y4I6t36Kc4qxD0eqqZjRLYBVKkUQHxBqc14t0DpoROMYQCNd4DF\n" + + "pcxv/9m6DaJbRk6Ao4HBMIG+MA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAG\n" + + "AQH/AgEBMFgGA1UdHgEB/wROMEygSjA1gjNzaWduZXItcm9sZS5jb250ZXh0LnNl\n" + + "Y3VyaXR5LXJlYWxtLnByb2Quc3BpZmZlLmdvb2cwEYIPcHJvZC5nb29nbGUuY29t\n" + + "MB0GA1UdDgQWBBRm5agVVdpWfRZKM7u6OMuzHhqPcDAfBgNVHSMEGDAWgBQcjNAh\n" + + "SCHTj+BW8KrzSSLo2ASEgjAKBggqhkjOPQQDAgNIADBFAiEA6KyGd9VxXDZceMZG\n" + + "IsbC40rtunFjLYI0mjZw9RcRWx8CIHCIiIHxafnDaCi+VB99NZfzAdu37g6pJptB\n" + + "gjIY71MO\n" + + "-----END CERTIFICATE-----\n" + + "-----BEGIN CERTIFICATE-----\n" + + "MIICODCCAd6gAwIBAgIUXtZECORWRSKnS9rRTJYkiALUXswwCgYIKoZIzj0EAwIw\n" + + "NzEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xDTALBgNVBAsMBHJvb3QxDTALBgNV\n" + + "BAMMBDEyMzQwIBcNMjMwNzE0MjIzNjA0WhgPMjA1MDExMjkyMjM2MDRaMDsxFzAV\n" + + "BgNVBAoMDnNlY3VyaXR5LXJlYWxtMRAwDgYDVQQLDAdjb250ZXh0MQ4wDAYDVQQD\n" + + "DAUxMjM0NTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABAycVTZrjockbpD59f1a\n" + + "4l1SNL7nSyXz66Guz4eDveQqLmaMBg7vpACfO4CtiAGnolHEffuRtSkdM434m5En\n" + + "bXCjgcEwgb4wDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQIwWAYD\n" + + "VR0eAQH/BE4wTKBKMDWCM3NpZ25lci1yb2xlLmNvbnRleHQuc2VjdXJpdHktcmVh\n" + + "bG0ucHJvZC5zcGlmZmUuZ29vZzARgg9wcm9kLmdvb2dsZS5jb20wHQYDVR0OBBYE\n" + + "FByM0CFIIdOP4FbwqvNJIujYBISCMB8GA1UdIwQYMBaAFMX+vebuj/lYfYEC23IA\n" + + "8HoIW0HsMAoGCCqGSM49BAMCA0gAMEUCIQCfxeXEBd7UPmeImT16SseCRu/6cHxl\n" + + "kTDsq9sKZ+eXBAIgA+oViAVOUhUQO1/6Mjlczg8NmMy2vNtG4V/7g9dMMVU=\n" + + "-----END CERTIFICATE-----"; + private static final String ROOT_PEM = + "-----BEGIN CERTIFICATE-----\n" + + "MIIBtTCCAVqgAwIBAgIUbAe+8OocndQXRBCElLBxBSdfdV8wCgYIKoZIzj0EAwIw\n" + + "NzEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xDTALBgNVBAsMBHJvb3QxDTALBgNV\n" + + "BAMMBDEyMzQwIBcNMjMwNzE0MjIzNjA0WhgPMjA1MDExMjkyMjM2MDRaMDcxFzAV\n" + + "BgNVBAoMDnNlY3VyaXR5LXJlYWxtMQ0wCwYDVQQLDARyb290MQ0wCwYDVQQDDAQx\n" + + "MjM0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEaMY2tBW5r1t0+vhayz0ZoGMF\n" + + "boX/ZmmCmIh0iTWg4madvwNOh74CMVVvDUlXZcuVqZ3vVIX/a7PTFVqUwQlKW6NC\n" + + "MEAwDgYDVR0PAQH/BAQDAgGGMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFMX+\n" + + "vebuj/lYfYEC23IA8HoIW0HsMAoGCCqGSM49BAMCA0kAMEYCIQDETd27nsUTXKWY\n" + + "CiOno78O09gK95NoTkPU5e2chJYMqAIhALYFAyh7PU5xgFQsN9hiqgsHUc5/pmBG\n" + + "BGjJ1iz8rWGJ\n" + + "-----END CERTIFICATE-----"; + private static final String PRIVATE_KEY = + "-----BEGIN PRIVATE KEY-----\n" + + "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgqA2U0ld1OOHLMXWf\n" + + "uyN4GSaqhhudEIaKkll3rdIq0M+hRANCAAQGFlJpLxJMh4HuUm0DKjnUF7larH3t\n" + + "JvroQ12xpk+pPKQepn4ILoq9lZ8Xd3jzU98eDRXG5f4VjnX98DDHE4Id\n" + + "-----END PRIVATE KEY-----"; + + private String s2aAddress; + private int s2aPort; + private Server s2aServer; + private String s2aDelayAddress; + private int s2aDelayPort; + private Server s2aDelayServer; + private String mtlsS2AAddress; + private int mtlsS2APort; + private Server mtlsS2AServer; + private int serverPort; + private String serverAddress; + private Server server; + + @BeforeClass + public static void setUpClass() { + System.setProperty("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", "false"); + } + + @Before + public void setUp() throws Exception { + s2aPort = Utils.pickUnusedPort(); + s2aAddress = "localhost:" + s2aPort; + s2aServer = ServerBuilder.forPort(s2aPort).addService(new FakeS2AServer()).build(); + logger.info("S2A service listening on localhost:" + s2aPort); + s2aServer.start(); + + mtlsS2APort = Utils.pickUnusedPort(); + mtlsS2AAddress = "localhost:" + mtlsS2APort; + File s2aCert = new File("src/test/resources/server_cert.pem"); + File s2aKey = new File("src/test/resources/server_key.pem"); + File rootCert = new File("src/test/resources/root_cert.pem"); + ServerCredentials s2aCreds = + TlsServerCredentials.newBuilder() + .keyManager(s2aCert, s2aKey) + .trustManager(rootCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + mtlsS2AServer = + NettyServerBuilder.forPort(mtlsS2APort, s2aCreds).addService(new FakeS2AServer()).build(); + logger.info("mTLS S2A service listening on localhost:" + mtlsS2APort); + mtlsS2AServer.start(); + + s2aDelayPort = Utils.pickUnusedPort(); + s2aDelayAddress = "localhost:" + s2aDelayPort; + s2aDelayServer = ServerBuilder.forPort(s2aDelayPort).addService(new FakeS2AServer()).build(); + + serverPort = Utils.pickUnusedPort(); + serverAddress = "localhost:" + serverPort; + server = + NettyServerBuilder.forPort(serverPort) + .addService(new SimpleServiceImpl()) + .sslContext(buildSslContext()) + .build(); + logger.info("Simple Service listening on localhost:" + serverPort); + server.start(); + } + + @After + public void tearDown() throws Exception { + server.shutdown(); + s2aServer.shutdown(); + s2aDelayServer.shutdown(); + mtlsS2AServer.shutdown(); + } + + @Test + public void clientCommunicateUsingS2ACredentials_succeeds() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + ChannelCredentials credentials = + S2AChannelCredentials.createBuilder(s2aAddress).setLocalSpiffeId("test-spiffe-id").build(); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build(); + + assertThat(doUnaryRpc(executor, channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingS2ACredentialsNoLocalIdentity_succeeds() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + ChannelCredentials credentials = S2AChannelCredentials.createBuilder(s2aAddress).build(); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build(); + + assertThat(doUnaryRpc(executor, channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + ChannelCredentials credentials = + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ mtlsS2AAddress, + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem") + .build() + .setLocalSpiffeId("test-spiffe-id") + .build(); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build(); + + assertThat(doUnaryRpc(executor, channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingS2ACredentials_s2AdelayStart_succeeds() throws Exception { + DoUnaryRpc doUnaryRpc = new DoUnaryRpc(); + doUnaryRpc.start(); + Thread.sleep(2000); + s2aDelayServer.start(); + doUnaryRpc.join(); + } + + private class DoUnaryRpc extends Thread { + @Override + public void run() { + ExecutorService executor = Executors.newSingleThreadExecutor(); + ChannelCredentials credentials = S2AChannelCredentials.createBuilder(s2aDelayAddress).build(); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build(); + boolean result = false; + try { + result = doUnaryRpc(executor, channel); + } catch (InterruptedException e) { + logger.log(Level.SEVERE, "Failed to do unary rpc", e); + result = false; + } + assertThat(result).isTrue(); + } + } + + public static boolean doUnaryRpc(ExecutorService executor, ManagedChannel channel) + throws InterruptedException { + try { + SimpleServiceGrpc.SimpleServiceBlockingStub stub = + SimpleServiceGrpc.newBlockingStub(channel).withWaitForReady(); + SimpleResponse resp = stub.unaryRpc(SimpleRequest.newBuilder() + .setRequestMessage("S2A team") + .build()); + if (!resp.getResponseMessage().equals("Hello, S2A team!")) { + logger.info( + "Received unexpected message from the Simple Service: " + resp.getResponseMessage()); + throw new RuntimeException(); + } else { + System.out.println( + "We received this message from the Simple Service: " + resp.getResponseMessage()); + return true; + } + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + } + + private static SslContext buildSslContext() throws SSLException { + SslContextBuilder sslServerContextBuilder = + SslContextBuilder.forServer( + new ByteArrayInputStream(CERT_CHAIN.getBytes(UTF_8)), + new ByteArrayInputStream(PRIVATE_KEY.getBytes(UTF_8))); + SslContext sslServerContext = + GrpcSslContexts.configure(sslServerContextBuilder, SslProvider.OPENSSL) + .protocols("TLSv1.3", "TLSv1.2") + .trustManager(new ByteArrayInputStream(ROOT_PEM.getBytes(UTF_8))) + .clientAuth(ClientAuth.REQUIRE) + .build(); + + // Enable TLS resumption. This requires using the OpenSSL provider, since the JDK provider does + // not allow a server to send session tickets. + SSLSessionContext sslSessionContext = sslServerContext.sessionContext(); + if (!(sslSessionContext instanceof OpenSslSessionContext)) { + throw new SSLException("sslSessionContext does not use OpenSSL."); + } + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + // Calling {@code setTicketKeys} without specifying any keys means that the SSL libraries will + // handle the generation of the resumption master secret. + openSslSessionContext.setTicketKeys(); + + return sslServerContext; + } + + public static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest request, StreamObserver observer) { + observer.onNext( + SimpleResponse.newBuilder() + .setResponseMessage("Hello, " + request.getRequestMessage() + "!") + .build()); + observer.onCompleted(); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java new file mode 100644 index 00000000000..0191398b6b7 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static org.junit.Assert.assertThrows; + +import com.google.common.truth.Expect; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ProtoUtil}. */ +@RunWith(JUnit4.class) +public final class ProtoUtilTest { + @Rule public final Expect expect = Expect.create(); + + @Test + public void convertCiphersuite_success() { + expect + .that( + ProtoUtil.convertCiphersuite( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)) + .isEqualTo("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"); + expect + .that( + ProtoUtil.convertCiphersuite( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384)) + .isEqualTo("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"); + expect + .that( + ProtoUtil.convertCiphersuite( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256)) + .isEqualTo("TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"); + expect + .that( + ProtoUtil.convertCiphersuite(Ciphersuite.CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256)) + .isEqualTo("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); + expect + .that( + ProtoUtil.convertCiphersuite(Ciphersuite.CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384)) + .isEqualTo("TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"); + expect + .that( + ProtoUtil.convertCiphersuite( + Ciphersuite.CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)) + .isEqualTo("TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256"); + } + + @Test + public void convertCiphersuite_withUnspecifiedCiphersuite_fails() { + AssertionError expected = + assertThrows( + AssertionError.class, + () -> ProtoUtil.convertCiphersuite(Ciphersuite.CIPHERSUITE_UNSPECIFIED)); + expect.that(expected).hasMessageThat().isEqualTo("Ciphersuite 0 is not supported."); + } + + @Test + public void convertTlsProtocolVersion_success() { + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_3)) + .isEqualTo("TLSv1.3"); + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_2)) + .isEqualTo("TLSv1.2"); + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_1)) + .isEqualTo("TLSv1.1"); + expect.that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_0)).isEqualTo("TLSv1"); + } + + @Test + public void convertTlsProtocolVersion_withUnknownTlsVersion_fails() { + AssertionError expected = + assertThrows( + AssertionError.class, + () -> ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_UNSPECIFIED)); + expect.that(expected).hasMessageThat().isEqualTo("TLS version 0 is not supported."); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java new file mode 100644 index 00000000000..4024e8a6e36 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java @@ -0,0 +1,308 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.truth.Expect; +import com.google.protobuf.ByteString; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslPrivateKeyMethod; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.ByteArrayInputStream; +import java.security.PublicKey; +import java.security.Signature; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Optional; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class S2APrivateKeyMethodTest { + @Rule public final Expect expect = Expect.create(); + private static final byte[] DATA_TO_SIGN = "random bytes for signing.".getBytes(UTF_8); + + private S2AStub stub; + private FakeWriter writer; + private S2APrivateKeyMethod keyMethod; + + private static PublicKey extractPublicKeyFromPem(String pem) throws Exception { + X509Certificate cert = + (X509Certificate) + CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(pem.getBytes(UTF_8))); + return cert.getPublicKey(); + } + + private static boolean verifySignature( + byte[] dataToSign, byte[] signature, String signatureAlgorithm) throws Exception { + Signature sig = Signature.getInstance(signatureAlgorithm); + sig.initVerify(extractPublicKeyFromPem(FakeWriter.LEAF_CERT)); + sig.update(dataToSign); + return sig.verify(signature); + } + + @Before + public void setUp() { + // This is line is to ensure that JNI correctly links the necessary objects. Without this, we + // get `java.lang.UnsatisfiedLinkError` on + // `io.netty.internal.tcnative.NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha1()` + GrpcSslContexts.configure(SslContextBuilder.forClient()); + + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + keyMethod = S2APrivateKeyMethod.create(stub, /* localIdentity= */ Optional.empty()); + } + + @Test + public void signatureAlgorithmConversion_success() { + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA512); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512); + } + + @Test + public void signatureAlgorithmConversion_unsupportedOperation() { + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, + () -> S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg(-1)); + + assertThat(e).hasMessageThat().contains("Signature Algorithm -1 is not supported."); + } + + @Test + public void createOnNullStub_returnsNullPointerException() { + assertThrows( + NullPointerException.class, + () -> S2APrivateKeyMethod.create(/* stub= */ null, /* localIdentity= */ Optional.empty())); + } + + @Test + public void decrypt_unsupportedOperation() { + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, + () -> keyMethod.decrypt(/* engine= */ null, DATA_TO_SIGN)); + + assertThat(e).hasMessageThat().contains("decrypt is not supported."); + } + + @Test + public void fakelocalIdentity_signWithSha256_success() throws Exception { + S2AIdentity fakeIdentity = S2AIdentity.fromSpiffeId("fake-spiffe-id"); + S2AStub mockStub = mock(S2AStub.class); + OpenSslPrivateKeyMethod keyMethodWithFakeIdentity = + S2APrivateKeyMethod.create(mockStub, Optional.of(fakeIdentity)); + SessionReq req = + SessionReq.newBuilder() + .setLocalIdentity(fakeIdentity.identity()) + .setOffloadPrivateKeyOperationReq( + OffloadPrivateKeyOperationReq.newBuilder() + .setOperation(OffloadPrivateKeyOperationReq.PrivateKeyOperation.SIGN) + .setSignatureAlgorithm(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256) + .setRawBytes(ByteString.copyFrom(DATA_TO_SIGN))) + .build(); + byte[] expectedOutbytes = "fake out bytes".getBytes(UTF_8); + when(mockStub.send(req)) + .thenReturn( + SessionResp.newBuilder() + .setOffloadPrivateKeyOperationResp( + OffloadPrivateKeyOperationResp.newBuilder() + .setOutBytes(ByteString.copyFrom(expectedOutbytes))) + .build()); + + byte[] signature = + keyMethodWithFakeIdentity.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN); + verify(mockStub).send(req); + assertThat(signature).isEqualTo(expectedOutbytes); + } + + @Test + public void signWithSha256_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA256withECDSA")).isTrue(); + } + + @Test + public void signWithSha384_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA384withECDSA")).isTrue(); + } + + @Test + public void signWithSha512_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA512withECDSA")).isTrue(); + } + + @Test + public void sign_noKeyAvailable() throws Exception { + writer.resetPrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 255, error message: \"No Private Key" + + " available.\"."); + } + + @Test + public void sign_algorithmNotSupported() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 255, error message: \"Only ECDSA key" + + " algorithms are supported.\"."); + } + + @Test + public void sign_getsErrorResponse() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 1, error message: \"Intended ERROR" + + " Status from FakeWriter.\"."); + } + + @Test + public void sign_getsEmptyResponse() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.EMPTY_RESPONSE); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e).hasMessageThat().contains("No valid response received from S2A."); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java new file mode 100644 index 00000000000..82db6d4a144 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java @@ -0,0 +1,267 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import com.google.common.testing.NullPointerTester; +import com.google.common.testing.NullPointerTester.Visibility; +import io.grpc.Channel; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.benchmarks.Utils; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.internal.TestUtils.NoopChannelLogger; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.s2a.channel.S2AChannelPool; +import io.grpc.s2a.channel.S2AGrpcChannelPool; +import io.grpc.s2a.channel.S2AHandshakerServiceChannel; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.grpc.s2a.handshaker.S2AProtocolNegotiatorFactory.S2AProtocolNegotiator; +import io.grpc.stub.StreamObserver; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.util.AsciiString; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AProtocolNegotiatorFactory}. */ +@RunWith(JUnit4.class) +public class S2AProtocolNegotiatorFactoryTest { + private static final S2AIdentity LOCAL_IDENTITY = S2AIdentity.fromSpiffeId("local identity"); + private final ChannelHandlerContext mockChannelHandlerContext = mock(ChannelHandlerContext.class); + private GrpcHttp2ConnectionHandler fakeConnectionHandler; + private String authority; + private int port; + private Server fakeS2AServer; + private ObjectPool channelPool; + + @Before + public void setUp() throws Exception { + port = Utils.pickUnusedPort(); + fakeS2AServer = ServerBuilder.forPort(port).addService(new S2AServiceImpl()).build(); + fakeS2AServer.start(); + channelPool = new FakeChannelPool(); + authority = "localhost:" + port; + fakeConnectionHandler = FakeConnectionHandler.create(authority); + } + + @After + public void tearDown() { + fakeS2AServer.shutdown(); + } + + @Test + public void createProtocolNegotiatorFactory_nullArgument() throws Exception { + NullPointerTester tester = new NullPointerTester().setDefault(Optional.class, Optional.empty()); + + tester.testStaticMethods(S2AProtocolNegotiatorFactory.class, Visibility.PUBLIC); + } + + @Test + public void createProtocolNegotiator_nullArgument() throws Exception { + S2AChannelPool pool = + S2AGrpcChannelPool.create( + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource( + "localhost:8080", /* s2aChannelCredentials= */ Optional.empty()))); + + NullPointerTester tester = + new NullPointerTester() + .setDefault(S2AChannelPool.class, pool) + .setDefault(Optional.class, Optional.empty()); + + tester.testStaticMethods(S2AProtocolNegotiator.class, Visibility.PACKAGE); + } + + @Test + public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(Optional.of(LOCAL_IDENTITY), channelPool); + + assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT); + } + + @Test + public void s2aProtocolNegotiator_getHostNameOnNull_returnsNull() throws Exception { + assertThat(S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.getHostNameFromAuthority(null)) + .isNull(); + } + + @Test + public void s2aProtocolNegotiator_getHostNameOnValidAuthority_returnsValidHostname() + throws Exception { + assertThat( + S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.getHostNameFromAuthority( + "hostname:80")) + .isEqualTo("hostname"); + } + + @Test + public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds() + throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(Optional.of(LOCAL_IDENTITY), channelPool); + + ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); + + assertThat(clientNegotiator).isInstanceOf(S2AProtocolNegotiator.class); + assertThat(clientNegotiator.scheme()).isEqualTo(AsciiString.of("https")); + } + + @Test + public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide() + throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(Optional.of(LOCAL_IDENTITY), channelPool); + ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); + + clientNegotiator.close(); + + assertThat(((FakeChannelPool) channelPool).isChannelCached()).isFalse(); + } + + @Test + public void createChannelHandler_addHandlerToMockContext() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + ManagedChannel channel = + Grpc.newChannelBuilder(authority, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + FakeS2AChannelPool fakeChannelPool = new FakeS2AChannelPool(channel); + ProtocolNegotiator clientNegotiator = + S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient( + fakeChannelPool, Optional.of(LOCAL_IDENTITY)); + + ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler); + + ((ChannelDuplexHandler) channelHandler).userEventTriggered(mockChannelHandlerContext, "event"); + verify(mockChannelHandlerContext).fireUserEventTriggered("event"); + } + + /** A {@link S2AChannelPool} that returns the given channel. */ + private static class FakeS2AChannelPool implements S2AChannelPool { + private final Channel channel; + + FakeS2AChannelPool(Channel channel) { + this.channel = channel; + } + + @Override + public Channel getChannel() { + return channel; + } + + @Override + public void returnChannel(Channel channel) {} + + @Override + public void close() {} + } + + /** A {@code GrpcHttp2ConnectionHandler} that does nothing. */ + private static class FakeConnectionHandler extends GrpcHttp2ConnectionHandler { + private static final Http2ConnectionDecoder DECODER = mock(Http2ConnectionDecoder.class); + private static final Http2ConnectionEncoder ENCODER = mock(Http2ConnectionEncoder.class); + private static final Http2Settings SETTINGS = new Http2Settings(); + private final String authority; + + static FakeConnectionHandler create(String authority) { + return new FakeConnectionHandler(null, DECODER, ENCODER, SETTINGS, authority); + } + + private FakeConnectionHandler( + ChannelPromise channelUnused, + Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder, + Http2Settings initialSettings, + String authority) { + super(channelUnused, decoder, encoder, initialSettings, new NoopChannelLogger()); + this.authority = authority; + } + + @Override + public String getAuthority() { + return authority; + } + } + + /** An S2A server that handles GetTlsConfiguration request. */ + private static class S2AServiceImpl extends S2AServiceGrpc.S2AServiceImplBase { + static final FakeWriter writer = new FakeWriter(); + + @Override + public StreamObserver setUpSession(StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(SessionReq req) { + responseObserver.onNext(writer.handleResponse(req)); + } + + @Override + public void onError(Throwable t) {} + + @Override + public void onCompleted() {} + }; + } + } + + private static class FakeChannelPool implements ObjectPool { + private final Channel mockChannel = mock(Channel.class); + private @Nullable Channel cachedChannel = null; + + @Override + public Channel getObject() { + if (cachedChannel == null) { + cachedChannel = mockChannel; + } + return cachedChannel; + } + + @Override + public Channel returnObject(Object object) { + assertThat(object).isSameInstanceAs(mockChannel); + cachedChannel = null; + return null; + } + + public boolean isChannelCached() { + return (cachedChannel != null); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java new file mode 100644 index 00000000000..a2b0e673313 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java @@ -0,0 +1,260 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.truth.Expect; +import io.grpc.internal.SharedResourcePool; +import io.grpc.s2a.channel.S2AChannelPool; +import io.grpc.s2a.channel.S2AGrpcChannelPool; +import io.grpc.s2a.channel.S2AHandshakerServiceChannel; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.Optional; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AStub}. */ +@RunWith(JUnit4.class) +public class S2AStubTest { + @Rule public final Expect expect = Expect.create(); + private static final String S2A_ADDRESS = "localhost:8080"; + private S2AStub stub; + private FakeWriter writer; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void send_receiveOkStatus() throws Exception { + S2AChannelPool channelPool = + S2AGrpcChannelPool.create( + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource( + S2A_ADDRESS, /* s2aChannelCredentials= */ Optional.empty()))); + S2AServiceGrpc.S2AServiceStub serviceStub = S2AServiceGrpc.newStub(channelPool.getChannel()); + S2AStub newStub = S2AStub.newInstance(serviceStub); + + IOException expected = + assertThrows(IOException.class, () -> newStub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("UNAVAILABLE"); + } + + @Test + public void send_clientTlsConfiguration_receiveOkStatus() throws Exception { + SessionReq req = + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build(); + + SessionResp resp = stub.send(req); + + SessionResp expected = + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(FakeWriter.LEAF_CERT) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_2) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLSVersion.TLS_VERSION_1_3) + .setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + assertThat(resp).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void send_serverTlsConfiguration_receiveErrorStatus() throws Exception { + SessionReq req = + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_SERVER)) + .build(); + + SessionResp resp = stub.send(req); + + SessionResp expected = + SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("No TLS configuration for the server side.")) + .build(); + assertThat(resp).isEqualTo(expected); + } + + @Test + public void send_receiveErrorStatus() throws Exception { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + SessionResp resp = stub.send(SessionReq.getDefaultInstance()); + + SessionResp expected = + SessionResp.newBuilder() + .setStatus( + Status.newBuilder().setCode(1).setDetails("Intended ERROR Status from FakeWriter.")) + .build(); + assertThat(resp).isEqualTo(expected); + } + + @Test + public void send_receiveErrorResponse() throws InterruptedException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + expect.that(expected).hasCauseThat().isInstanceOf(RuntimeException.class); + expect.that(expected).hasMessageThat().contains("Intended ERROR from FakeWriter."); + } + + @Test + public void send_receiveCompleteStatus() throws Exception { + writer.setBehavior(FakeWriter.Behavior.COMPLETE_STATUS); + + ConnectionIsClosedException expected = + assertThrows( + ConnectionIsClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("Reading from the S2A is complete."); + } + + @Test + public void send_receiveUnexpectedResponse() throws Exception { + writer.sendIoError(); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected) + .hasMessageThat() + .contains( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable."); + } + + @Test + public void send_receiveManyUnexpectedResponse_expectResponsesEmpty() throws Exception { + writer.sendIoError(); + writer.sendIoError(); + writer.sendIoError(); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected) + .hasMessageThat() + .contains( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable."); + + assertThat(stub.getResponses()).isEmpty(); + } + + @Test + public void send_receiveDelayedResponse() throws Exception { + writer.sendGetTlsConfigResp(); + SessionResp resp = stub.send(SessionReq.getDefaultInstance()); + SessionResp expected = + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(FakeWriter.LEAF_CERT) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_2) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLSVersion.TLS_VERSION_1_3) + .setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + assertThat(resp).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void send_afterEarlyClose_receivesClosedException() throws InterruptedException { + stub.close(); + expect.that(writer.isFakeWriterClosed()).isTrue(); + + ConnectionIsClosedException expected = + assertThrows( + ConnectionIsClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("Stream to the S2A is closed."); + } + + @Test + public void send_failToWrite() throws Exception { + FailWriter failWriter = new FailWriter(); + stub = S2AStub.newInstanceForTesting(failWriter); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + expect.that(expected).hasCauseThat().isInstanceOf(S2AConnectionException.class); + expect + .that(expected) + .hasCauseThat() + .hasMessageThat() + .isEqualTo("Could not send request to S2A."); + } + + /** Fails whenever a write is attempted. */ + private static class FailWriter implements StreamObserver { + @Override + public void onNext(SessionReq req) { + assertThat(req).isNotNull(); + throw new S2AConnectionException("Could not send request to S2A."); + } + + @Override + public void onError(Throwable t) { + assertThat(t).isInstanceOf(S2AConnectionException.class); + } + + @Override + public void onCompleted() { + throw new UnsupportedOperationException(); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2ATrustManagerTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2ATrustManagerTest.java new file mode 100644 index 00000000000..384e1aba5cc --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2ATrustManagerTest.java @@ -0,0 +1,262 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.s2a.handshaker.S2AIdentity; +import java.io.ByteArrayInputStream; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Base64; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class S2ATrustManagerTest { + private S2AStub stub; + private FakeWriter writer; + private static final String FAKE_HOSTNAME = "Fake-Hostname"; + private static final String CLIENT_CERT_PEM = + "MIICKjCCAc+gAwIBAgIUC2GShcVO+5Zkml+7VO3OQ+B2c7EwCgYIKoZIzj0EAwIw" + + "HzEdMBsGA1UEAwwUcm9vdGNlcnQuZXhhbXBsZS5jb20wIBcNMjMwMTI2MTk0OTUx" + + "WhgPMjA1MDA2MTMxOTQ5NTFaMB8xHTAbBgNVBAMMFGxlYWZjZXJ0LmV4YW1wbGUu" + + "Y29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEeciYZgFAZjxyzTrklCRIWpad" + + "8wkyCZQzJSf0IfNn9NKtfzL2V/blteULO0o9Da8e2Avaj+XCKfFTc7salMo/waOB" + + "5jCB4zAOBgNVHQ8BAf8EBAMCB4AwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsG" + + "AQUFBwMBMAwGA1UdEwEB/wQCMAAwYQYDVR0RBFowWIYic3BpZmZlOi8vZm9vLnBy" + + "b2QuZ29vZ2xlLmNvbS9wMS9wMoIUZm9vLnByb2Quc3BpZmZlLmdvb2eCHG1hY2hp" + + "bmUtbmFtZS5wcm9kLmdvb2dsZS5jb20wHQYDVR0OBBYEFETY6Cu/aW924nfvUrOs" + + "yXCC1hrpMB8GA1UdIwQYMBaAFJLkXGlTYKISiGd+K/Ijh4IOEpHBMAoGCCqGSM49" + + "BAMCA0kAMEYCIQCZDW472c1/4jEOHES/88X7NTqsYnLtIpTjp5PZ62z3sAIhAN1J" + + "vxvbxt9ySdFO+cW7oLBEkCwUicBhxJi5VfQeQypT"; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void createForClient_withNullStub_throwsError() { + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + S2ATrustManager.createForClient( + /* stub= */ null, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isNull(); + } + + @Test + public void createForClient_withNullHostname_throwsError() { + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + S2ATrustManager.createForClient( + stub, /* hostname= */ null, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isNull(); + } + + @Test + public void getAcceptedIssuers_returnsExpectedNullResult() { + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + assertThat(trustManager.getAcceptedIssuers()).isNull(); + } + + @Test + public void checkClientTrusted_withEmptyCertificateChain_throwsException() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> trustManager.checkClientTrusted(new X509Certificate[] {}, /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Certificate chain has zero certificates."); + } + + @Test + public void checkServerTrusted_withEmptyCertificateChain_throwsException() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> trustManager.checkServerTrusted(new X509Certificate[] {}, /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Certificate chain has zero certificates."); + } + + @Test + public void checkClientTrusted_getsSuccessResponse() throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + // Expect no exception. + trustManager.checkClientTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkClientTrusted_withLocalIdentity_getsSuccessResponse() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient( + stub, FAKE_HOSTNAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + // Expect no exception. + trustManager.checkClientTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkServerTrusted_getsSuccessResponse() throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + // Expect no exception. + trustManager.checkServerTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkServerTrusted_withLocalIdentity_getsSuccessResponse() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient( + stub, FAKE_HOSTNAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + // Expect no exception. + trustManager.checkServerTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkClientTrusted_getsIntendedFailureResponse() throws CertificateException { + writer + .setVerificationResult(FakeWriter.VerificationResult.FAILURE) + .setFailureReason("Intended failure."); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Intended failure."); + } + + @Test + public void checkClientTrusted_getsIntendedFailureStatusInResponse() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Error occurred in response from S2A"); + } + + @Test + public void checkClientTrusted_getsIntendedFailureFromServer() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().isEqualTo("Failed to send request to S2A."); + } + + @Test + public void checkServerTrusted_getsIntendedFailureResponse() throws CertificateException { + writer + .setVerificationResult(FakeWriter.VerificationResult.FAILURE) + .setFailureReason("Intended failure."); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Intended failure."); + } + + @Test + public void checkServerTrusted_getsIntendedFailureStatusInResponse() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Error occurred in response from S2A"); + } + + @Test + public void checkServerTrusted_getsIntendedFailureFromServer() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().isEqualTo("Failed to send request to S2A."); + } + + private X509Certificate[] getCerts() throws CertificateException { + byte[] decoded = Base64.getDecoder().decode(CLIENT_CERT_PEM); + return new X509Certificate[] { + (X509Certificate) + CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(decoded)) + }; + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java new file mode 100644 index 00000000000..c33fd820e4c --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java @@ -0,0 +1,173 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.truth.Expect; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.SslContext; +import java.security.GeneralSecurityException; +import java.util.Optional; +import javax.net.ssl.SSLSessionContext; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link SslContextFactory}. */ +@RunWith(JUnit4.class) +public final class SslContextFactoryTest { + @Rule public final Expect expect = Expect.create(); + private static final String FAKE_TARGET_NAME = "fake_target_name"; + private S2AStub stub; + private FakeWriter writer; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void createForClient_returnsValidSslContext() throws Exception { + SslContext sslContext = + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty()); + + expect.that(sslContext).isNotNull(); + expect.that(sslContext.sessionCacheSize()).isEqualTo(1); + expect.that(sslContext.sessionTimeout()).isEqualTo(300); + expect.that(sslContext.isClient()).isTrue(); + expect.that(sslContext.applicationProtocolNegotiator().protocols()).containsExactly("h2"); + expect + .that(sslContext.cipherSuites()) + .containsExactly( + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + expect.that(openSslSessionContext.isSessionCacheEnabled()).isFalse(); + } + } + + @Test + public void createForClient_withLocalIdentity_returnsValidSslContext() throws Exception { + SslContext sslContext = + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + expect.that(sslContext).isNotNull(); + expect.that(sslContext.sessionCacheSize()).isEqualTo(1); + expect.that(sslContext.sessionTimeout()).isEqualTo(300); + expect.that(sslContext.isClient()).isTrue(); + expect.that(sslContext.applicationProtocolNegotiator().protocols()).containsExactly("h2"); + expect + .that(sslContext.cipherSuites()) + .containsExactly( + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + expect.that(openSslSessionContext.isSessionCacheEnabled()).isFalse(); + } + } + + @Test + public void createForClient_returnsEmptyResponse_error() throws Exception { + writer.setBehavior(FakeWriter.Behavior.EMPTY_RESPONSE); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Response from S2A server does NOT contain ClientTlsConfiguration."); + } + + @Test + public void createForClient_returnsErrorStatus_error() throws Exception { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().contains("Intended ERROR Status from FakeWriter."); + } + + @Test + public void createForClient_getsErrorFromServer_throwsError() throws Exception { + writer.sendIoError(); + + GeneralSecurityException expected = + assertThrows( + GeneralSecurityException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Failed to get client TLS configuration from S2A."); + } + + @Test + public void createForClient_nullStub_throwsError() throws Exception { + writer.sendUnexpectedResponse(); + + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + SslContextFactory.createForClient( + /* stub= */ null, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isEqualTo("stub should not be null."); + } + + @Test + public void createForClient_nullTargetName_throwsError() throws Exception { + writer.sendUnexpectedResponse(); + + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + SslContextFactory.createForClient( + stub, /* targetName= */ null, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .isEqualTo("targetName should not be null on client side."); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java new file mode 100644 index 00000000000..806e412b784 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker.tokenmanager; + +import static com.google.common.truth.Truth.assertThat; + +import com.beust.jcommander.JCommander; +import io.grpc.s2a.handshaker.S2AIdentity; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SingleTokenAccessTokenManagerTest { + private static final S2AIdentity IDENTITY = S2AIdentity.fromSpiffeId("spiffe_id"); + private static final String TOKEN = "token"; + private static final String[] SET_TOKEN = {"--s2a_access_token", TOKEN}; + private static final SingleTokenFetcher.Flags FLAGS = new SingleTokenFetcher.Flags(); + + @Before + public void setUp() { + FLAGS.reset(); + } + + @Test + public void getDefaultToken_success() throws Exception { + JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getDefaultToken()).isEqualTo(TOKEN); + } + + @Test + public void getToken_success() throws Exception { + JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getToken(IDENTITY)).isEqualTo(TOKEN); + } + + @Test + public void getToken_noEnvironmentVariable() throws Exception { + assertThat(SingleTokenFetcher.create()).isEmpty(); + } + + @Test + public void create_success() throws Exception { + JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getToken(IDENTITY)).isEqualTo(TOKEN); + } + + @Test + public void create_noEnvironmentVariable() throws Exception { + assertThat(AccessTokenManager.create()).isEmpty(); + } +} \ No newline at end of file diff --git a/s2a/src/test/resources/README.md b/s2a/src/test/resources/README.md new file mode 100644 index 00000000000..00901015444 --- /dev/null +++ b/s2a/src/test/resources/README.md @@ -0,0 +1,31 @@ +# Generating certificates and keys for testing mTLS-S2A + +Content from: https://github.com/google/s2a-go/blob/main/testdata/README.md + +Create root CA + +``` +openssl req -x509 -sha256 -days 7305 -newkey rsa:2048 -keyout root_key.pem -out +root_cert.pem +``` + +Generate private keys for server and client + +``` +openssl genrsa -out server_key.pem 2048 +openssl genrsa -out client_key.pem 2048 +``` + +Generate CSRs for server and client + +``` +openssl req -key server_key.pem -new -out server.csr -config config.cnf +openssl req -key client_key.pem -new -out client.csr -config config.cnf +``` + +Sign CSRs for server and client + +``` +openssl x509 -req -CA root_cert.pem -CAkey root_key.pem -in server.csr -out server_cert.pem -days 7305 -extfile config.cnf -extensions req_ext +openssl x509 -req -CA root_cert.pem -CAkey root_key.pem -in client.csr -out client_cert.pem -days 7305 +``` \ No newline at end of file diff --git a/s2a/src/test/resources/client.csr b/s2a/src/test/resources/client.csr new file mode 100644 index 00000000000..664f5a4cf86 --- /dev/null +++ b/s2a/src/test/resources/client.csr @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIChzCCAW8CAQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0B +AQEFAAOCAQ8AMIIBCgKCAQEAoSS3KtFgiXX4vAUNscFGIB/r2OOMgiZMKHz72dN0 +5kSxwdpQxpMIhwEoe0lhHNfOiuE7/r6VbGG9RGGIcQcoSonc3InPRfpnzfj9KohJ +i8pYkLL9EwElAEl9sWnvVKTza8jTApDP2Z/fntBEsWAMsLPpuRZT6tgN1sXe4vNG +4wufJSxuImyCVAx1fkZjRkYEKOtm1osnEDng4R0WXZ6S+q5lYzYPk1wxgbjdZu2U +fWxP6V63SphV0NFXTx0E401j2h258cIqTVj8lRX6dfl9gO0d43Rd+hSU7R4iXGEw +arixuH9g5H745AFf9H52twHPcNP9cEKBljBpSV5z3MvTkQIDAQABoC4wLAYJKoZI +hvcNAQkOMR8wHTAbBgNVHREEFDAShxAAAAAAAAAAAAAAAAAAAAAAMA0GCSqGSIb3 +DQEBCwUAA4IBAQCQHim3aIpGJs5u6JhEA07Rwm8YKyVALDEklhsHILlFhdNr2uV7 +S+3bHV79mDGjxNWvFcgK5h5ENkT60tXbhbie1gYmFT0RMCYHDsL09NGTh8G9Bbdl +UKeA9DMhRSYzE7Ks3Lo1dJvX7OAEI0qV77dGpQknufYpmHiBXuqtB9I0SpYi1c4O +9IUn/NY0yiYFPsIEsVRz/1dK97wazusLnijaMwNNhUc9bJwTyujhlr+b8ioPyADG +e+GDF97d0nQ8806DOJF4GTRKwaXD+R5zN5t4ULhZ7ERqLNeE9EnWRe4CvSGvBoNA +hIVeYaLd761Z9ZKvOnsgCr8qvMDilDFY6OfB +-----END CERTIFICATE REQUEST----- \ No newline at end of file diff --git a/s2a/src/test/resources/client_cert.pem b/s2a/src/test/resources/client_cert.pem new file mode 100644 index 00000000000..b72f6991c91 --- /dev/null +++ b/s2a/src/test/resources/client_cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC9DCCAdwCFB+cDXee2sIHjdlBhdNpTo+G2XAjMA0GCSqGSIb3DQEBCwUAMFkx +CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl +cm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMzEw +MTcyMzA5MDNaFw00MzEwMTcyMzA5MDNaMBQxEjAQBgNVBAMMCWxvY2FsaG9zdDCC +ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEktyrRYIl1+LwFDbHBRiAf +69jjjIImTCh8+9nTdOZEscHaUMaTCIcBKHtJYRzXzorhO/6+lWxhvURhiHEHKEqJ +3NyJz0X6Z834/SqISYvKWJCy/RMBJQBJfbFp71Sk82vI0wKQz9mf357QRLFgDLCz +6bkWU+rYDdbF3uLzRuMLnyUsbiJsglQMdX5GY0ZGBCjrZtaLJxA54OEdFl2ekvqu +ZWM2D5NcMYG43WbtlH1sT+let0qYVdDRV08dBONNY9odufHCKk1Y/JUV+nX5fYDt +HeN0XfoUlO0eIlxhMGq4sbh/YOR++OQBX/R+drcBz3DT/XBCgZYwaUlec9zL05EC +AwEAATANBgkqhkiG9w0BAQsFAAOCAQEARorc1t2OJnwm1lxhf2KpTpNvNOI9FJak +iSHz/MxhMdu4BG/dQHkKkWoVC6W2Kaimx4OImBwRlGEmGf4P0bXOLSTOumk2k1np +ZUbw7Z2cJzvBmT2BLoHRXcBvbFIBW5DJUSHR37eXEKP57BeD+Og4/3XhNzehSpTX +DRd2Ix/D39JjYA462nqPHQP8HDMf6+0BFmvf9ZRYmFucccYQRCUCKDqb8+wGf9W6 +tKNRE6qPG2jpAQ9qkgO7XuucbLvpywt5xj+yDRbOIq43l40mHaz4lRp697oaxjP8 +HSVcMydW3cluoW3AVInNIaqbM1dr6931MllK62DKipFtmCycq/56XA== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/client_key.pem b/s2a/src/test/resources/client_key.pem new file mode 100644 index 00000000000..dd3e2ff78f1 --- /dev/null +++ b/s2a/src/test/resources/client_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQChJLcq0WCJdfi8 +BQ2xwUYgH+vY44yCJkwofPvZ03TmRLHB2lDGkwiHASh7SWEc186K4Tv+vpVsYb1E +YYhxByhKidzcic9F+mfN+P0qiEmLyliQsv0TASUASX2xae9UpPNryNMCkM/Zn9+e +0ESxYAyws+m5FlPq2A3Wxd7i80bjC58lLG4ibIJUDHV+RmNGRgQo62bWiycQOeDh +HRZdnpL6rmVjNg+TXDGBuN1m7ZR9bE/pXrdKmFXQ0VdPHQTjTWPaHbnxwipNWPyV +Ffp1+X2A7R3jdF36FJTtHiJcYTBquLG4f2DkfvjkAV/0fna3Ac9w0/1wQoGWMGlJ +XnPcy9ORAgMBAAECggEALAUqoGDIHWUDyOEch5WDwZzWwc4PgTJTFbBm4G96fLkB +UjKAZG6gIrk3RM6b39Q4UQoMaJ/Jk+zzVi3Kpw3MfOhCVGC1JamtF8BP8IGAjdZ9 +8TFkHv/uCrEIzCFjRt00vhoDQq0qiom4/dppGYdikBbl3zDxRbM1vJkbNSY+FCGW +dA0uJ5XdMLR6lPeB5odqjUggnfUgPCOLdV/F+HkSM9NP1bzmHLiKznzwFsfat139 +7LdzJwNN5IX4Io6cxsxNlrX/NNvPkKdGv07Z6FYxWROyKCunjh48xFcQg0ltoRuq +R9P8/LwS8GYrcc1uC/uBc0e6VgM9D9fsvh+8SQtf3QKBgQDXX+z2GnsFoEs7xv9U +qN0HEX4jOkihZvFu43layUmeCeE8wlEctJ0TsM5Bd7FMoUG6e5/btwhsAIYW89Xn +l/R8OzxR6Kh952Dce4DAULuIeopiw7ASJwTZtO9lWhxw0hjM1hxXTG+xxOqQvsRX +c+d+vtvdIqyJ4ELfzg9kUtkdpwKBgQC/ig3cmej7dQdRAMn0YAwgwhuLkCqVFh4y +WIlqyPPejKf8RXubqgtaSYx/T7apP87SMMSfSLaUdrYAGjST6k+tG5cmwutPIbw/ +osL7U3hcIhjX3hfHgI69Ojcpplbd5yqTxZHpxIs6iAQCEqNuasLXIDMouqNhGF1D +YssD6qxcBwKBgQCdZqWvVrsB6ZwSG+UO4jpmqAofhMD/9FQOToCqMOF0dpP966WL +7RO/CEA06FzTPCblOuQhlyq4g8l7jMiPcSZkhIYY9oftO+Q2Pqxh4J6tp6DrfUh4 +e7u3v9wVnj2a1nD5gqFDy8D1kow7LLAhmbtdje7xNh4SxasaFWZ6U3IJkQKBgGS1 +F5i3q9IatCAZBBZjMb0/kfANevYsTPA3sPjec6q91c1EUzuDarisFx0RMn9Gt124 +mokNWEIzMHpZTO/AsOfZq92LeuF+YVYsI8y1FIGMw/csJOCWbXZ812gkt2OxGafc +p118I6BAx6q3VgrGQ2+M1JlDmIeCofa+SPPkPX+dAoGBAJrOgEJ+oyEaX/YR1g+f +33pWoPQbRCG7T4+Y0oetCCWIcMg1/IUvGUCGmRDxj5dMqB+a0vJtviQN9rjpSuNS +0EVw79AJkIjHhi6KDOfAuyBvzGhxpqxGufnQ2GU0QL65NxQfd290xkxikN0ZGtuB +SDgZoJxMOGYwf8EX5i9h27Db +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/config.cnf b/s2a/src/test/resources/config.cnf new file mode 100644 index 00000000000..38d9a9ccdb0 --- /dev/null +++ b/s2a/src/test/resources/config.cnf @@ -0,0 +1,17 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = req_ext + +[req_distinguished_name] +countryName = Country Name (2 letter code) +stateOrProvinceName = State or Province Name (full name) +localityName = Locality Name (eg, city) +organizationalUnitName = Organizational Unit Name (eg, section) +commonName = Common Name (eg, your name or your server\'s hostname) +emailAddress = Email Address + +[req_ext] +subjectAltName = @alt_names + +[alt_names] +IP.1 = :: \ No newline at end of file diff --git a/s2a/src/test/resources/root_cert.pem b/s2a/src/test/resources/root_cert.pem new file mode 100644 index 00000000000..737e601691c --- /dev/null +++ b/s2a/src/test/resources/root_cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUb7RsINwsFgKf0Q0RuzfOgp48j6UwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTIzMTAxNzIzMDczOFoXDTQzMTAxNzIzMDczOFowWTELMAkGA1UEBhMCQVUxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAkIFnQLuhzYnm3rvmi/U7zMgEP2Tqgb3VC00frSXEV6olZcLgyC9g +0DAGdt9l9lP90DQTG5KCOtoW2BTqM/aaVpR0OaDFOCy90FIj6YyZLZ9w2PQxQcxS +GQHyEvWszTkNxeDyG1mPTj+Go8JLKqdvLg/9GUgPg6stxyAZwYhyUTGuEM4bv0sn +b3vmHRmIGJ/w6aLtd7nK8LkNHa3WVrbvRGHrzdMHfpzF/M/5fAk8GfRYugo39knf +VLKGyQCXNI8Y1iHGEmPqQZIFPTjBL6caIlbEV0VHlxoSOGB6JVxcllxAEvd6abqX +RJVJPQzzGfEnMNYp9SiZQ9bvDRUsUkWyYwIDAQABo1MwUTAdBgNVHQ4EFgQUAZMN +F9JAGHbA3jGOeu6bWFvSdWkwHwYDVR0jBBgwFoAUAZMNF9JAGHbA3jGOeu6bWFvS +dWkwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAicBli36ISJFu +lrJqOHVqTeNP6go0I35VGnP44nEEP5cBvRD3XntBFEk5D3mSNNOGt+2ncxom8VR9 +FsLuTfHAipXePJI6MSxFuBPea8V/YPBs3npk5f1FRvJ5vEgtzFvBjsKmp1dS9hH0 +KUWtWcsAkO2Anc/LVc0xxSidL8NjzYoEFqiki0TNNwCJjmd9XwnBLHW38sEb/pgy +KTyRpOyG3Zg2UDjBHiXPBrmIvVFLB6+LrPNvfr1k4HjIgVY539ZXUvVMDKytMrDY +h63EMDn4kkPpxXlufgWGybjN5D51OylyWBZLe+L1DQyWEg0Vd7GwPzb6p7bmI7MP +pooqbgbDpQ== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/root_key.pem b/s2a/src/test/resources/root_key.pem new file mode 100644 index 00000000000..aae992426d7 --- /dev/null +++ b/s2a/src/test/resources/root_key.pem @@ -0,0 +1,30 @@ +-----BEGIN ENCRYPTED PRIVATE KEY----- +MIIFHDBOBgkqhkiG9w0BBQ0wQTApBgkqhkiG9w0BBQwwHAQInmQVkXP3TFcCAggA +MAwGCCqGSIb3DQIJBQAwFAYIKoZIhvcNAwcECGeCAVH1pefxBIIEyD3Nj1Dy19oy +fogU+z8YBLXuSCx8s3zncYPF9nYlegGSSo0ace/WxfPu8AEPus1P2MxlxfcCQ1A+ +5+vMihtEpgpTg9R4RlLAWs45jz4AduGiwqW05+W5zgDn6g7p7HIL0+M5FxKRkAW0 +KEH4Jy8Vc1XQxkhOm1Q4NLI8PT94rcBDE9Od03sdrW/hQgaOFz5AWOlT5jF1uUOz +glF1RQQxfJygTB6qlPTC3BAaiAnWij3NOg5L5vvUhjLa7iOZOhRQBRkf4YtHsM+2 +rFy8Z7MeHOvrqFf8LXosNy3JreQW036rLGR0Xh5myATkNrEwA8df37AgLUmwqyfz +hjZefPW77LgMAXlaN8s345AGikOX8yQKEFzPV/Nag32p6t4oiRRcUUfdB4wzKi6T +mzZ6lKcGR3qqL4V6lJSV3I2fmgkYZnUwymolyu+1+CVYDLuE53TBi5dRXwgOghi7 +npw7PqqQCian8yxHF9c1rYukD0ov0/y8ratjOu9XoJG2/wWQJNvDkAyc3mSJf+3y +6Wtu1qhLszU8pZOGW0fK6bGyHSp+wkoah/vRzB0+yFjvuMIG6py2ZDQeqhqS3ZV2 +nZHHjj0tZ45Wbdf4k17ujEK34pFXluPH//zADnd6ym2W0t6x+jtqR5tYu3poORQg +jFgpudkn2RUSq8N/gIiHDwblYBxU2dmyzEVudv1zNgVSHyetGLxsFoNB7Prn89rJ +u24a/xtuCyC2pshWo3KiL74hkkCsC8rLbEAAbADheb35b+Ca3JnMwgyUHbHL6Hqf +EiVIgm14lB/1uz651X58Boo6tDFkgrxEtGDUIZm8yk2n0tGflp7BtYbMCw+7gqhb +XN4hlhFDcCJm8peXcyCtGajOnBuNO9JJDNYor6QjptaIpQBFb7/0rc7kyO12BIUv +F9mrCHF18Hd/9AtUO93+tyDAnL64Jqq9tUv8dOVtIfbcHXZSYHf24l0XAiKByb8y +9NQLUZkIuF4aUZVHV8ZBDdHNqjzqVglKQlGHdw1XBexSal5pC9HvknOmWBgl0aza +flzeTRPX7TPrMJDE5lgSy58czGpvZzhFYwOp6cwpfjNsiqdzD78Zs0xsRbNg519s +d+cLmbiU3plWCoYCuDb68eZRRzT+o41+QJG2PoMCpzPw5wMLl6HuW7HXMRFpZKJc +tPKpeTIzb8hjhA+TwVIVpTPHvvQehtTUQD2mRujdvNM6PF8tnuC3F3sB3PTjeeJg +uzfEfs3BynRTIj/gX6y87gzwsrwWIEN6U0cCbQ6J1EcgdQCiH8vbhIgfd4DkLgLN +Kkif+fI/HgBOqaiwSw3sHmWgB6PllVQOKH6qAiejTHR/UUvJTPvgKJFLunmBiF12 +N1bRge1sSXE1eLKVdi+dP1j0o6RxhaRrbX7ie3y/wYHwCJnb8h08DEprgCqoswFs +SuNKmvlibBHAsnOdhyCTOd9I5n8XzAUUp6mT+C5WDfl7qfYvh6IHFlSrhZ9aS9b6 +RY873cnphKbqU5d7Cr8Ufx4b4SgS+hEnuP8y5IToLQ3BONGQH2lu7nmd89wjW0uo +IMRXybwf/5FnKhEy8Aw+pD6AxiXC3DZVTKl3SHmjkYBDvNElsJVgygVTKgbOa1Z+ +ovIK/D7QV7Nv3uVortH8XA== +-----END ENCRYPTED PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/server.csr b/s2a/src/test/resources/server.csr new file mode 100644 index 00000000000..1657b191133 --- /dev/null +++ b/s2a/src/test/resources/server.csr @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIChzCCAW8CAQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0B +AQEFAAOCAQ8AMIIBCgKCAQEAlPThqu8tfJ4hQKRiUw/vNPfo2L2LQU8NlrRL7rvV +71E345LGK1h/hM3MHp5VgEvaaIibb0hSNv/TYz3HVCQyNuPlcmkHZTJ9mB0icilU +rYWdM0LPIg46iThmIQVhMiNfpMKQLDLQ7o3Jktjm32OxnQdtYSV+7NFnw8/0pB4j +iaiBYfZIMeGzEJIOFG8GSNJG0pfCI71DyLRonIcb2XzfeDPHeWSF7lbIoMGAuKIE +2mXpwHmAjTMJzIShSgLqCvmbz7wR3ZeVMknXcgcqMmagGphy8SjizIWC5KRbrnRq +F22Ouxdat6scIevRXGp5nYawFYdpK9qo+82gEouVX3dtSQIDAQABoC4wLAYJKoZI +hvcNAQkOMR8wHTAbBgNVHREEFDAShxAAAAAAAAAAAAAAAAAAAAAAMA0GCSqGSIb3 +DQEBCwUAA4IBAQB2qU354OlNVunhZhiOFNwabovxLcgKoQz+GtJ2EzsMEza+NPvV +dttPxXzqL/U+gDghvGzSYGuh2yMfTTPO+XtZKpvMUmIWonN5jItbFwSTaWcoE8Qs +zFZokRuFJ9dy017u642mpdf6neUzjbfCjWs8+3jyFzWlkrMF3RlSTxPuksWjhXsX +dxxLNu8YWcsYRB3fODHqrlBNuDn+9kb9z8to+yq76MA0HtdDkjd/dfgghiTDJhqm +IcwhBXufwQUrOP4YiuiwM0mo7Xlhw65gnSmRcwR9ha98SV2zG5kiRYE+m+94bDbd +kGBRfhpQSzh1w09cVzmLgzkfxRShEB+bb9Ss +-----END CERTIFICATE REQUEST----- \ No newline at end of file diff --git a/s2a/src/test/resources/server_cert.pem b/s2a/src/test/resources/server_cert.pem new file mode 100644 index 00000000000..10a98cf5c21 --- /dev/null +++ b/s2a/src/test/resources/server_cert.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDWjCCAkKgAwIBAgIUMZkgD5gtoa39H9jdI/ijVkyxC/swDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTIzMTAxNzIzMDg1M1oXDTQzMTAxNzIzMDg1M1owFDESMBAGA1UEAwwJbG9jYWxo +b3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAlPThqu8tfJ4hQKRi +Uw/vNPfo2L2LQU8NlrRL7rvV71E345LGK1h/hM3MHp5VgEvaaIibb0hSNv/TYz3H +VCQyNuPlcmkHZTJ9mB0icilUrYWdM0LPIg46iThmIQVhMiNfpMKQLDLQ7o3Jktjm +32OxnQdtYSV+7NFnw8/0pB4jiaiBYfZIMeGzEJIOFG8GSNJG0pfCI71DyLRonIcb +2XzfeDPHeWSF7lbIoMGAuKIE2mXpwHmAjTMJzIShSgLqCvmbz7wR3ZeVMknXcgcq +MmagGphy8SjizIWC5KRbrnRqF22Ouxdat6scIevRXGp5nYawFYdpK9qo+82gEouV +X3dtSQIDAQABo18wXTAbBgNVHREEFDAShxAAAAAAAAAAAAAAAAAAAAAAMB0GA1Ud +DgQWBBTKJU+NK7Q6ZPccSigRCMBCBgjkaDAfBgNVHSMEGDAWgBQBkw0X0kAYdsDe +MY567ptYW9J1aTANBgkqhkiG9w0BAQsFAAOCAQEAXuCs6MGVoND8TaJ6qaDmqtpy +wKEW2hsGclI9yv5cMS0XCVTkmKYnIoijtqv6Pdh8PfhIx5oJqJC8Ml16w4Iou4+6 +kKF0DdzdQyiM0OlNCgLYPiR4rh0ZCAFFCvOsDum1g+b9JTFZGooK4TMd9thwms4D +SqpP5v1NWf/ZLH5TYnp2CkPzBxDlnMJZphuWtPHL+78TbgQuQaKu2nMLBGBJqtFi +HDOGxckgZuwBsy0c+aC/ZwaV7FdMP42kxUZduCEx8+BDSGwPoEpz6pwVIkjiyYAm +3O8FUeEPzYzwpkANIbbEIDWV6FVH9IahKRRkE+bL3BqoQkv8SMciEA5zWsPrbA== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/server_key.pem b/s2a/src/test/resources/server_key.pem new file mode 100644 index 00000000000..44f087dee94 --- /dev/null +++ b/s2a/src/test/resources/server_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCU9OGq7y18niFA +pGJTD+809+jYvYtBTw2WtEvuu9XvUTfjksYrWH+EzcwenlWAS9poiJtvSFI2/9Nj +PcdUJDI24+VyaQdlMn2YHSJyKVSthZ0zQs8iDjqJOGYhBWEyI1+kwpAsMtDujcmS +2ObfY7GdB21hJX7s0WfDz/SkHiOJqIFh9kgx4bMQkg4UbwZI0kbSl8IjvUPItGic +hxvZfN94M8d5ZIXuVsigwYC4ogTaZenAeYCNMwnMhKFKAuoK+ZvPvBHdl5UySddy +ByoyZqAamHLxKOLMhYLkpFuudGoXbY67F1q3qxwh69FcanmdhrAVh2kr2qj7zaAS +i5Vfd21JAgMBAAECggEACTBuN4hXywdKT92UP0GNZTwh/jT7QUUqNnDa+lhWI1Rk +WUK1vPjRrRSxEfZ8mdSUHbzHsf7JK6FungGyqUsuWdqHTh6SmTibLOYnONm54paK +kx38/0HXdJ2pF0Jos5ohDV3/XOqpnv3aQJfm7kMNMv3BTqvsf5mPiDHtCq7dTGGj +rGiLc0zirKZq79C6YSB1UMB01BsDl2ScflK8b3osT18uYx/BOdjLT4yZWQsU/nbB +OeF+ziWTTUAVjodGeTf+NYG7cFN/9N9PdSnAwuw8Nche3xZKbHTh2I578Zd4bsDX +H+hoMN862nzOXEvD6KyLB8xDdnEZ+p+njeDROJVmgQKBgQDQhzQEl/co1LYc5IDO +mynhCOtKJeRWBLhYEPIuaSY3qF+lrOWzqyOUNppWDx+HeKOq70X1Q+ETeSXtbaL1 +qHBkNcApQ2lStcpkR9whcVbr9NIWC8y8UQxyerEK3x3l0bZ99dfJ/z6lbxdS7prc +Hhxy6pUj8Q8AgpTZA8HfQUF1EQKBgQC23ek24kTVvWeWX2C/82H1Yfia6ITL7WHz +3aEJaZaO5JD3KmOSZgY88Ob3pkDTRYjFZND5zSB7PnM68gpo/OEDla6ZYtfwBWCX +q4QhFtv2obehobmDk+URVfvlOcBikoEP1i8oy7WdZ5CgC4gNKkkD15l68W+g5IIG +2ZOA97yUuQKBgDAzoI2TRxmUGciR9UhMy6Bt/F12ZtKPYsFQoXqi6aeh7wIP9kTS +wXWoLYLJGiOpekOv7X7lQujKbz7zweCBIAG5/wJKx9TLms4VYkgEt+/w9oMMFTZO +kc8Al14I9xNBp6p0In5Z1vRMupp79yX8e90AZpsZRLt8c8W6PZ1Kq0PRAoGBAKmD +7LzD46t/eJccs0M9CoG94Ac5pGCmHTdDLBTdnIO5vehhkwwTJ5U2e+T2aQFwY+kY +G+B1FrconQj3dk78nFoGV2Q5DJOjaHcwt7s0xZNLNj7O/HnMj3wSiP9lGcJGrP1R +P0ZCEIlph9fU2LnbiPPW2J/vT9uF+EMBTosvG9GBAoGAEVaDLLXOHj+oh1i6YY7s +0qokN2CdeKY4gG7iKjuDFb0r/l6R9uFvpUwJMhLEkF5SPQMyrzKFdnTpw3n/jnRa +AWG6GoV+D7LES+lHP5TXKKijbnHJdFjW8PtfDXHCJ6uGG91vH0TMMp1LqhcvGfTv +lcNGXkk6gUNSecxBC1uJfKE= +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index ae6e395e7a1..a550789aaca 100644 --- a/settings.gradle +++ b/settings.gradle @@ -64,6 +64,7 @@ include ":grpc-benchmarks" include ":grpc-services" include ":grpc-servlet" include ":grpc-servlet-jakarta" +include ":grpc-s2a" include ":grpc-xds" include ":grpc-bom" include ":grpc-rls" @@ -98,6 +99,7 @@ project(':grpc-benchmarks').projectDir = "$rootDir/benchmarks" as File project(':grpc-services').projectDir = "$rootDir/services" as File project(':grpc-servlet').projectDir = "$rootDir/servlet" as File project(':grpc-servlet-jakarta').projectDir = "$rootDir/servlet/jakarta" as File +project(':grpc-s2a').projectDir = "$rootDir/s2a" as File project(':grpc-xds').projectDir = "$rootDir/xds" as File project(':grpc-bom').projectDir = "$rootDir/bom" as File project(':grpc-rls').projectDir = "$rootDir/rls" as File