From 906e7094d49d0ddc5de92e7a5792e369dc3e7804 Mon Sep 17 00:00:00 2001 From: Riya Mehta <55350838+rmehta19@users.noreply.github.com> Date: Tue, 8 Oct 2024 10:56:43 -0700 Subject: [PATCH] s2a: Correct type of exception thrown (#11588) * throw IllegalArgumentException in ProtoUtil. * throw exception in TrustManager in more standard way. * handle IllegalArgumentException in SslContextFactory. * Don't throw error on unknown TLS version. --- .../grpc/s2a/internal/handshaker/ProtoUtil.java | 11 ++++++++--- .../s2a/internal/handshaker/S2ATrustManager.java | 8 ++++---- .../internal/handshaker/SslContextFactory.java | 6 ++++-- .../s2a/internal/handshaker/ProtoUtilTest.java | 16 +++++++--------- .../handshaker/SslContextFactoryTest.java | 2 +- 5 files changed, 24 insertions(+), 19 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java index 1d88d5a2b55..1f24727a083 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java @@ -27,7 +27,8 @@ final class ProtoUtil { * * @param tlsVersion the {@link TLSVersion} object to be converted. * @return a {@link String} representation of the TLS version. - * @throws AssertionError if the {@code tlsVersion} is not one of the supported TLS versions. + * @throws IllegalArgumentException if the {@code tlsVersion} is not one of + * the supported TLS versions. */ @VisibleForTesting static String convertTlsProtocolVersion(TLSVersion tlsVersion) { @@ -41,7 +42,7 @@ static String convertTlsProtocolVersion(TLSVersion tlsVersion) { case TLS_VERSION_1_0: return "TLSv1"; default: - throw new AssertionError( + throw new IllegalArgumentException( String.format("TLS version %d is not supported.", tlsVersion.getNumber())); } } @@ -62,7 +63,11 @@ static ImmutableSet buildTlsProtocolVersionSet( } if (versionNumber >= minTlsVersion.getNumber() && versionNumber <= maxTlsVersion.getNumber()) { - tlsVersions.add(convertTlsProtocolVersion(tlsVersion)); + try { + tlsVersions.add(convertTlsProtocolVersion(tlsVersion)); + } catch (IllegalArgumentException e) { + continue; + } } } return tlsVersions.build(); diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java index 2f7e5750f88..406545b30bf 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java @@ -120,10 +120,10 @@ private void checkPeerTrusted(X509Certificate[] chain, boolean isCheckingClientC SessionResp resp; try { resp = stub.send(reqBuilder.build()); - } catch (IOException | InterruptedException e) { - if (e instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } + } catch (IOException e) { + throw new CertificateException("Failed to send request to S2A.", e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); throw new CertificateException("Failed to send request to S2A.", e); } if (resp.hasStatus() && resp.getStatus().getCode() != 0) { diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java index 72ace2c7885..3e5481daa9e 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java @@ -138,11 +138,13 @@ private static void configureSslContextWithClientTlsConfiguration( NoSuchAlgorithmException, UnrecoverableKeyException { sslContextBuilder.keyManager(createKeylessManager(clientTlsConfiguration)); - ImmutableSet tlsVersions = + ImmutableSet tlsVersions; + tlsVersions = ProtoUtil.buildTlsProtocolVersionSet( clientTlsConfiguration.getMinTlsVersion(), clientTlsConfiguration.getMaxTlsVersion()); if (tlsVersions.isEmpty()) { - throw new S2AConnectionException("Set of TLS versions received from S2A server is empty."); + throw new S2AConnectionException("Set of TLS versions received from S2A server is" + + " empty or not supported."); } sslContextBuilder.protocols(tlsVersions); } diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/ProtoUtilTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/ProtoUtilTest.java index b685d0bc755..f60aa1a189b 100644 --- a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/ProtoUtilTest.java +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/ProtoUtilTest.java @@ -46,9 +46,9 @@ public void convertTlsProtocolVersion_success() { @Test public void convertTlsProtocolVersion_withUnknownTlsVersion_fails() { - AssertionError expected = + IllegalArgumentException expected = assertThrows( - AssertionError.class, + IllegalArgumentException.class, () -> ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_UNSPECIFIED)); expect.that(expected).hasMessageThat().isEqualTo("TLS version 0 is not supported."); } @@ -79,12 +79,10 @@ public void buildTlsProtocolVersionSet_success() { @Test public void buildTlsProtocolVersionSet_failure() { - AssertionError expected = - assertThrows( - AssertionError.class, - () -> - ProtoUtil.buildTlsProtocolVersionSet( - TLSVersion.TLS_VERSION_UNSPECIFIED, TLSVersion.TLS_VERSION_1_3)); - expect.that(expected).hasMessageThat().isEqualTo("TLS version 0 is not supported."); + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_UNSPECIFIED, TLSVersion.TLS_VERSION_1_3)) + .isEqualTo(ImmutableSet.of("TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3")); } } \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/SslContextFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/SslContextFactoryTest.java index fc3cfb5e441..17b834abf2a 100644 --- a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/SslContextFactoryTest.java +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/SslContextFactoryTest.java @@ -142,7 +142,7 @@ public void createForClient_getsBadTlsVersionsFromServer_throwsError() throws Ex assertThat(expected) .hasMessageThat() - .contains("Set of TLS versions received from S2A server is empty."); + .contains("Set of TLS versions received from S2A server is empty or not supported."); } @Test