diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java new file mode 100644 index 00000000..9621c935 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java @@ -0,0 +1,41 @@ +package io.opentdf.platform.sdk; + +import io.grpc.Channel; +import io.opentdf.platform.kas.AccessServiceGrpc; +import io.opentdf.platform.kas.PublicKeyRequest; +import io.opentdf.platform.kas.RewrapRequest; + +import java.util.HashMap; +import java.util.function.Function; + +public class KASClient implements SDK.KAS { + + private final Function channelFactory; + + public KASClient(Function channelFactory) { + this.channelFactory = channelFactory; + } + + @Override + public String getPublicKey(SDK.KASInfo kasInfo) { + return getStub(kasInfo).publicKey(PublicKeyRequest.getDefaultInstance()).getPublicKey(); + } + + @Override + public byte[] unwrap(SDK.KASInfo kasInfo, SDK.Policy policy) { + // this is obviously wrong. we still have to generate a correct request and decrypt the payload + return getStub(kasInfo).rewrap(RewrapRequest.getDefaultInstance()).getEntityWrappedKey().toByteArray(); + } + + private final HashMap stubs = new HashMap<>(); + + private synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(SDK.KASInfo kasInfo) { + if (!stubs.containsKey(kasInfo)) { + var channel = channelFactory.apply(kasInfo); + var stub = AccessServiceGrpc.newBlockingStub(channel); + stubs.put(kasInfo, stub); + } + + return stubs.get(kasInfo); + } +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java index 49af620d..ef7494ee 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java @@ -17,14 +17,25 @@ public class SDK { private final Services services; + public interface KASInfo{ + String getAddress(); + } + public interface Policy{} + + interface KAS { + String getPublicKey(KASInfo kasInfo); + byte[] unwrap(KASInfo kasInfo, Policy policy); + } + // TODO: add KAS - public interface Services { + interface Services { AttributesServiceFutureStub attributes(); NamespaceServiceFutureStub namespaces(); SubjectMappingServiceFutureStub subjectMappings(); ResourceMappingServiceFutureStub resourceMappings(); + KAS kas(); - static Services newServices(Channel channel) { + static Services newServices(Channel channel, KAS kas) { var attributeService = AttributesServiceGrpc.newFutureStub(channel); var namespaceService = NamespaceServiceGrpc.newFutureStub(channel); var subjectMappingService = SubjectMappingServiceGrpc.newFutureStub(channel); @@ -50,11 +61,16 @@ public SubjectMappingServiceFutureStub subjectMappings() { public ResourceMappingServiceFutureStub resourceMappings() { return resourceMappingService; } + + @Override + public KAS kas() { + return kas; + } }; } } - public SDK(Services services) { + SDK(Services services) { this.services = services; } } \ No newline at end of file diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java index 5cf830b9..675d27cb 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java @@ -11,6 +11,7 @@ import com.nimbusds.oauth2.sdk.id.ClientID; import com.nimbusds.oauth2.sdk.id.Issuer; import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; +import io.grpc.Channel; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.Status; @@ -23,6 +24,7 @@ import java.io.IOException; import java.util.UUID; +import java.util.function.Function; /** * A builder class for creating instances of the SDK class. @@ -33,9 +35,13 @@ public class SDKBuilder { private ClientAuthentication clientAuth = null; private Boolean usePlainText; + private static final Logger logger = LoggerFactory.getLogger(SDKBuilder.class); + public static SDKBuilder newBuilder() { SDKBuilder builder = new SDKBuilder(); builder.usePlainText = false; + builder.clientAuth = null; + builder.platformEndpoint = null; return builder; } @@ -57,8 +63,16 @@ public SDKBuilder useInsecurePlaintextConnection(Boolean usePlainText) { return this; } - // this is not exposed publicly so that it can be tested - ManagedChannel buildChannel() { + private GRPCAuthInterceptor getGrpcAuthInterceptor() { + if (platformEndpoint == null) { + throw new SDKException("cannot build an SDK without specifying the platform endpoint"); + } + + if (clientAuth == null) { + // this simplifies things for now, if we need to support this case we can revisit + throw new SDKException("cannot build an SDK without specifying OAuth credentials"); + } + // we don't add the auth listener to this channel since it is only used to call the // well known endpoint ManagedChannel bootstrapChannel = null; @@ -107,24 +121,39 @@ ManagedChannel buildChannel() { throw new SDKException("Error generating DPoP key", e); } - GRPCAuthInterceptor interceptor = new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI()); + return new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI()); + } - return getManagedChannelBuilder() - .intercept(interceptor) - .build(); + SDK.Services buildServices() { + var authInterceptor = getGrpcAuthInterceptor(); + var channel = getManagedChannelBuilder().intercept(authInterceptor).build(); + var client = new KASClient(getChannelFactory(authInterceptor)); + return SDK.Services.newServices(channel, client); } public SDK build() { - return new SDK(SDK.Services.newServices(buildChannel())); + return new SDK(buildServices()); } private ManagedChannelBuilder getManagedChannelBuilder() { - ManagedChannelBuilder channelBuilder = ManagedChannelBuilder - .forTarget(platformEndpoint); + ManagedChannelBuilder channelBuilder = ManagedChannelBuilder.forTarget(platformEndpoint); if (usePlainText) { channelBuilder = channelBuilder.usePlaintext(); } return channelBuilder; } + + Function getChannelFactory(GRPCAuthInterceptor authInterceptor) { + var pt = usePlainText; // no need to have the builder be able to influence things from beyond the grave + return (SDK.KASInfo kasInfo) -> { + ManagedChannelBuilder channelBuilder = ManagedChannelBuilder + .forTarget(kasInfo.getAddress()) + .intercept(authInterceptor); + if (pt) { + channelBuilder = channelBuilder.usePlaintext(); + } + return channelBuilder.build(); + }; + } } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDKException.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDKException.java index 0db5da43..2d0b390d 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDKException.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDKException.java @@ -4,4 +4,8 @@ public class SDKException extends RuntimeException { public SDKException(String message, Exception reason) { super(message, reason); } + + public SDKException(String message) { + super(message); + } } diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java index d24e28b3..28ce0ecb 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java @@ -11,11 +11,19 @@ import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.stub.StreamObserver; +import io.opentdf.platform.kas.AccessServiceGrpc; +import io.opentdf.platform.kas.RewrapRequest; +import io.opentdf.platform.kas.RewrapResponse; +import io.opentdf.platform.policy.namespaces.GetNamespaceRequest; +import io.opentdf.platform.policy.namespaces.GetNamespaceResponse; +import io.opentdf.platform.policy.namespaces.NamespaceServiceGrpc; import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest; import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse; import io.opentdf.platform.wellknownconfiguration.WellKnownServiceGrpc; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import java.io.IOException; @@ -30,8 +38,9 @@ public class SDKBuilderTest { @Test - void testCreatingSDKChannel() throws IOException, InterruptedException { - Server wellknownServer = null; + void testCreatingSDKServices() throws IOException, InterruptedException { + Server platformServicesServer = null; + Server kasServer = null; // we use the HTTP server for two things: // * it returns the OIDC configuration we use at bootstrapping time // * it fakes out being an IDP and returns an access token when need to retrieve an access token @@ -51,6 +60,8 @@ void testCreatingSDKChannel() throws IOException, InterruptedException { .setHeader("Content-type", "application/json") ); + // this service returns the platform_issuer url to the SDK during bootstrapping. This + // tells the SDK where to download the OIDC discovery document from (our test webserver!) WellKnownServiceGrpc.WellKnownServiceImplBase wellKnownService = new WellKnownServiceGrpc.WellKnownServiceImplBase() { @Override public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request, StreamObserver responseObserver) { @@ -65,55 +76,76 @@ public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request, } }; - AtomicReference authHeaderFromRequest = new AtomicReference<>(null); - AtomicReference dpopHeaderFromRequest = new AtomicReference<>(null); + // remember the auth headers that we received during GRPC calls to platform services + AtomicReference servicesAuthHeader = new AtomicReference<>(null); + AtomicReference servicesDPoPHeader = new AtomicReference<>(null); + // remember the auth headers that we received during GRPC calls to KAS + AtomicReference kasAuthHeader = new AtomicReference<>(null); + AtomicReference kasDPoPHeader = new AtomicReference<>(null); // we use the server in two different ways. the first time we use it to actually return // issuer for bootstrapping. the second time we use the interception functionality in order // to make sure that we are including a DPoP proof and an auth header - int randomPort; - try (ServerSocket socket = new ServerSocket(0)) { - randomPort = socket.getLocalPort(); - } - wellknownServer = ServerBuilder - .forPort(randomPort) + platformServicesServer = ServerBuilder + .forPort(getRandomPort()) .directExecutor() .addService(wellKnownService) + .addService(new NamespaceServiceGrpc.NamespaceServiceImplBase() {}) + .intercept(new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, Metadata headers, ServerCallHandler next) { + servicesAuthHeader.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER))); + servicesDPoPHeader.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER))); + return next.startCall(call, headers); + } + }) + .build() + .start(); + + + kasServer = ServerBuilder + .forPort(getRandomPort()) + .directExecutor() + .addService(new AccessServiceGrpc.AccessServiceImplBase() { + @Override + public void rewrap(RewrapRequest request, StreamObserver responseObserver) { + responseObserver.onNext(RewrapResponse.getDefaultInstance()); + responseObserver.onCompleted(); + } + }) .intercept(new ServerInterceptor() { @Override public ServerCall.Listener interceptCall(ServerCall call, Metadata headers, ServerCallHandler next) { - authHeaderFromRequest.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER))); - dpopHeaderFromRequest.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER))); + kasAuthHeader.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER))); + kasDPoPHeader.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER))); return next.startCall(call, headers); } }) .build() .start(); - ManagedChannel channel = SDKBuilder + SDK.Services services = SDKBuilder .newBuilder() .clientSecret("client-id", "client-secret") - .platformEndpoint("localhost:" + wellknownServer.getPort()) + .platformEndpoint("localhost:" + platformServicesServer.getPort()) .useInsecurePlaintextConnection(true) - .buildChannel(); - - assertThat(channel).isNotNull(); - assertThat(channel.getState(false)).isEqualTo(ConnectivityState.IDLE); + .buildServices(); - var wellKnownStub = WellKnownServiceGrpc.newBlockingStub(channel); + assertThat(services).isNotNull(); httpServer.enqueue(new MockResponse() .setBody("{\"access_token\": \"hereisthetoken\", \"token_type\": \"Bearer\"}") .setHeader("Content-Type", "application/json")); - var ignored = wellKnownStub.getWellKnownConfiguration(GetWellKnownConfigurationRequest.getDefaultInstance()); - channel.shutdownNow(); + var ignored = services.namespaces().getNamespace(GetNamespaceRequest.getDefaultInstance()); // we've now made two requests. one to get the bootstrapping info and one // call that should activate the token fetching logic assertThat(httpServer.getRequestCount()).isEqualTo(2); httpServer.takeRequest(); + + // validate that we made a reasonable request to our fake IdP to get an access token var accessTokenRequest = httpServer.takeRequest(); assertThat(accessTokenRequest).isNotNull(); var authHeader = accessTokenRequest.getHeader("Authorization"); @@ -124,16 +156,35 @@ public ServerCall.Listener interceptCall(ServerCall "localhost:" + kasPort; + services.kas().unwrap(kasInfo, new SDK.Policy() {}); + + assertThat(kasDPoPHeader.get()).isNotNull(); + assertThat(kasAuthHeader.get()).isEqualTo("DPoP hereisthetoken"); } finally { - if (wellknownServer != null) { - wellknownServer.shutdownNow(); + if (platformServicesServer != null) { + platformServicesServer.shutdownNow(); + } + if (kasServer != null) { + kasServer.shutdownNow(); } } } + + private static int getRandomPort() throws IOException { + int randomPort; + try (ServerSocket socket = new ServerSocket(0)) { + randomPort = socket.getLocalPort(); + } + return randomPort; + } }