From fa36f83d4419c76e913f98894a091fb7e6aea826 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Mon, 18 Nov 2024 16:51:38 +0000 Subject: [PATCH] In-progress changes. --- .../io/grpc/netty/ProtocolNegotiators.java | 53 +++++++++++++++++-- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 84370ac8153..ddef7bd0c26 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -67,9 +67,15 @@ import io.netty.handler.ssl.SslProvider; import io.netty.util.AsciiString; import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; import java.net.SocketAddress; import java.net.URI; import java.nio.channels.ClosedChannelException; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; import java.util.Arrays; import java.util.EnumSet; import java.util.Optional; @@ -82,6 +88,9 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; /** * Common {@link ProtocolNegotiator}s used by gRPC. @@ -99,7 +108,8 @@ final class ProtocolNegotiators { private ProtocolNegotiators() { } - public static FromChannelCredentialsResult from(ChannelCredentials creds) { + public static FromChannelCredentialsResult from(ChannelCredentials creds) + throws KeyStoreException, NoSuchAlgorithmException { if (creds instanceof TlsChannelCredentials) { TlsChannelCredentials tlsCreds = (TlsChannelCredentials) creds; Set incomprehensible = @@ -117,11 +127,20 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { new ByteArrayInputStream(tlsCreds.getPrivateKey()), tlsCreds.getPrivateKeyPassword()); } + Optional x509ExtendedTrustManager; if (tlsCreds.getTrustManagers() != null) { builder.trustManager(new FixedTrustManagerFactory(tlsCreds.getTrustManagers())); + x509ExtendedTrustManager = tlsCreds.getTrustManagers().stream().filter( + trustManager -> trustManager instanceof X509ExtendedTrustManager).findFirst(); } else if (tlsCreds.getRootCertificates() != null) { builder.trustManager(new ByteArrayInputStream(tlsCreds.getRootCertificates())); - } // else use system default + } else { // else use system default + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init((KeyStore)null); + x509ExtendedTrustManager = Arrays.stream(tmf.getTrustManagers()) + .filter(trustManager -> trustManager instanceof X509ExtendedTrustManager).findFirst(); + } try { return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build())); } catch (SSLException ex) { @@ -161,6 +180,26 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { } } + private static X509ExtendedTrustManager getX509ExtendedTrustManager(InputStream rootCerts) throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + return trustManagerFactory.getTrustManagers(); + } + public static FromServerCredentialsResult from(ServerCredentials creds) { if (creds instanceof TlsServerCredentials) { TlsServerCredentials tlsCreds = (TlsServerCredentials) creds; @@ -711,17 +750,21 @@ public static ProtocolNegotiator tls(SslContext sslContext) { return tls(sslContext, null, Optional.empty()); } - public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) { - return new TlsProtocolNegotiatorClientFactory(sslContext); + public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext, + X509ExtendedTrustManager x509ExtendedTrustManager) { + return new TlsProtocolNegotiatorClientFactory(sslContext, x509ExtendedTrustManager); } @VisibleForTesting static final class TlsProtocolNegotiatorClientFactory implements ProtocolNegotiator.ClientFactory { private final SslContext sslContext; + private final X509ExtendedTrustManager x509ExtendedTrustManager; - public TlsProtocolNegotiatorClientFactory(SslContext sslContext) { + public TlsProtocolNegotiatorClientFactory(SslContext sslContext, + X509ExtendedTrustManager x509ExtendedTrustManager) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + this.x509ExtendedTrustManager = x509ExtendedTrustManager; } @Override public ProtocolNegotiator newNegotiator() {