diff --git a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/DelegatingSingleAddressHttpClientBuilder.java b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/DelegatingSingleAddressHttpClientBuilder.java index 125aa552b4..8f92e495a3 100644 --- a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/DelegatingSingleAddressHttpClientBuilder.java +++ b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/DelegatingSingleAddressHttpClientBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright © 2022 Apple Inc. and the ServiceTalk project authors + * Copyright © 2022-2023 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import java.net.SocketOption; import java.util.function.BooleanSupplier; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; @@ -73,6 +74,13 @@ public SingleAddressHttpClientBuilder proxyAddress(final U proxyAddress) { return this; } + @Override + public SingleAddressHttpClientBuilder proxyAddress( + final U proxyAddress, final Consumer connectRequestInitializer) { + delegate = delegate.proxyAddress(proxyAddress, connectRequestInitializer); + return this; + } + @Override public SingleAddressHttpClientBuilder socketOption(final SocketOption option, final T value) { delegate = delegate.socketOption(option, value); diff --git a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/SingleAddressHttpClientBuilder.java b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/SingleAddressHttpClientBuilder.java index bc6518545c..f6f2120d67 100644 --- a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/SingleAddressHttpClientBuilder.java +++ b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/SingleAddressHttpClientBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright © 2018-2019 Apple Inc. and the ServiceTalk project authors + * Copyright © 2018-2023 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ import java.net.SocketOption; import java.net.StandardSocketOptions; import java.util.function.BooleanSupplier; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; @@ -49,15 +50,40 @@ public interface SingleAddressHttpClientBuilder extends HttpClientBuilder> { /** * Configure proxy to serve as an intermediary for requests. + *

+ * If the client talks to a proxy over http (not https, {@link #sslConfig(ClientSslConfig) ClientSslConfig} is NOT + * configured), it will rewrite the request-target to + * absolute-form, as specified by the RFC. + * * @param proxyAddress Unresolved address of the proxy. When used with a builder created for a resolved address, * {@code proxyAddress} should also be already resolved – otherwise runtime exceptions may occur. * @return {@code this}. */ - default SingleAddressHttpClientBuilder proxyAddress(U proxyAddress) { + default SingleAddressHttpClientBuilder proxyAddress(U proxyAddress) { // FIXME: 0.43 - remove default impl throw new UnsupportedOperationException("Setting proxy address is not yet supported by " + getClass().getName()); } + /** + * Configure proxy to serve as an intermediary for requests. + *

+ * If the client talks to a proxy over http (not https, {@link #sslConfig(ClientSslConfig) ClientSslConfig} is NOT + * configured), it will rewrite the request-target to + * absolute-form, as specified by the RFC. + * + * @param proxyAddress Unresolved address of the proxy. When used with a builder created for a resolved address, + * {@code proxyAddress} should also be already resolved – otherwise runtime exceptions may occur. + * @param connectRequestInitializer {@link Consumer} of {@link StreamingHttpRequest} that can be used to add + * additional info to HTTP/1.1 CONNECT + * request. It can be used to add headers, like {@link HttpHeaderNames#PROXY_AUTHORIZATION}, debugging info, etc. + * @return {@code this}. + */ + default SingleAddressHttpClientBuilder proxyAddress(// FIXME: 0.43 - remove default impl + U proxyAddress, Consumer connectRequestInitializer) { + throw new UnsupportedOperationException( + "Setting proxy address with request initializer is not yet supported by " + getClass().getName()); + } + /** * Adds a {@link SocketOption} for all connections created by this builder. * diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java index 6c22f82953..fb49281ed1 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java @@ -1,5 +1,5 @@ /* - * Copyright © 2018-2019, 2021-2022 Apple Inc. and the ServiceTalk project authors + * Copyright © 2018-2023 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ import io.servicetalk.concurrent.api.Publisher; import io.servicetalk.concurrent.api.Single; import io.servicetalk.concurrent.api.TerminalSignalConsumer; -import io.servicetalk.http.api.FilterableStreamingHttpConnection; import io.servicetalk.http.api.HttpConnectionContext; import io.servicetalk.http.api.HttpEventKey; import io.servicetalk.http.api.HttpExecutionContext; @@ -39,6 +38,7 @@ import io.servicetalk.transport.netty.internal.FlushStrategy; import io.servicetalk.transport.netty.internal.NettyConnectionContext; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -66,7 +66,7 @@ import static java.util.Objects.requireNonNull; abstract class AbstractStreamingHttpConnection - implements FilterableStreamingHttpConnection { + implements NettyFilterableStreamingHttpConnection { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractStreamingHttpConnection.class); static final IgnoreConsumedEvent ZERO_MAX_CONCURRENCY_EVENT = new IgnoreConsumedEvent<>(0); @@ -168,7 +168,7 @@ public void cancel() { } @Override - public Single request(final StreamingHttpRequest request) { + public final Single request(final StreamingHttpRequest request) { return defer(() -> { Publisher flatRequest; // See https://tools.ietf.org/html/rfc7230#section-3.3.3 @@ -260,6 +260,11 @@ public final StreamingHttpResponseFactory httpResponseFactory() { return reqRespFactory; } + @Override + public Channel nettyChannel() { + return connection.nettyChannel(); + } + @Override public final Completable onClose() { return connectionContext.onClose(); diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultSingleAddressHttpClientBuilder.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultSingleAddressHttpClientBuilder.java index bb74597dee..92c1acd19f 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultSingleAddressHttpClientBuilder.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultSingleAddressHttpClientBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright © 2018-2022 Apple Inc. and the ServiceTalk project authors + * Copyright © 2018-2023 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,6 +65,7 @@ import java.time.Duration; import java.util.Collection; import java.util.function.BooleanSupplier; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; import javax.annotation.Nullable; @@ -109,6 +110,7 @@ final class DefaultSingleAddressHttpClientBuilder implements SingleAddress private final U address; @Nullable private U proxyAddress; + private Consumer proxyConnectRequestInitializer = __ -> { }; private final HttpClientConfig config; final HttpExecutionContextBuilder executionContextBuilder; private final ClientStrategyInfluencerChainBuilder strategyComputation; @@ -146,6 +148,7 @@ private DefaultSingleAddressHttpClientBuilder(@Nullable final U address, final DefaultSingleAddressHttpClientBuilder from) { this.address = address; proxyAddress = from.proxyAddress; + proxyConnectRequestInitializer = from.proxyConnectRequestInitializer; config = new HttpClientConfig(from.config); executionContextBuilder = new HttpExecutionContextBuilder(from.executionContextBuilder); strategyComputation = from.strategyComputation.copy(); @@ -278,14 +281,18 @@ public HttpExecutionStrategy executionStrategy() { H2ProtocolConfig h2Config = roConfig.h2Config(); connectionFactory = new AlpnLBHttpConnectionFactory<>(roConfig, executionContext, connectionFilterFactory, new AlpnReqRespFactoryFunc( - executionContext.bufferAllocator(), - h1Config == null ? null : h1Config.headersFactory(), - h2Config == null ? null : h2Config.headersFactory()), + executionContext.bufferAllocator(), + h1Config == null ? null : h1Config.headersFactory(), + h2Config == null ? null : h2Config.headersFactory()), connectionFactoryStrategy, connectionFactoryFilter, ctx.builder.loadBalancerFactory::toLoadBalancedConnection); + } else if (roConfig.hasProxy() && sslContext != null) { + connectionFactory = new ProxyConnectLBHttpConnectionFactory<>(roConfig, executionContext, + connectionFilterFactory, reqRespFactory, + connectionFactoryStrategy, connectionFactoryFilter, + ctx.builder.loadBalancerFactory::toLoadBalancedConnection, + ctx.builder.proxyConnectRequestInitializer); } else { - H1ProtocolConfig h1Config = roConfig.h1Config(); - assert h1Config != null; connectionFactory = new PipelinedLBHttpConnectionFactory<>(roConfig, executionContext, connectionFilterFactory, reqRespFactory, connectionFactoryStrategy, connectionFactoryFilter, @@ -446,6 +453,14 @@ public DefaultSingleAddressHttpClientBuilder proxyAddress(final U proxyAdd return this; } + @Override + public SingleAddressHttpClientBuilder proxyAddress( + final U proxyAddress, final Consumer connectRequestInitializer) { + this.proxyAddress(proxyAddress); + this.proxyConnectRequestInitializer = requireNonNull(connectRequestInitializer); + return this; + } + @Override public DefaultSingleAddressHttpClientBuilder ioExecutor(final IoExecutor ioExecutor) { executionContextBuilder.ioExecutor(ioExecutor); diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/NettyFilterableStreamingHttpConnection.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/NettyFilterableStreamingHttpConnection.java new file mode 100644 index 0000000000..192e9fa431 --- /dev/null +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/NettyFilterableStreamingHttpConnection.java @@ -0,0 +1,33 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.http.netty; + +import io.servicetalk.http.api.FilterableStreamingHttpConnection; + +import io.netty.channel.Channel; + +/** + * {@link FilterableStreamingHttpConnection} that also gives access to Netty {@link Channel}. + */ +interface NettyFilterableStreamingHttpConnection extends FilterableStreamingHttpConnection { + + /** + * Return the Netty {@link Channel} backing this connection. + * + * @return the Netty {@link Channel} backing this connection. + */ + Channel nettyChannel(); +} diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java index 950acb2683..b9b4d20edf 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java @@ -28,6 +28,7 @@ import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; import static io.servicetalk.http.netty.StreamingConnectionFactory.buildStreaming; +import static java.util.Objects.requireNonNull; final class PipelinedLBHttpConnectionFactory extends AbstractLBHttpConnectionFactory { PipelinedLBHttpConnectionFactory( @@ -39,6 +40,7 @@ final class PipelinedLBHttpConnectionFactory extends AbstractLB final ProtocolBinding protocolBinding) { super(config, executionContext, version -> reqRespFactory, connectStrategy, connectionFactoryFilter, connectionFilterFunction, protocolBinding); + requireNonNull(config.h1Config(), "H1ProtocolConfig is required"); } @Override diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectConnectionFactoryFilter.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectConnectionFactoryFilter.java index 457663ca2a..2a3fada165 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectConnectionFactoryFilter.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectConnectionFactoryFilter.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019-2020, 2022-2023 Apple Inc. and the ServiceTalk project authors + * Copyright © 2019-2023 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,40 +18,31 @@ import io.servicetalk.client.api.ConnectionFactory; import io.servicetalk.client.api.ConnectionFactoryFilter; import io.servicetalk.client.api.DelegatingConnectionFactory; -import io.servicetalk.concurrent.SingleSource; import io.servicetalk.concurrent.api.Single; import io.servicetalk.concurrent.internal.DefaultContextMap; import io.servicetalk.context.api.ContextMap; import io.servicetalk.http.api.FilterableStreamingHttpConnection; +import io.servicetalk.http.api.HttpContextKeys; import io.servicetalk.http.api.HttpExecutionStrategies; import io.servicetalk.http.api.HttpExecutionStrategy; -import io.servicetalk.http.api.StreamingHttpResponse; -import io.servicetalk.transport.api.ConnectExecutionStrategy; import io.servicetalk.transport.api.ExecutionStrategy; -import io.servicetalk.transport.api.IoThreadFactory; import io.servicetalk.transport.api.TransportObserver; -import io.servicetalk.transport.netty.internal.DeferSslHandler; -import io.servicetalk.transport.netty.internal.NettyConnectionContext; -import io.servicetalk.transport.netty.internal.StacklessClosedChannelException; -import io.netty.channel.Channel; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.ssl.SslHandshakeCompletionEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nullable; -import static io.servicetalk.concurrent.api.Processors.newSingleProcessor; -import static io.servicetalk.concurrent.api.Single.failed; -import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; import static io.servicetalk.http.api.HttpContextKeys.HTTP_TARGET_ADDRESS_BEHIND_PROXY; -import static io.servicetalk.http.api.HttpHeaderNames.HOST; -import static io.servicetalk.http.api.HttpResponseStatus.StatusClass.SUCCESSFUL_2XX; /** - * A connection factory filter that sends a `CONNECT` request for https proxying. + * A {@link ConnectionFactoryFilter} that is prepended before any user-defined filters for the purpose of setting a + * {@link HttpContextKeys#HTTP_TARGET_ADDRESS_BEHIND_PROXY} key. + *

+ * The actual logic to do a proxy connect was moved to {@link ProxyConnectLBHttpConnectionFactory}. + *

+ * This filter can be removed when {@link HttpContextKeys#HTTP_TARGET_ADDRESS_BEHIND_PROXY} key is deprecated and + * removed. * * @param The type of resolved addresses that can be used for connecting. * @param The type of connections created by this factory. @@ -62,12 +53,9 @@ final class ProxyConnectConnectionFactoryFilter newConnection(final ResolvedAddress resolvedAddress, final ContextMap contextMap = context != null ? context : new DefaultContextMap(); logUnexpectedAddress(contextMap.put(HTTP_TARGET_ADDRESS_BEHIND_PROXY, connectAddress), connectAddress, LOGGER); - return delegate().newConnection(resolvedAddress, contextMap, observer).flatMap(c -> { - try { - // Send CONNECT request: https://datatracker.ietf.org/doc/html/rfc9110#section-9.3.6 - // Host header value must be equal to CONNECT request target, see - // https://github.com/haproxy/haproxy/issues/1159 - // https://datatracker.ietf.org/doc/html/rfc7230#section-5.4: - // If the target URI includes an authority component, then a client MUST send a field-value - // for Host that is identical to that authority component - return c.request(c.connect(connectAddress).setHeader(HOST, connectAddress)) - // Successful response to CONNECT never has a message body, and we are not interested in - // payload body for any non-200 status code. Drain it asap to free connection and RS - // resources before starting TLS handshake. - .flatMap(response -> response.messageBody().ignoreElements() - .concat(Single.defer(() -> handleConnectResponse(c, response) - .shareContextOnSubscribe())) - .shareContextOnSubscribe()) - // Close recently created connection in case of any error while it connects to the - // proxy: - .onErrorResume(t -> c.closeAsync().concat(failed(t))); - // We do not apply shareContextOnSubscribe() here to isolate a context for `CONNECT` request. - } catch (Throwable t) { - return c.closeAsync().concat(failed(t)); - } - }).shareContextOnSubscribe(); + // The rest of the logic was moved to ProxyConnectLBHttpConnectionFactory + return delegate().newConnection(resolvedAddress, contextMap, observer) + .shareContextOnSubscribe(); }); } - - private Single handleConnectResponse(final C connection, final StreamingHttpResponse response) { - if (response.status().statusClass() != SUCCESSFUL_2XX) { - return failed(new ProxyResponseException(connection + " Non-successful response from proxy CONNECT " + - connectAddress, response.status())); - } - - final Channel channel = ((NettyConnectionContext) connection.connectionContext()).nettyChannel(); - final SingleSource.Processor processor = newSingleProcessor(); - channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) { - if (evt instanceof SslHandshakeCompletionEvent) { - SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; - if (event.isSuccess()) { - processor.onSuccess(connection); - } else { - processor.onError(event.cause()); - } - } - ctx.fireUserEventTriggered(evt); - } - }); - - final DeferSslHandler deferSslHandler = channel.pipeline().get(DeferSslHandler.class); - if (deferSslHandler == null) { - if (!channel.isActive()) { - LOGGER.info("{} is unexpectedly closed after receiving response: {}. " + - "Investigate logs on a proxy side to identify the cause.", - connection, response.toString((name, value) -> value)); - return failed(StacklessClosedChannelException.newInstance( - ProxyConnectConnectionFactoryFilter.class, "handleConnectResponse: " + - connection + " is unexpectedly closed. Check logs for more info.")); - } - return failed(new IllegalStateException(connection + " Failed to find a handler of type " + - DeferSslHandler.class + " in channel pipeline.")); - } - deferSslHandler.ready(); - - // processor completes on EventLoop thread, apply offloading if required: - return isConnectOffloaded ? - fromSource(processor).publishOn(connection.executionContext().executor(), - IoThreadFactory.IoThread::currentThreadIsIoThread) : - fromSource(processor); - } } static void logUnexpectedAddress(@Nullable final Object current, final Object expected, final Logger logger) { diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java new file mode 100644 index 0000000000..83f310560a --- /dev/null +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java @@ -0,0 +1,204 @@ +/* + * Copyright © 2019-2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.http.netty; + +import io.servicetalk.client.api.ConnectionFactoryFilter; +import io.servicetalk.concurrent.SingleSource; +import io.servicetalk.concurrent.api.Completable; +import io.servicetalk.concurrent.api.Publisher; +import io.servicetalk.concurrent.api.Single; +import io.servicetalk.http.api.FilterableStreamingHttpConnection; +import io.servicetalk.http.api.HttpExecutionContext; +import io.servicetalk.http.api.HttpExecutionStrategy; +import io.servicetalk.http.api.StreamingHttpConnectionFilterFactory; +import io.servicetalk.http.api.StreamingHttpRequest; +import io.servicetalk.http.api.StreamingHttpRequestResponseFactory; +import io.servicetalk.http.api.StreamingHttpResponse; +import io.servicetalk.transport.api.ExecutionStrategy; +import io.servicetalk.transport.api.TransportObserver; +import io.servicetalk.transport.netty.internal.DeferSslHandler; +import io.servicetalk.transport.netty.internal.StacklessClosedChannelException; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; + +import java.util.function.Consumer; +import javax.annotation.Nullable; + +import static io.servicetalk.concurrent.api.Processors.newSingleProcessor; +import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; +import static io.servicetalk.http.api.HttpApiConversions.isPayloadEmpty; +import static io.servicetalk.http.api.HttpContextKeys.HTTP_EXECUTION_STRATEGY_KEY; +import static io.servicetalk.http.api.HttpExecutionStrategies.customStrategyBuilder; +import static io.servicetalk.http.api.HttpExecutionStrategies.offloadNone; +import static io.servicetalk.http.api.HttpHeaderNames.HOST; +import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; +import static io.servicetalk.http.api.HttpResponseStatus.StatusClass.SUCCESSFUL_2XX; +import static io.servicetalk.http.netty.StreamingConnectionFactory.buildStreaming; +import static io.servicetalk.utils.internal.ThrowableUtils.addSuppressed; + +/** + * {@link AbstractLBHttpConnectionFactory} implementation that handles HTTP/1.1 CONNECT when a client is configured to + * talk over HTTPS Proxy Tunnel. + * + * @param The type of resolved address. + */ +final class ProxyConnectLBHttpConnectionFactory + extends AbstractLBHttpConnectionFactory { + + private static final HttpExecutionStrategy OFFLOAD_SEND_STRATEGY = customStrategyBuilder().offloadSend().build(); + + private final String connectAddress; + private final Consumer connectRequestInitializer; + + ProxyConnectLBHttpConnectionFactory( + final ReadOnlyHttpClientConfig config, final HttpExecutionContext executionContext, + @Nullable final StreamingHttpConnectionFilterFactory connectionFilterFunction, + final StreamingHttpRequestResponseFactory reqRespFactory, + final ExecutionStrategy connectStrategy, + final ConnectionFactoryFilter connectionFactoryFilter, + final ProtocolBinding protocolBinding, + final Consumer connectRequestInitializer) { + super(config, executionContext, version -> reqRespFactory, connectStrategy, connectionFactoryFilter, + connectionFilterFunction, protocolBinding); + assert config.h1Config() != null : "H1ProtocolConfig is required"; + assert config.tcpConfig().sslContext() != null : "Proxy CONNECT works only for TLS connections"; + assert config.connectAddress() != null : "Address (authority) for CONNECT request is required"; + this.connectAddress = config.connectAddress().toString(); + this.connectRequestInitializer = connectRequestInitializer; + } + + @Override + Single newFilterableConnection(final ResolvedAddress resolvedAddress, + final TransportObserver observer) { + assert config.h1Config() != null; + return buildStreaming(executionContext, resolvedAddress, config, observer) + .map(c -> new PipelinedStreamingHttpConnection(c, config.h1Config(), + reqRespFactoryFunc.apply(HTTP_1_1), config.allowDropTrailersReadFromTransport())) + .flatMap(this::processConnect); + } + + // Visible for testing + Single processConnect(final NettyFilterableStreamingHttpConnection c) { + try { + // Send CONNECT request: https://datatracker.ietf.org/doc/html/rfc9110#section-9.3.6 + // Host header value must be equal to CONNECT request target, see + // https://github.com/haproxy/haproxy/issues/1159 + // https://datatracker.ietf.org/doc/html/rfc7230#section-5.4: + // If the target URI includes an authority component, then a client MUST send a field-value + // for Host that is identical to that authority component + final StreamingHttpRequest request = c.connect(connectAddress).setHeader(HOST, connectAddress); + connectRequestInitializer.accept(request); + configureOffloading(request); + return c.request(request) + .flatMap(response -> { + // Successful response to CONNECT never has a message body, and we are not interested in payload + // body for any non-200 status code. Drain it asap to free connection and RS resources before + // starting TLS handshake or propagating an error. We do this after verifying the status to + // preserve ProxyResponseException even if draining fails with an exception. + if (response.status().statusClass() != SUCCESSFUL_2XX) { + return drainPropagateError(response, new ProxyResponseException(c + + " Non-successful response from proxy CONNECT " + connectAddress, response.status())) + .shareContextOnSubscribe(); + } + return response.messageBody().ignoreElements() + .concat(handshake(c)) + .shareContextOnSubscribe(); + }) + // Close recently created connection in case of any error while it connects to the proxy: + .onErrorResume(t -> closePropagateError(c, t)); + // We do not apply shareContextOnSubscribe() here to isolate a context for `CONNECT` request. + } catch (Throwable t) { + return closePropagateError(c, t); + } + } + + private static void configureOffloading(final StreamingHttpRequest request) { + final HttpExecutionStrategy strategy; + if (isPayloadEmpty(request) || request.messageBody() == Publisher.empty()) { + // No need to offload because there is no user code involved + strategy = offloadNone(); + } else { + // Users added a custom request payload body Publisher, offload send for safety + strategy = OFFLOAD_SEND_STRATEGY; + } + // Put only if users didn't set their own strategy via connectRequestInitializer + request.context().putIfAbsent(HTTP_EXECUTION_STRATEGY_KEY, strategy); + } + + private Single handshake( + final NettyFilterableStreamingHttpConnection connection) { + return Single.defer(() -> { + final SingleSource.Processor + processor = newSingleProcessor(); + final Channel channel = connection.nettyChannel(); + assert channel.eventLoop().inEventLoop(); + channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; + if (event.isSuccess()) { + processor.onSuccess(connection); + } else { + processor.onError(event.cause()); + } + channel.pipeline().remove(this); + } + ctx.fireUserEventTriggered(evt); + } + }); + + final Single result; + final DeferSslHandler deferSslHandler = channel.pipeline().get(DeferSslHandler.class); + if (deferSslHandler == null) { + if (!channel.isActive()) { + result = Single.failed(StacklessClosedChannelException.newInstance(connection + + " Connection is closed, either received a 'Connection: closed' header or" + + " closed by the proxy. Investigate logs on a proxy side to identify the cause.", + ProxyConnectLBHttpConnectionFactory.class, "handshake")); + } else { + result = Single.failed(new IllegalStateException(connection + + " Unexpected connection state: failed to find a handler of type " + + DeferSslHandler.class + " in the channel pipeline.")); + } + } else { + deferSslHandler.ready(); + result = fromSource(processor); + } + return result.shareContextOnSubscribe(); + }); + } + + private static Single drainPropagateError( + final StreamingHttpResponse response, final Throwable error) { + return safeCompletePropagateError(response.messageBody().ignoreElements(), error); + } + + private static Single closePropagateError( + final FilterableStreamingHttpConnection connection, final Throwable error) { + return safeCompletePropagateError(connection.closeAsync(), error); + } + + private static Single safeCompletePropagateError( + final Completable completable, final Throwable error) { + return completable + .onErrorResume(completableError -> Completable.failed(addSuppressed(error, completableError))) + .concat(Single.failed(error)); + } +} diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java index eda012d703..12f1b240d8 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, 2021-2022 Apple Inc. and the ServiceTalk project authors + * Copyright © 2019-2023 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,16 +27,18 @@ import io.servicetalk.test.resources.DefaultTestCerts; import io.servicetalk.transport.api.ClientSslConfigBuilder; import io.servicetalk.transport.api.HostAndPort; -import io.servicetalk.transport.api.IoExecutor; import io.servicetalk.transport.api.ServerContext; import io.servicetalk.transport.api.ServerSslConfigBuilder; import io.servicetalk.transport.api.TransportObserver; import io.servicetalk.transport.netty.internal.ExecutionContextExtension; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.net.InetSocketAddress; import java.util.concurrent.atomic.AtomicReference; @@ -45,11 +47,13 @@ import static io.servicetalk.concurrent.api.Single.succeeded; import static io.servicetalk.http.api.HttpContextKeys.HTTP_TARGET_ADDRESS_BEHIND_PROXY; import static io.servicetalk.http.api.HttpHeaderNames.HOST; +import static io.servicetalk.http.api.HttpHeaderNames.PROXY_AUTHORIZATION; import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; +import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR; import static io.servicetalk.http.api.HttpResponseStatus.OK; +import static io.servicetalk.http.api.HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED; import static io.servicetalk.http.api.HttpSerializers.textSerializerUtf8; import static io.servicetalk.test.resources.DefaultTestCerts.serverPemHostname; -import static io.servicetalk.transport.netty.NettyIoExecutors.createIoExecutor; import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; import static java.nio.charset.StandardCharsets.US_ASCII; import static org.hamcrest.MatcherAssert.assertThat; @@ -60,6 +64,9 @@ class HttpsProxyTest { + private static final Logger LOGGER = LoggerFactory.getLogger(HttpsProxyTest.class); + private static final String AUTH_TOKEN = "aGVsbG86d29ybGQ="; + @RegisterExtension static final ExecutionContextExtension SERVER_CTX = ExecutionContextExtension.cached("server-io", "server-executor") @@ -75,32 +82,26 @@ class HttpsProxyTest { @Nullable private HostAndPort proxyAddress; @Nullable - private IoExecutor serverIoExecutor; - @Nullable private ServerContext serverContext; @Nullable private HostAndPort serverAddress; @Nullable private BlockingHttpClient client; - @BeforeEach - void setUp() throws Exception { + void setUp(boolean withAuth) throws Exception { + if (withAuth) { + proxyTunnel.basicAuthToken(AUTH_TOKEN); + } proxyAddress = proxyTunnel.startProxy(); startServer(); - createClient(); + createClient(withAuth); } @AfterEach void tearDown() throws Exception { - try { - safeClose(client); - safeClose(serverContext); - safeClose(proxyTunnel); - } finally { - if (serverIoExecutor != null) { - serverIoExecutor.closeAsync().toFuture().get(); - } - } + safeClose(client); + safeClose(serverContext); + safeClose(proxyTunnel); } static void safeClose(@Nullable AutoCloseable closeable) { @@ -108,14 +109,13 @@ static void safeClose(@Nullable AutoCloseable closeable) { try { closeable.close(); } catch (Exception e) { - e.printStackTrace(); + LOGGER.error("Unexpected exception while closing", e); } } } void startServer() throws Exception { serverContext = BuilderUtils.newServerBuilder(SERVER_CTX) - .ioExecutor(serverIoExecutor = createIoExecutor("server-io-executor")) .sslConfig(new ServerSslConfigBuilder(DefaultTestCerts::loadServerPem, DefaultTestCerts::loadServerKey).build()) .listenAndAwait((ctx, request, responseFactory) -> succeeded(responseFactory.ok() @@ -123,24 +123,30 @@ void startServer() throws Exception { serverAddress = serverHostAndPort(serverContext); } - void createClient() { + void createClient(boolean withAuth) { assert serverContext != null && proxyAddress != null; client = BuilderUtils.newClientBuilder(serverContext, CLIENT_CTX) - .proxyAddress(proxyAddress) + .proxyAddress(proxyAddress, withAuth ? + request -> request.setHeader(PROXY_AUTHORIZATION, "basic " + AUTH_TOKEN) : + __ -> { }) .sslConfig(new ClientSslConfigBuilder(DefaultTestCerts::loadServerCAPem) .peerHost(serverPemHostname()).build()) .appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, true)) .buildBlocking(); } - @Test - void testClientRequest() throws Exception { + @ParameterizedTest(name = "{displayName} [{index}] withAuth={0}") + @ValueSource(booleans = {false, true}) + void testClientRequest(boolean withAuth) throws Exception { + setUp(withAuth); assert client != null; assertResponse(client.request(client.get("/path"))); } - @Test - void testConnectionRequest() throws Exception { + @ParameterizedTest(name = "{displayName} [{index}] withAuth={0}") + @ValueSource(booleans = {false, true}) + void testConnectionRequest(boolean withAuth) throws Exception { + setUp(withAuth); assert client != null; try (ReservedBlockingHttpConnection connection = client.reserveConnection(client.get("/"))) { assertThat(connection.connectionContext().protocol(), is(HTTP_1_1)); @@ -159,10 +165,24 @@ private void assertResponse(HttpResponse httpResponse) { } @Test - void testBadProxyResponse() { + void testProxyAuthRequired() throws Exception { + setUp(false); + proxyTunnel.basicAuthToken(AUTH_TOKEN); + assert client != null; + ProxyResponseException e = assertThrows(ProxyResponseException.class, + () -> client.request(client.get("/path"))); + assertThat(e.status(), is(PROXY_AUTHENTICATION_REQUIRED)); + assertThat(targetAddress.get(), is(equalTo(serverAddress.toString()))); + } + + @Test + void testBadProxyResponse() throws Exception { + setUp(false); proxyTunnel.badResponseProxy(); assert client != null; - assertThrows(ProxyResponseException.class, () -> client.request(client.get("/path"))); + ProxyResponseException e = assertThrows(ProxyResponseException.class, + () -> client.request(client.get("/path"))); + assertThat(e.status(), is(INTERNAL_SERVER_ERROR)); assertThat(targetAddress.get(), is(equalTo(serverAddress.toString()))); } diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectConnectionFactoryFilterTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactoryTest.java similarity index 65% rename from servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectConnectionFactoryFilterTest.java rename to servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactoryTest.java index 3de9cd9ecb..b4ec9d28da 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectConnectionFactoryFilterTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactoryTest.java @@ -1,5 +1,5 @@ /* - * Copyright © 2020-2021 Apple Inc. and the ServiceTalk project authors + * Copyright © 2020-2023 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,10 @@ */ package io.servicetalk.http.netty; -import io.servicetalk.client.api.ConnectionFactory; +import io.servicetalk.client.api.ConnectionFactoryFilter; import io.servicetalk.concurrent.Cancellable; import io.servicetalk.concurrent.PublisherSource; +import io.servicetalk.concurrent.api.Publisher; import io.servicetalk.concurrent.api.TestCompletable; import io.servicetalk.concurrent.api.TestPublisher; import io.servicetalk.concurrent.test.internal.TestSingleSubscriber; @@ -26,28 +27,30 @@ import io.servicetalk.http.api.FilterableStreamingHttpConnection; import io.servicetalk.http.api.HttpConnectionContext; import io.servicetalk.http.api.HttpExecutionContext; -import io.servicetalk.http.api.HttpExecutionStrategies; import io.servicetalk.http.api.HttpExecutionStrategy; -import io.servicetalk.http.api.StreamingHttpRequestFactory; +import io.servicetalk.http.api.StreamingHttpRequest; +import io.servicetalk.http.api.StreamingHttpRequestResponseFactory; import io.servicetalk.http.api.StreamingHttpResponse; +import io.servicetalk.http.netty.AbstractLBHttpConnectionFactory.ProtocolBinding; +import io.servicetalk.transport.api.ClientSslConfig; +import io.servicetalk.transport.api.ClientSslConfigBuilder; import io.servicetalk.transport.api.ConnectExecutionStrategy; import io.servicetalk.transport.netty.internal.DeferSslHandler; -import io.servicetalk.transport.netty.internal.NettyConnectionContext; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoop; import io.netty.handler.ssl.SslHandshakeCompletionEvent; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import java.nio.channels.ClosedChannelException; -import java.util.Queue; -import java.util.concurrent.LinkedBlockingQueue; import java.util.function.Consumer; import javax.annotation.Nullable; @@ -57,10 +60,13 @@ import static io.servicetalk.concurrent.api.Single.succeeded; import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; +import static io.servicetalk.http.api.HttpContextKeys.HTTP_EXECUTION_STRATEGY_KEY; +import static io.servicetalk.http.api.HttpExecutionStrategies.customStrategyBuilder; +import static io.servicetalk.http.api.HttpExecutionStrategies.offloadNone; import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR; import static io.servicetalk.http.api.HttpResponseStatus.OK; -import static io.servicetalk.test.resources.TestUtils.assertNoAsyncErrors; +import static io.servicetalk.http.netty.HttpProtocolConfigs.h1Default; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; @@ -69,29 +75,35 @@ import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.sameInstance; +import static org.mockito.ArgumentCaptor.forClass; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -class ProxyConnectConnectionFactoryFilterTest { +class ProxyConnectLBHttpConnectionFactoryTest { - private static final StreamingHttpRequestFactory REQ_FACTORY = new DefaultStreamingHttpRequestResponseFactory( - DEFAULT_ALLOCATOR, DefaultHttpHeadersFactory.INSTANCE, HTTP_1_1); + private static final ClientSslConfig DEFAULT_SSL_CONFIG = new ClientSslConfigBuilder().build(); + private static final StreamingHttpRequestResponseFactory REQ_RES_FACTORY = + new DefaultStreamingHttpRequestResponseFactory(DEFAULT_ALLOCATOR, DefaultHttpHeadersFactory.INSTANCE, + HTTP_1_1); private static final String CONNECT_ADDRESS = "foo.bar"; - private static final String RESOLVED_ADDRESS = "bar.foo"; - private final FilterableStreamingHttpConnection connection; + private final NettyFilterableStreamingHttpConnection connection; private final TestCompletable connectionClose; private final TestPublisher messageBody; private final TestSingleSubscriber subscriber; + private final Consumer connectRequestInitializer; + private final ProxyConnectLBHttpConnectionFactory connectionFactory; - ProxyConnectConnectionFactoryFilterTest() { + @SuppressWarnings("unchecked") + ProxyConnectLBHttpConnectionFactoryTest() { HttpExecutionContext executionContext = new HttpExecutionContextBuilder().build(); HttpConnectionContext connectionContext = mock(HttpConnectionContext.class); when(connectionContext.executionContext()).thenReturn(executionContext); - connection = mock(FilterableStreamingHttpConnection.class); + connection = mock(NettyFilterableStreamingHttpConnection.class); when(connection.connectionContext()).thenReturn(connectionContext); connectionClose = new TestCompletable.Builder().build(subscriber -> { subscriber.onSubscribe(IGNORE_CANCEL); @@ -116,6 +128,15 @@ public void cancel() { }); subscriber = new TestSingleSubscriber<>(); + + connectRequestInitializer = mock(Consumer.class); + HttpClientConfig config = new HttpClientConfig(); + config.connectAddress(CONNECT_ADDRESS); + config.tcpConfig().sslConfig(DEFAULT_SSL_CONFIG); + config.protocolConfigs().protocols(h1Default()); + connectionFactory = new ProxyConnectLBHttpConnectionFactory<>(config.asReadOnly(), + executionContext, null, REQ_RES_FACTORY, ConnectExecutionStrategy.offloadNone(), + ConnectionFactoryFilter.identity(), mock(ProtocolBinding.class), connectRequestInitializer); } private static ChannelPipeline configurePipeline(@Nullable SslHandshakeCompletionEvent event) { @@ -134,22 +155,14 @@ private static void configureDeferSslHandler(ChannelPipeline pipeline) { when(pipeline.get(DeferSslHandler.class)).thenReturn(mock(DeferSslHandler.class)); } - private void configureConnectionContext(final ChannelPipeline pipeline) { - configureConnectionContext(pipeline, HttpExecutionStrategies.defaultStrategy()); - } - - private void configureConnectionContext(final ChannelPipeline pipeline, - final HttpExecutionStrategy executionStrategy) { + private void configureConnectionNettyChannel(final ChannelPipeline pipeline) { Channel channel = mock(Channel.class); + EventLoop eventLoop = mock(EventLoop.class); + when(eventLoop.inEventLoop()).thenReturn(true); + when(channel.eventLoop()).thenReturn(eventLoop); when(channel.pipeline()).thenReturn(pipeline); when(pipeline.channel()).thenReturn(channel); - - HttpExecutionContext executionContext = new HttpExecutionContextBuilder() - .executionStrategy(executionStrategy).build(); - NettyHttpConnectionContext nettyContext = mock(NettyHttpConnectionContext.class); - when(nettyContext.executionContext()).thenReturn(executionContext); - when(nettyContext.nettyChannel()).thenReturn(channel); - when(connection.connectionContext()).thenReturn(nettyContext); + when(connection.nettyChannel()).thenReturn(channel); } private void configureRequestSend() { @@ -160,21 +173,11 @@ private void configureRequestSend() { } private void configureConnectRequest() { - when(connection.connect(any())).thenReturn(REQ_FACTORY.connect(CONNECT_ADDRESS)); + when(connection.connect(any())).thenReturn(REQ_RES_FACTORY.connect(CONNECT_ADDRESS)); } private void subscribeToProxyConnectionFactory() { - subscribeToProxyConnectionFactory(c -> { }); - } - - private void subscribeToProxyConnectionFactory(Consumer onSuccess) { - @SuppressWarnings("unchecked") - ConnectionFactory original = mock(ConnectionFactory.class); - when(original.newConnection(any(), any(), any())).thenReturn(succeeded(connection)); - toSource(new ProxyConnectConnectionFactoryFilter( - CONNECT_ADDRESS, ConnectExecutionStrategy.offloadNone()) - .create(original).newConnection(RESOLVED_ADDRESS, null, null).afterOnSuccess(onSuccess)) - .subscribe(subscriber); + toSource(connectionFactory.processConnect(connection)).subscribe(subscriber); } @Test @@ -185,6 +188,7 @@ void newConnectRequestThrows() { assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); verify(connection).connect(any()); verify(connection, never()).request(any()); + verify(connectRequestInitializer, never()).accept(any()); assertConnectionClosed(); } @@ -195,6 +199,7 @@ void connectRequestFails() { configureConnectRequest(); subscribeToProxyConnectionFactory(); + verify(connectRequestInitializer).accept(any()); Throwable error = subscriber.awaitOnError(); assertThat("Unexpected error: " + error, error, is(DELIBERATE_EXCEPTION)); assertConnectPayloadConsumed(false); @@ -211,6 +216,7 @@ void nonSuccessfulResponseCode() { configureConnectRequest(); subscribeToProxyConnectionFactory(); + verify(connectRequestInitializer).accept(any()); Throwable error = subscriber.awaitOnError(); assertThat(error, instanceOf(ProxyResponseException.class)); assertThat(((ProxyResponseException) error).status(), is(INTERNAL_SERVER_ERROR)); @@ -218,37 +224,19 @@ void nonSuccessfulResponseCode() { assertConnectionClosed(); } - @Test - void cannotAccessNettyChannel() { - // Does not implement NettyConnectionContext: - HttpExecutionContext executionContext = new HttpExecutionContextBuilder().build(); - - HttpConnectionContext connectionContext = mock(HttpConnectionContext.class); - when(connectionContext.executionContext()).thenReturn(executionContext); - - when(connection.connectionContext()).thenReturn(connectionContext); - - configureRequestSend(); - configureConnectRequest(); - subscribeToProxyConnectionFactory(); - - assertThat(subscriber.awaitOnError(), instanceOf(ClassCastException.class)); - assertConnectPayloadConsumed(true); - assertConnectionClosed(); - } - @ParameterizedTest(name = "{displayName} [{index}] ttl={0}") @ValueSource(booleans = {true, false}) void noDeferSslHandler(boolean channelActive) { ChannelPipeline pipeline = configurePipeline(SslHandshakeCompletionEvent.SUCCESS); // Do not configureDeferSslHandler(pipeline); - configureConnectionContext(pipeline); + configureConnectionNettyChannel(pipeline); Channel channel = pipeline.channel(); when(channel.isActive()).thenReturn(channelActive); configureRequestSend(); configureConnectRequest(); subscribeToProxyConnectionFactory(); + verify(connectRequestInitializer).accept(any()); Throwable error = subscriber.awaitOnError(); assertThat(error, is(notNullValue())); if (channelActive) { @@ -266,11 +254,12 @@ void deferSslHandlerReadyThrows() { ChannelPipeline pipeline = configurePipeline(SslHandshakeCompletionEvent.SUCCESS); when(pipeline.get(DeferSslHandler.class)).thenThrow(DELIBERATE_EXCEPTION); - configureConnectionContext(pipeline); + configureConnectionNettyChannel(pipeline); configureRequestSend(); configureConnectRequest(); subscribeToProxyConnectionFactory(); + verify(connectRequestInitializer).accept(any()); assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); assertConnectPayloadConsumed(true); assertConnectionClosed(); @@ -281,11 +270,12 @@ void sslHandshakeFailure() { ChannelPipeline pipeline = configurePipeline(new SslHandshakeCompletionEvent(DELIBERATE_EXCEPTION)); configureDeferSslHandler(pipeline); - configureConnectionContext(pipeline); + configureConnectionNettyChannel(pipeline); configureRequestSend(); configureConnectRequest(); subscribeToProxyConnectionFactory(); + verify(connectRequestInitializer).accept(any()); assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); assertConnectPayloadConsumed(true); assertConnectionClosed(); @@ -297,11 +287,12 @@ void cancelledBeforeSslHandshakeCompletionEvent() { ChannelPipeline pipeline = configurePipeline(null); // Do not generate any SslHandshakeCompletionEvent configureDeferSslHandler(pipeline); - configureConnectionContext(pipeline); + configureConnectionNettyChannel(pipeline); configureRequestSend(); configureConnectRequest(); subscribeToProxyConnectionFactory(); + verify(connectRequestInitializer).accept(any()); Cancellable cancellable = subscriber.awaitSubscription(); assertThat(subscriber.pollTerminal(10, MILLISECONDS), is(nullValue())); assertThat(connectionClose.isSubscribed(), is(false)); @@ -312,51 +303,77 @@ void cancelledBeforeSslHandshakeCompletionEvent() { @Test void successfulConnect() { - ChannelPipeline pipeline = configurePipeline(SslHandshakeCompletionEvent.SUCCESS); - configureDeferSslHandler(pipeline); - configureConnectionContext(pipeline); - configureRequestSend(); - configureConnectRequest(); + prepareForSuccess(); subscribeToProxyConnectionFactory(); + verify(connectRequestInitializer).accept(any()); assertThat(subscriber.awaitOnSuccess(), is(sameInstance(this.connection))); - assertConnectPayloadConsumed(true); - assertThat("Connection closed", connectionClose.isSubscribed(), is(false)); + StreamingHttpRequest request = assertConnectPayloadConsumed(true); + assertExecutionStrategy(request, offloadNone()); + assertConnectionClosed(false); } @Test - void noOffloadingStrategy() { + void connectPayloadBodyChangesExecutionStrategy() { + prepareForSuccess(); + doAnswer(invocation -> { + StreamingHttpRequest request = invocation.getArgument(0); + request.payloadBody(Publisher.from()); + return null; + }).when(connectRequestInitializer).accept(any()); + subscribeToProxyConnectionFactory(); + + verify(connectRequestInitializer).accept(any()); + assertThat(subscriber.awaitOnSuccess(), is(sameInstance(this.connection))); + StreamingHttpRequest request = assertConnectPayloadConsumed(true); + assertExecutionStrategy(request, customStrategyBuilder().offloadSend().build()); + assertConnectionClosed(false); + } + + @Test + void usersCanOverrideExecutionStrategy() { + prepareForSuccess(); + HttpExecutionStrategy customStrategy = customStrategyBuilder().offloadReceiveData().build(); + doAnswer(invocation -> { + StreamingHttpRequest request = invocation.getArgument(0); + request.context().put(HTTP_EXECUTION_STRATEGY_KEY, customStrategy); + return null; + }).when(connectRequestInitializer).accept(any()); + subscribeToProxyConnectionFactory(); + + verify(connectRequestInitializer).accept(any()); + assertThat(subscriber.awaitOnSuccess(), is(sameInstance(this.connection))); + StreamingHttpRequest request = assertConnectPayloadConsumed(true); + assertExecutionStrategy(request, customStrategy); + assertConnectionClosed(false); + } + + private void prepareForSuccess() { ChannelPipeline pipeline = configurePipeline(SslHandshakeCompletionEvent.SUCCESS); configureDeferSslHandler(pipeline); - configureConnectionContext(pipeline, HttpExecutionStrategies.offloadNone()); + configureConnectionNettyChannel(pipeline); configureRequestSend(); configureConnectRequest(); - Queue errors = new LinkedBlockingQueue<>(); - Thread testThread = Thread.currentThread(); - subscribeToProxyConnectionFactory(c -> { - if (Thread.currentThread() != testThread) { - errors.add(new AssertionError("Unexpected Thread for success " + Thread.currentThread())); - } - }); - - assertNoAsyncErrors(errors); - assertThat(subscriber.awaitOnSuccess(), is(sameInstance(this.connection))); - assertConnectPayloadConsumed(true); - assertThat("Connection closed", !connectionClose.isSubscribed()); } - private void assertConnectPayloadConsumed(boolean expected) { + private StreamingHttpRequest assertConnectPayloadConsumed(boolean expected) { + ArgumentCaptor requestCaptor = forClass(StreamingHttpRequest.class); verify(connection).connect(any()); - verify(connection).request(any()); + verify(connection).request(requestCaptor.capture()); assertThat("CONNECT response payload body was " + (expected ? "was" : "unnecessarily") + " consumed", messageBody.isSubscribed(), is(expected)); + return requestCaptor.getValue(); + } + + private static void assertExecutionStrategy(StreamingHttpRequest request, HttpExecutionStrategy expectedStrategy) { + assertThat(request.context().get(HTTP_EXECUTION_STRATEGY_KEY), is(expectedStrategy)); } private void assertConnectionClosed() { - assertThat("Closure of the connection was not triggered", connectionClose.isSubscribed(), is(true)); + assertConnectionClosed(true); } - private interface NettyHttpConnectionContext extends HttpConnectionContext, NettyConnectionContext { - // no methods + private void assertConnectionClosed(boolean closed) { + assertThat("Closure of the connection was not triggered", connectionClose.isSubscribed(), is(closed)); } } diff --git a/servicetalk-http-netty/src/testFixtures/java/io/servicetalk/http/netty/ProxyTunnel.java b/servicetalk-http-netty/src/testFixtures/java/io/servicetalk/http/netty/ProxyTunnel.java index 6cf58aa344..32b75fbf86 100644 --- a/servicetalk-http-netty/src/testFixtures/java/io/servicetalk/http/netty/ProxyTunnel.java +++ b/servicetalk-http-netty/src/testFixtures/java/io/servicetalk/http/netty/ProxyTunnel.java @@ -16,6 +16,7 @@ package io.servicetalk.http.netty; import io.servicetalk.concurrent.api.DefaultThreadFactory; +import io.servicetalk.http.api.HttpHeaderNames; import io.servicetalk.transport.api.HostAndPort; import org.slf4j.Logger; @@ -35,10 +36,13 @@ import static io.servicetalk.http.api.HttpHeaderNames.CONTENT_LENGTH; import static io.servicetalk.http.api.HttpHeaderNames.HOST; +import static io.servicetalk.http.api.HttpHeaderNames.PROXY_AUTHENTICATE; +import static io.servicetalk.http.api.HttpHeaderNames.PROXY_AUTHORIZATION; import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; import static io.servicetalk.http.api.HttpRequestMethod.CONNECT; import static io.servicetalk.http.api.HttpResponseStatus.BAD_REQUEST; import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR; +import static io.servicetalk.http.api.HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED; import static java.net.InetAddress.getLoopbackAddress; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.concurrent.Executors.newCachedThreadPool; @@ -56,7 +60,9 @@ public final class ProxyTunnel implements AutoCloseable { @Nullable private ServerSocket serverSocket; - private ProxyRequestHandler handler = this::handleRequest; + @Nullable + private volatile String authToken; + private volatile ProxyRequestHandler handler = this::handleRequest; @SuppressWarnings("ResultOfMethodCallIgnored") @Override @@ -86,38 +92,27 @@ public HostAndPort startProxy() throws IOException { executor.submit(() -> { try { final InputStream in = socket.getInputStream(); - final String host; - final int port; - final String protocol; - try { - final String initialLine = readLine(in); - if (!initialLine.startsWith(CONNECT_PREFIX)) { - throw new IllegalArgumentException("Expected " + CONNECT + " request, but found: " + - initialLine); - } - final int end = initialLine.indexOf(' ', CONNECT_PREFIX.length()); - final String authority = initialLine.substring(CONNECT_PREFIX.length(), end); - final int colon = authority.indexOf(':'); - host = authority.substring(0, colon); - port = Integer.parseInt(authority.substring(colon + 1)); - protocol = initialLine.substring(end + 1); - - final String hostHeader = readLine(in); - if (!hostHeader.toLowerCase(Locale.ROOT).startsWith(HOST.toString())) { - throw new IllegalArgumentException("Expected " + HOST + " header, but found: " + - hostHeader); - } - final String hostHeaderValue = hostHeader.substring(HOST.length() + 2 /* colon & space */); - if (!(host + ':' + port).equalsIgnoreCase(hostHeaderValue)) { - throw new IllegalArgumentException( - "Host header value must be identical to authority component"); - } - - while (readLine(in).length() > 0) { - // Ignore any other headers. - } - } catch (Exception e) { - badRequest(socket, e.getMessage()); + final String initialLine = readLine(in); + if (!initialLine.startsWith(CONNECT_PREFIX)) { + throw new IllegalArgumentException("Expected " + CONNECT + " request, but found: " + + initialLine); + } + final int end = initialLine.indexOf(' ', CONNECT_PREFIX.length()); + final String authority = initialLine.substring(CONNECT_PREFIX.length(), end); + final int colon = authority.indexOf(':'); + final String host = authority.substring(0, colon); + final int port = Integer.parseInt(authority.substring(colon + 1)); + final String protocol = initialLine.substring(end + 1); + + final Headers headers = readHeaders(in); + if (!authority.equalsIgnoreCase(headers.host)) { + badRequest(socket, "Host header value must be identical to authority " + + "component. Expected: " + authority + ", found: " + headers.host); + return; + } + final String authToken = this.authToken; + if (authToken != null && !("basic " + authToken).equals(headers.proxyAuthorization)) { + proxyAuthRequired(socket); return; } handler.handle(socket, host, port, protocol); @@ -146,6 +141,14 @@ private static void badRequest(final Socket socket, final String cause) throws I os.flush(); } + private static void proxyAuthRequired(final Socket socket) throws IOException { + final OutputStream os = socket.getOutputStream(); + os.write((HTTP_1_1 + " " + PROXY_AUTHENTICATION_REQUIRED + "\r\n" + + PROXY_AUTHENTICATE + ": Basic realm=\"simple\"" + "\r\n" + + "\r\n").getBytes(UTF_8)); + os.flush(); + } + /** * Changes the proxy handler to return 500 instead of 200. */ @@ -157,6 +160,16 @@ public void badResponseProxy() { }; } + /** + * Sets a required {@link HttpHeaderNames#PROXY_AUTHORIZATION} header value for "Basic" scheme to validate before + * accepting a {@code CONNECT} request. + * + * @param authToken the auth token to validate + */ + public void basicAuthToken(@Nullable String authToken) { + this.authToken = authToken; + } + /** * Number of established connections to the proxy. * @@ -181,6 +194,22 @@ private static String readLine(final InputStream in) throws IOException { } } + private static Headers readHeaders(final InputStream in) throws IOException { + String host = null; + String proxyAuthorization = null; + String line; + while ((line = readLine(in)).length() > 0) { + final String lowerCaseLine = line.toLowerCase(Locale.ROOT); + if (lowerCaseLine.startsWith(HOST.toString())) { + host = line.substring(HOST.length() + 2 /* colon & space */); + } else if (lowerCaseLine.startsWith(PROXY_AUTHORIZATION.toString())) { + proxyAuthorization = line.substring(PROXY_AUTHORIZATION.length() + 2 /* colon & space */); + } + // Ignore any other headers. + } + return new Headers(host, proxyAuthorization); + } + private void handleRequest(final Socket serverSocket, final String host, final int port, final String protocol) throws IOException { try (Socket clientSocket = new Socket(host, port)) { @@ -232,4 +261,16 @@ private static void copyStream(final OutputStream out, final InputStream cin) th private interface ProxyRequestHandler { void handle(Socket socket, String host, int port, String protocol) throws IOException; } + + private static final class Headers { + @Nullable + final String host; + @Nullable + final String proxyAuthorization; + + Headers(@Nullable final String host, @Nullable final String proxyAuthorization) { + this.host = host; + this.proxyAuthorization = proxyAuthorization; + } + } } diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/StacklessClosedChannelException.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/StacklessClosedChannelException.java index 3242897b3d..99b3c73c61 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/StacklessClosedChannelException.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/StacklessClosedChannelException.java @@ -18,6 +18,7 @@ import io.servicetalk.concurrent.internal.ThrowableUtils; import java.nio.channels.ClosedChannelException; +import javax.annotation.Nullable; /** * {@link ClosedChannelException} that will not not fill in the stacktrace but use a cheaper way of producing @@ -26,7 +27,22 @@ public final class StacklessClosedChannelException extends ClosedChannelException { private static final long serialVersionUID = -5021225720136487769L; - private StacklessClosedChannelException() { } + @Nullable + private final String message; + + private StacklessClosedChannelException() { + this(null); + } + + private StacklessClosedChannelException(@Nullable final String message) { + this.message = message; + } + + @Nullable + @Override + public String getMessage() { + return message; + } @Override public Throwable fillInStackTrace() { @@ -41,7 +57,21 @@ public Throwable fillInStackTrace() { * @param method The method from which it will be thrown. * @return a new instance. */ - public static StacklessClosedChannelException newInstance(Class clazz, String method) { + public static StacklessClosedChannelException newInstance(final Class clazz, final String method) { return ThrowableUtils.unknownStackTrace(new StacklessClosedChannelException(), clazz, method); } + + /** + * Creates a new {@link StacklessClosedChannelException} instance. + * + * @param message The description message for more information. + * @param clazz The class in which this {@link StacklessClosedChannelException} will be used. + * @param method The method from which it will be thrown. + * @return a new instance. + */ + public static StacklessClosedChannelException newInstance(final String message, + final Class clazz, + final String method) { + return ThrowableUtils.unknownStackTrace(new StacklessClosedChannelException(message), clazz, method); + } }