Skip to content

Commit

Permalink
Merge branch 'main' into dmihalcik-virtru-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
mkleene authored May 24, 2024
2 parents 30e452f + 04ca715 commit 8cd73dc
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 37 deletions.
41 changes: 41 additions & 0 deletions sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.opentdf.platform.sdk;

import io.grpc.Channel;
import io.opentdf.platform.kas.AccessServiceGrpc;
import io.opentdf.platform.kas.PublicKeyRequest;
import io.opentdf.platform.kas.RewrapRequest;

import java.util.HashMap;
import java.util.function.Function;

public class KASClient implements SDK.KAS {

private final Function<SDK.KASInfo, Channel> channelFactory;

public KASClient(Function <SDK.KASInfo, Channel> channelFactory) {
this.channelFactory = channelFactory;
}

@Override
public String getPublicKey(SDK.KASInfo kasInfo) {
return getStub(kasInfo).publicKey(PublicKeyRequest.getDefaultInstance()).getPublicKey();
}

@Override
public byte[] unwrap(SDK.KASInfo kasInfo, SDK.Policy policy) {
// this is obviously wrong. we still have to generate a correct request and decrypt the payload
return getStub(kasInfo).rewrap(RewrapRequest.getDefaultInstance()).getEntityWrappedKey().toByteArray();
}

private final HashMap<SDK.KASInfo, AccessServiceGrpc.AccessServiceBlockingStub> stubs = new HashMap<>();

private synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(SDK.KASInfo kasInfo) {
if (!stubs.containsKey(kasInfo)) {
var channel = channelFactory.apply(kasInfo);
var stub = AccessServiceGrpc.newBlockingStub(channel);
stubs.put(kasInfo, stub);
}

return stubs.get(kasInfo);
}
}
22 changes: 19 additions & 3 deletions sdk/src/main/java/io/opentdf/platform/sdk/SDK.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,25 @@
public class SDK {
private final Services services;

public interface KASInfo{
String getAddress();
}
public interface Policy{}

interface KAS {
String getPublicKey(KASInfo kasInfo);
byte[] unwrap(KASInfo kasInfo, Policy policy);
}

// TODO: add KAS
public interface Services {
interface Services {
AttributesServiceFutureStub attributes();
NamespaceServiceFutureStub namespaces();
SubjectMappingServiceFutureStub subjectMappings();
ResourceMappingServiceFutureStub resourceMappings();
KAS kas();

static Services newServices(Channel channel) {
static Services newServices(Channel channel, KAS kas) {
var attributeService = AttributesServiceGrpc.newFutureStub(channel);
var namespaceService = NamespaceServiceGrpc.newFutureStub(channel);
var subjectMappingService = SubjectMappingServiceGrpc.newFutureStub(channel);
Expand All @@ -50,11 +61,16 @@ public SubjectMappingServiceFutureStub subjectMappings() {
public ResourceMappingServiceFutureStub resourceMappings() {
return resourceMappingService;
}

@Override
public KAS kas() {
return kas;
}
};
}
}

public SDK(Services services) {
SDK(Services services) {
this.services = services;
}
}
47 changes: 38 additions & 9 deletions sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.Issuer;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Status;
Expand All @@ -23,6 +24,7 @@

import java.io.IOException;
import java.util.UUID;
import java.util.function.Function;

/**
* A builder class for creating instances of the SDK class.
Expand All @@ -33,9 +35,13 @@ public class SDKBuilder {
private ClientAuthentication clientAuth = null;
private Boolean usePlainText;

private static final Logger logger = LoggerFactory.getLogger(SDKBuilder.class);

public static SDKBuilder newBuilder() {
SDKBuilder builder = new SDKBuilder();
builder.usePlainText = false;
builder.clientAuth = null;
builder.platformEndpoint = null;

return builder;
}
Expand All @@ -57,8 +63,16 @@ public SDKBuilder useInsecurePlaintextConnection(Boolean usePlainText) {
return this;
}

// this is not exposed publicly so that it can be tested
ManagedChannel buildChannel() {
private GRPCAuthInterceptor getGrpcAuthInterceptor() {
if (platformEndpoint == null) {
throw new SDKException("cannot build an SDK without specifying the platform endpoint");
}

if (clientAuth == null) {
// this simplifies things for now, if we need to support this case we can revisit
throw new SDKException("cannot build an SDK without specifying OAuth credentials");
}

// we don't add the auth listener to this channel since it is only used to call the
// well known endpoint
ManagedChannel bootstrapChannel = null;
Expand Down Expand Up @@ -107,24 +121,39 @@ ManagedChannel buildChannel() {
throw new SDKException("Error generating DPoP key", e);
}

GRPCAuthInterceptor interceptor = new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI());
return new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI());
}

return getManagedChannelBuilder()
.intercept(interceptor)
.build();
SDK.Services buildServices() {
var authInterceptor = getGrpcAuthInterceptor();
var channel = getManagedChannelBuilder().intercept(authInterceptor).build();
var client = new KASClient(getChannelFactory(authInterceptor));
return SDK.Services.newServices(channel, client);
}

public SDK build() {
return new SDK(SDK.Services.newServices(buildChannel()));
return new SDK(buildServices());
}

private ManagedChannelBuilder<?> getManagedChannelBuilder() {
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder
.forTarget(platformEndpoint);
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(platformEndpoint);

if (usePlainText) {
channelBuilder = channelBuilder.usePlaintext();
}
return channelBuilder;
}

Function<SDK.KASInfo, Channel> getChannelFactory(GRPCAuthInterceptor authInterceptor) {
var pt = usePlainText; // no need to have the builder be able to influence things from beyond the grave
return (SDK.KASInfo kasInfo) -> {
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder
.forTarget(kasInfo.getAddress())
.intercept(authInterceptor);
if (pt) {
channelBuilder = channelBuilder.usePlaintext();
}
return channelBuilder.build();
};
}
}
4 changes: 4 additions & 0 deletions sdk/src/main/java/io/opentdf/platform/sdk/SDKException.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@ public class SDKException extends RuntimeException {
public SDKException(String message, Exception reason) {
super(message, reason);
}

public SDKException(String message) {
super(message);
}
}
101 changes: 76 additions & 25 deletions sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,19 @@
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.stub.StreamObserver;
import io.opentdf.platform.kas.AccessServiceGrpc;
import io.opentdf.platform.kas.RewrapRequest;
import io.opentdf.platform.kas.RewrapResponse;
import io.opentdf.platform.policy.namespaces.GetNamespaceRequest;
import io.opentdf.platform.policy.namespaces.GetNamespaceResponse;
import io.opentdf.platform.policy.namespaces.NamespaceServiceGrpc;
import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest;
import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse;
import io.opentdf.platform.wellknownconfiguration.WellKnownServiceGrpc;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.io.IOException;
Expand All @@ -30,8 +38,9 @@
public class SDKBuilderTest {

@Test
void testCreatingSDKChannel() throws IOException, InterruptedException {
Server wellknownServer = null;
void testCreatingSDKServices() throws IOException, InterruptedException {
Server platformServicesServer = null;
Server kasServer = null;
// we use the HTTP server for two things:
// * it returns the OIDC configuration we use at bootstrapping time
// * it fakes out being an IDP and returns an access token when need to retrieve an access token
Expand All @@ -51,6 +60,8 @@ void testCreatingSDKChannel() throws IOException, InterruptedException {
.setHeader("Content-type", "application/json")
);

// this service returns the platform_issuer url to the SDK during bootstrapping. This
// tells the SDK where to download the OIDC discovery document from (our test webserver!)
WellKnownServiceGrpc.WellKnownServiceImplBase wellKnownService = new WellKnownServiceGrpc.WellKnownServiceImplBase() {
@Override
public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request, StreamObserver<GetWellKnownConfigurationResponse> responseObserver) {
Expand All @@ -65,55 +76,76 @@ public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request,
}
};

AtomicReference<String> authHeaderFromRequest = new AtomicReference<>(null);
AtomicReference<String> dpopHeaderFromRequest = new AtomicReference<>(null);
// remember the auth headers that we received during GRPC calls to platform services
AtomicReference<String> servicesAuthHeader = new AtomicReference<>(null);
AtomicReference<String> servicesDPoPHeader = new AtomicReference<>(null);

// remember the auth headers that we received during GRPC calls to KAS
AtomicReference<String> kasAuthHeader = new AtomicReference<>(null);
AtomicReference<String> kasDPoPHeader = new AtomicReference<>(null);
// we use the server in two different ways. the first time we use it to actually return
// issuer for bootstrapping. the second time we use the interception functionality in order
// to make sure that we are including a DPoP proof and an auth header
int randomPort;
try (ServerSocket socket = new ServerSocket(0)) {
randomPort = socket.getLocalPort();
}
wellknownServer = ServerBuilder
.forPort(randomPort)
platformServicesServer = ServerBuilder
.forPort(getRandomPort())
.directExecutor()
.addService(wellKnownService)
.addService(new NamespaceServiceGrpc.NamespaceServiceImplBase() {})
.intercept(new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
servicesAuthHeader.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));
servicesDPoPHeader.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER)));
return next.startCall(call, headers);
}
})
.build()
.start();


kasServer = ServerBuilder
.forPort(getRandomPort())
.directExecutor()
.addService(new AccessServiceGrpc.AccessServiceImplBase() {
@Override
public void rewrap(RewrapRequest request, StreamObserver<RewrapResponse> responseObserver) {
responseObserver.onNext(RewrapResponse.getDefaultInstance());
responseObserver.onCompleted();
}
})
.intercept(new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
authHeaderFromRequest.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));
dpopHeaderFromRequest.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER)));
kasAuthHeader.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));
kasDPoPHeader.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER)));
return next.startCall(call, headers);
}
})
.build()
.start();

ManagedChannel channel = SDKBuilder
SDK.Services services = SDKBuilder
.newBuilder()
.clientSecret("client-id", "client-secret")
.platformEndpoint("localhost:" + wellknownServer.getPort())
.platformEndpoint("localhost:" + platformServicesServer.getPort())
.useInsecurePlaintextConnection(true)
.buildChannel();

assertThat(channel).isNotNull();
assertThat(channel.getState(false)).isEqualTo(ConnectivityState.IDLE);
.buildServices();

var wellKnownStub = WellKnownServiceGrpc.newBlockingStub(channel);
assertThat(services).isNotNull();

httpServer.enqueue(new MockResponse()
.setBody("{\"access_token\": \"hereisthetoken\", \"token_type\": \"Bearer\"}")
.setHeader("Content-Type", "application/json"));

var ignored = wellKnownStub.getWellKnownConfiguration(GetWellKnownConfigurationRequest.getDefaultInstance());
channel.shutdownNow();
var ignored = services.namespaces().getNamespace(GetNamespaceRequest.getDefaultInstance());

// we've now made two requests. one to get the bootstrapping info and one
// call that should activate the token fetching logic
assertThat(httpServer.getRequestCount()).isEqualTo(2);

httpServer.takeRequest();

// validate that we made a reasonable request to our fake IdP to get an access token
var accessTokenRequest = httpServer.takeRequest();
assertThat(accessTokenRequest).isNotNull();
var authHeader = accessTokenRequest.getHeader("Authorization");
Expand All @@ -124,16 +156,35 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
var usernameAndPassword = new String(Base64.getDecoder().decode(authHeaderParts[1]), StandardCharsets.UTF_8);
assertThat(usernameAndPassword).isEqualTo("client-id:client-secret");

assertThat(dpopHeaderFromRequest.get()).isNotNull();
assertThat(authHeaderFromRequest.get()).isEqualTo("DPoP hereisthetoken");
// validate that during the request to the namespace service we supplied a valid token
assertThat(servicesDPoPHeader.get()).isNotNull();
assertThat(servicesAuthHeader.get()).isEqualTo("DPoP hereisthetoken");

var body = new String(accessTokenRequest.getBody().readByteArray(), StandardCharsets.UTF_8);
assertThat(body).contains("grant_type=client_credentials");

// now call KAS _on a different server_ and make sure that the interceptors provide us with auth tokens
int kasPort = kasServer.getPort();
SDK.KASInfo kasInfo = () -> "localhost:" + kasPort;
services.kas().unwrap(kasInfo, new SDK.Policy() {});

assertThat(kasDPoPHeader.get()).isNotNull();
assertThat(kasAuthHeader.get()).isEqualTo("DPoP hereisthetoken");
} finally {
if (wellknownServer != null) {
wellknownServer.shutdownNow();
if (platformServicesServer != null) {
platformServicesServer.shutdownNow();
}
if (kasServer != null) {
kasServer.shutdownNow();
}
}
}

private static int getRandomPort() throws IOException {
int randomPort;
try (ServerSocket socket = new ServerSocket(0)) {
randomPort = socket.getLocalPort();
}
return randomPort;
}
}

0 comments on commit 8cd73dc

Please sign in to comment.