From f00a486f3f4e010b64e8654915f9a1985bcd2a57 Mon Sep 17 00:00:00 2001 From: Bryce Anderson Date: Mon, 16 Oct 2023 17:04:08 -0700 Subject: [PATCH] Extract inner types from RoundRobinLoadBalancer (#2724) Motivation: The RoundRobinLoadBalancer contains a bunch of code that is doing a good deal more than load balancing, namely it's doing pooling, circuit breaking, and other things. This both makes the implementation hard to read and reason about and makes the inner pieces difficult to reuse for future load balancer implementation. Modifications: - Extract the types into their own files. - Hide a bit of the internal state to start to define abstraction boundaries. Note that this is pretty minimal so far, just accessors etc, so to keep this PR mechanical in nature. --- .../servicetalk/loadbalancer/Exceptions.java | 80 +++ .../loadbalancer/HealthCheckConfig.java | 40 ++ .../io/servicetalk/loadbalancer/Host.java | 491 ++++++++++++++++ .../loadbalancer/RoundRobinLoadBalancer.java | 536 +----------------- .../RoundRobinLoadBalancerFactory.java | 1 - 5 files changed, 618 insertions(+), 530 deletions(-) create mode 100644 servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/Exceptions.java create mode 100644 servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/HealthCheckConfig.java create mode 100644 servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/Host.java diff --git a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/Exceptions.java b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/Exceptions.java new file mode 100644 index 0000000000..a89e9e5791 --- /dev/null +++ b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/Exceptions.java @@ -0,0 +1,80 @@ +/* + * Copyright © 2021-2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.loadbalancer; + +import io.servicetalk.client.api.ConnectionRejectedException; +import io.servicetalk.client.api.NoActiveHostException; +import io.servicetalk.client.api.NoAvailableHostException; +import io.servicetalk.concurrent.internal.ThrowableUtils; + +final class Exceptions { + + static final class StacklessNoAvailableHostException extends NoAvailableHostException { + private static final long serialVersionUID = 5942960040738091793L; + + private StacklessNoAvailableHostException(final String message) { + super(message); + } + + @Override + public Throwable fillInStackTrace() { + return this; + } + + static StacklessNoAvailableHostException newInstance(String message, Class clazz, String method) { + return ThrowableUtils.unknownStackTrace(new StacklessNoAvailableHostException(message), clazz, method); + } + } + + static final class StacklessNoActiveHostException extends NoActiveHostException { + + private static final long serialVersionUID = 7500474499335155869L; + + private StacklessNoActiveHostException(final String message) { + super(message); + } + + @Override + public Throwable fillInStackTrace() { + return this; + } + + static StacklessNoActiveHostException newInstance(String message, Class clazz, String method) { + return ThrowableUtils.unknownStackTrace(new StacklessNoActiveHostException(message), clazz, method); + } + } + + static final class StacklessConnectionRejectedException extends ConnectionRejectedException { + private static final long serialVersionUID = -4940708893680455819L; + + private StacklessConnectionRejectedException(final String message) { + super(message); + } + + @Override + public Throwable fillInStackTrace() { + return this; + } + + static StacklessConnectionRejectedException newInstance(String message, Class clazz, String method) { + return ThrowableUtils.unknownStackTrace(new StacklessConnectionRejectedException(message), clazz, method); + } + } + + private Exceptions() { + // no instances + } +} diff --git a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/HealthCheckConfig.java b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/HealthCheckConfig.java new file mode 100644 index 0000000000..8fad1f4198 --- /dev/null +++ b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/HealthCheckConfig.java @@ -0,0 +1,40 @@ +/* + * Copyright © 2021-2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.loadbalancer; + +import io.servicetalk.concurrent.api.Executor; + +import java.time.Duration; + +final class HealthCheckConfig { + final Executor executor; + final Duration healthCheckInterval; + final Duration jitter; + final int failedThreshold; + final long healthCheckResubscribeLowerBound; + final long healthCheckResubscribeUpperBound; + + HealthCheckConfig(final Executor executor, final Duration healthCheckInterval, final Duration healthCheckJitter, + final int failedThreshold, final long healthCheckResubscribeLowerBound, + final long healthCheckResubscribeUpperBound) { + this.executor = executor; + this.healthCheckInterval = healthCheckInterval; + this.failedThreshold = failedThreshold; + this.jitter = healthCheckJitter; + this.healthCheckResubscribeLowerBound = healthCheckResubscribeLowerBound; + this.healthCheckResubscribeUpperBound = healthCheckResubscribeUpperBound; + } +} diff --git a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/Host.java b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/Host.java new file mode 100644 index 0000000000..f4961c2eb4 --- /dev/null +++ b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/Host.java @@ -0,0 +1,491 @@ +/* + * Copyright © 2021-2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.loadbalancer; + +import io.servicetalk.client.api.ConnectionFactory; +import io.servicetalk.client.api.ConnectionLimitReachedException; +import io.servicetalk.client.api.LoadBalancedConnection; +import io.servicetalk.concurrent.api.AsyncCloseable; +import io.servicetalk.concurrent.api.AsyncContext; +import io.servicetalk.concurrent.api.Completable; +import io.servicetalk.concurrent.api.ListenableAsyncCloseable; +import io.servicetalk.concurrent.internal.DelayedCancellable; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.AbstractMap; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; +import java.util.stream.Stream; +import javax.annotation.Nullable; + +import static io.servicetalk.concurrent.api.AsyncCloseables.toAsyncCloseable; +import static io.servicetalk.concurrent.api.Completable.completed; +import static io.servicetalk.concurrent.api.Publisher.from; +import static io.servicetalk.concurrent.api.RetryStrategies.retryWithConstantBackoffDeltaJitter; +import static io.servicetalk.concurrent.internal.FlowControlUtils.addWithOverflowProtection; +import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.newUpdater; +import static java.util.stream.Collectors.toList; + +final class Host implements ListenableAsyncCloseable { + + private static final Object[] EMPTY_ARRAY = new Object[0]; + private static final Logger LOGGER = LoggerFactory.getLogger(Host.class); + + private enum State { + // The enum is not exhaustive, as other states have dynamic properties. + // For clarity, the other state classes are listed as comments: + // ACTIVE - see ActiveState + // UNHEALTHY - see HealthCheck + EXPIRED, + CLOSED + } + + private static final ActiveState STATE_ACTIVE_NO_FAILURES = new ActiveState(); + private static final ConnState ACTIVE_EMPTY_CONN_STATE = new ConnState(EMPTY_ARRAY, STATE_ACTIVE_NO_FAILURES); + private static final ConnState CLOSED_CONN_STATE = new ConnState(EMPTY_ARRAY, State.CLOSED); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater connStateUpdater = + newUpdater(Host.class, ConnState.class, "connState"); + + private final String lbDescription; + final Addr address; + @Nullable + private final HealthCheckConfig healthCheckConfig; + private final ListenableAsyncCloseable closeable; + private volatile ConnState connState = ACTIVE_EMPTY_CONN_STATE; + + Host(String lbDescription, Addr address, @Nullable HealthCheckConfig healthCheckConfig) { + this.lbDescription = lbDescription; + this.address = address; + this.healthCheckConfig = healthCheckConfig; + this.closeable = toAsyncCloseable(graceful -> + graceful ? doClose(AsyncCloseable::closeAsyncGracefully) : doClose(AsyncCloseable::closeAsync)); + } + + boolean markActiveIfNotClosed() { + final Object oldState = connStateUpdater.getAndUpdate(this, oldConnState -> { + if (oldConnState.state == State.EXPIRED) { + return new ConnState(oldConnState.connections, STATE_ACTIVE_NO_FAILURES); + } + // If oldConnState.state == State.ACTIVE this could mean either a duplicate event, + // or a repeated CAS operation. We could issue a warning, but as we don't know, we don't log anything. + // UNHEALTHY state cannot transition to ACTIVE without passing the health check. + return oldConnState; + }).state; + return oldState != State.CLOSED; + } + + void markClosed() { + final ConnState oldState = closeConnState(); + final Object[] toRemove = oldState.connections; + cancelIfHealthCheck(oldState); + LOGGER.debug("{}: closing {} connection(s) gracefully to the closed address: {}.", + lbDescription, toRemove.length, address); + for (Object conn : toRemove) { + @SuppressWarnings("unchecked") + final C cConn = (C) conn; + cConn.closeAsyncGracefully().subscribe(); + } + } + + private ConnState closeConnState() { + for (;;) { + // We need to keep the oldState.connections around even if we are closed because the user may do + // closeGracefully with a timeout, which fails, and then force close. If we discard connections when + // closeGracefully is started we may leak connections. + final ConnState oldState = connState; + if (oldState.state == State.CLOSED || connStateUpdater.compareAndSet(this, oldState, + new ConnState(oldState.connections, State.CLOSED))) { + return oldState; + } + } + } + + void markExpired() { + for (;;) { + ConnState oldState = connStateUpdater.get(this); + if (oldState.state == State.EXPIRED || oldState.state == State.CLOSED) { + break; + } + Object nextState = oldState.connections.length == 0 ? State.CLOSED : State.EXPIRED; + + if (connStateUpdater.compareAndSet(this, oldState, + new ConnState(oldState.connections, nextState))) { + cancelIfHealthCheck(oldState); + if (nextState == State.CLOSED) { + // Trigger the callback to remove the host from usedHosts array. + this.closeAsync().subscribe(); + } + break; + } + } + } + + void markHealthy(final HealthCheck originalHealthCheckState) { + // Marking healthy is called when we need to recover from an unexpected error. + // However, it is possible that in the meantime, the host entered an EXPIRED state, then ACTIVE, then failed + // to open connections and entered the UNHEALTHY state before the original thread continues execution here. + // In such case, the flipped state is not the same as the one that just succeeded to open a connection. + // In an unlikely scenario that the following connection attempts fail indefinitely, a health check task + // would leak and would not be cancelled. Therefore, we cancel it here and allow failures to trigger a new + // health check. + ConnState oldState = connStateUpdater.getAndUpdate(this, previous -> { + if (Host.isUnhealthy(previous)) { + return new ConnState(previous.connections, STATE_ACTIVE_NO_FAILURES); + } + return previous; + }); + if (oldState.state != originalHealthCheckState) { + cancelIfHealthCheck(oldState); + } + } + + void markUnhealthy(final Throwable cause, final ConnectionFactory connectionFactory) { + assert healthCheckConfig != null; + for (;;) { + ConnState previous = connStateUpdater.get(this); + + if (!Host.isActive(previous) || previous.connections.length > 0 + || cause instanceof ConnectionLimitReachedException) { + LOGGER.debug("{}: failed to open a new connection to the host on address {}. {}.", + lbDescription, address, previous, cause); + break; + } + + ActiveState previousState = (ActiveState) previous.state; + if (previousState.failedConnections + 1 < healthCheckConfig.failedThreshold) { + final ActiveState nextState = previousState.forNextFailedConnection(); + if (connStateUpdater.compareAndSet(this, previous, + new ConnState(previous.connections, nextState))) { + LOGGER.debug("{}: failed to open a new connection to the host on address {}" + + " {} time(s) ({} consecutive failures will trigger health-checking).", + lbDescription, address, nextState.failedConnections, + healthCheckConfig.failedThreshold, cause); + break; + } + // another thread won the race, try again + continue; + } + + final HealthCheck healthCheck = new HealthCheck<>(connectionFactory, this, cause); + final ConnState nextState = new ConnState(previous.connections, healthCheck); + if (connStateUpdater.compareAndSet(this, previous, nextState)) { + LOGGER.info("{}: failed to open a new connection to the host on address {} " + + "{} time(s) in a row. Error counting threshold reached, marking this host as " + + "UNHEALTHY for the selection algorithm and triggering background health-checking.", + lbDescription, address, healthCheckConfig.failedThreshold, cause); + healthCheck.schedule(cause); + break; + } + } + } + + boolean isActiveAndHealthy() { + return isActive(connState); + } + + boolean isUnhealthy() { + return isUnhealthy(connState); + } + + private static boolean isActive(final ConnState connState) { + return ActiveState.class.equals(connState.state.getClass()); + } + + private static boolean isUnhealthy(ConnState connState) { + return HealthCheck.class.equals(connState.state.getClass()); + } + + boolean addConnection(final C connection, final @Nullable HealthCheck currentHealthCheck) { + int addAttempt = 0; + for (;;) { + final ConnState previous = connStateUpdater.get(this); + if (previous.state == State.CLOSED) { + return false; + } + ++addAttempt; + + final Object[] existing = previous.connections; + // Brute force iteration to avoid duplicates. If connections grow larger and faster lookup is required + // we can keep a Set for faster lookups (at the cost of more memory) as well as array. + for (final Object o : existing) { + if (o.equals(connection)) { + return true; + } + } + Object[] newList = Arrays.copyOf(existing, existing.length + 1); + newList[existing.length] = connection; + + // If we were able to add a new connection to the list, we should mark the host as ACTIVE again and + // reset its failures counter. + final Object newState = Host.isActive(previous) || Host.isUnhealthy(previous) ? + STATE_ACTIVE_NO_FAILURES : previous.state; + + if (connStateUpdater.compareAndSet(this, + previous, new ConnState(newList, newState))) { + // It could happen that the Host turned into UNHEALTHY state either concurrently with adding a new + // connection or with passing a previous health-check (if SD turned it into ACTIVE state). In both + // cases we have to cancel the "previous" ongoing health check. See "markHealthy" for more context. + if (Host.isUnhealthy(previous) && + (currentHealthCheck == null || previous.state != currentHealthCheck)) { + assert newState == STATE_ACTIVE_NO_FAILURES; + cancelIfHealthCheck(previous); + } + break; + } + } + + LOGGER.trace("{}: added a new connection {} to {} after {} attempt(s).", + lbDescription, connection, this, addAttempt); + // Instrument the new connection so we prune it on close + connection.onClose().beforeFinally(() -> { + int removeAttempt = 0; + for (;;) { + final ConnState currentConnState = this.connState; + if (currentConnState.state == State.CLOSED) { + break; + } + assert currentConnState.connections.length > 0; + ++removeAttempt; + int i = 0; + final Object[] connections = currentConnState.connections; + // Search for the connection in the list. + for (; i < connections.length; ++i) { + if (connections[i].equals(connection)) { + break; + } + } + if (i == connections.length) { + // Connection was already removed, nothing to do. + break; + } else if (connections.length == 1) { + assert !Host.isUnhealthy(currentConnState) : "Cannot be UNHEALTHY with #connections > 0"; + if (Host.isActive(currentConnState)) { + if (connStateUpdater.compareAndSet(this, currentConnState, + new ConnState(EMPTY_ARRAY, currentConnState.state))) { + break; + } + } else if (currentConnState.state == State.EXPIRED + // We're closing the last connection, close the Host. + // Closing the host will trigger the Host's onClose method, which will remove the host + // from used hosts list. If a race condition appears and a new connection was added + // in the meantime, that would mean the host is available again and the CAS operation + // will allow for determining that. It will prevent closing the Host and will only + // remove the connection (previously considered as the last one) from the array + // in the next iteration. + && connStateUpdater.compareAndSet(this, currentConnState, CLOSED_CONN_STATE)) { + this.closeAsync().subscribe(); + break; + } + } else { + Object[] newList = new Object[connections.length - 1]; + System.arraycopy(connections, 0, newList, 0, i); + System.arraycopy(connections, i + 1, newList, i, newList.length - i); + if (connStateUpdater.compareAndSet(this, + currentConnState, new ConnState(newList, currentConnState.state))) { + break; + } + } + } + LOGGER.trace("{}: removed connection {} from {} after {} attempt(s).", + lbDescription, connection, this, removeAttempt); + }).onErrorComplete(t -> { + // Use onErrorComplete instead of whenOnError to avoid double logging of an error inside subscribe(): + // SimpleCompletableSubscriber. + LOGGER.error("{}: unexpected error while processing connection.onClose() for {}.", + lbDescription, connection, t); + return true; + }).subscribe(); + return true; + } + + // Used for testing only + @SuppressWarnings("unchecked") + Map.Entry> asEntry() { + return new AbstractMap.SimpleImmutableEntry<>(address, + Stream.of(connState.connections).map(conn -> (C) conn).collect(toList())); + } + + Object[] connections() { + return connState.connections; + } + + @Override + public Completable closeAsync() { + return closeable.closeAsync(); + } + + @Override + public Completable closeAsyncGracefully() { + return closeable.closeAsyncGracefully(); + } + + @Override + public Completable onClose() { + return closeable.onClose(); + } + + @Override + public Completable onClosing() { + return closeable.onClosing(); + } + + @SuppressWarnings("unchecked") + private Completable doClose(final Function closeFunction) { + return Completable.defer(() -> { + final ConnState oldState = closeConnState(); + cancelIfHealthCheck(oldState); + final Object[] connections = oldState.connections; + return (connections.length == 0 ? completed() : + from(connections).flatMapCompletableDelayError(conn -> closeFunction.apply((C) conn))) + .shareContextOnSubscribe(); + }); + } + + private void cancelIfHealthCheck(ConnState connState) { + if (Host.isUnhealthy(connState)) { + @SuppressWarnings("unchecked") + HealthCheck healthCheck = (HealthCheck) connState.state; + LOGGER.debug("{}: health check cancelled for {}.", lbDescription, healthCheck.host); + healthCheck.cancel(); + } + } + + @Override + public String toString() { + final ConnState connState = this.connState; + return "Host{" + + "lbDescription=" + lbDescription + + ", address=" + address + + ", state=" + connState.state + + ", #connections=" + connState.connections.length + + '}'; + } + + private static final class ActiveState { + private final int failedConnections; + + ActiveState() { + this(0); + } + + private ActiveState(int failedConnections) { + this.failedConnections = failedConnections; + } + + ActiveState forNextFailedConnection() { + return new ActiveState(addWithOverflowProtection(this.failedConnections, 1)); + } + + @Override + public String toString() { + return "ACTIVE(failedConnections=" + failedConnections + ')'; + } + } + + private static final class HealthCheck + extends DelayedCancellable { + private final ConnectionFactory connectionFactory; + private final Host host; + private final Throwable lastError; + + private HealthCheck(final ConnectionFactory connectionFactory, + final Host host, final Throwable lastError) { + this.connectionFactory = connectionFactory; + this.host = host; + this.lastError = lastError; + } + + public void schedule(final Throwable originalCause) { + assert host.healthCheckConfig != null; + delayedCancellable( + // Use retry strategy to utilize jitter. + retryWithConstantBackoffDeltaJitter(cause -> true, + host.healthCheckConfig.healthCheckInterval, + host.healthCheckConfig.jitter, + host.healthCheckConfig.executor) + .apply(0, originalCause) + // Remove any state from async context + .beforeOnSubscribe(__ -> AsyncContext.clear()) + .concat(connectionFactory.newConnection(host.address, null, null) + // There is no risk for StackOverflowError because result of each connection + // attempt will be invoked on IoExecutor as a new task. + .retryWhen(retryWithConstantBackoffDeltaJitter( + cause -> { + LOGGER.debug("{}: health check failed for {}.", + host.lbDescription, host, cause); + return true; + }, + host.healthCheckConfig.healthCheckInterval, + host.healthCheckConfig.jitter, + host.healthCheckConfig.executor))) + .flatMapCompletable(newCnx -> { + if (host.addConnection(newCnx, this)) { + LOGGER.info("{}: health check passed for {}, marked this " + + "host as ACTIVE for the selection algorithm.", + host.lbDescription, host); + return completed(); + } else { + // This happens only if the host is closed, no need to mark as healthy. + assert host.connState.state == State.CLOSED; + LOGGER.debug("{}: health check passed for {}, but the " + + "host rejected a new connection {}. Closing it now.", + host.lbDescription, host, newCnx); + return newCnx.closeAsync(); + } + }) + // Use onErrorComplete instead of whenOnError to avoid double logging of an error inside + // subscribe(): SimpleCompletableSubscriber. + .onErrorComplete(t -> { + LOGGER.error("{}: health check terminated with " + + "an unexpected error for {}. Marking this host as ACTIVE as a fallback " + + "to allow connection attempts.", host.lbDescription, host, t); + host.markHealthy(this); + return true; + }) + .subscribe()); + } + + @Override + public String toString() { + return "UNHEALTHY(" + lastError + ')'; + } + } + + private static final class ConnState { + final Object[] connections; + final Object state; + + ConnState(final Object[] connections, final Object state) { + this.connections = connections; + this.state = state; + } + + @Override + public String toString() { + return "ConnState{" + + "state=" + state + + ", #connections=" + connections.length + + '}'; + } + } +} diff --git a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancer.java b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancer.java index 4a1034dc43..35c3a1a52f 100644 --- a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancer.java +++ b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancer.java @@ -16,36 +16,27 @@ package io.servicetalk.loadbalancer; import io.servicetalk.client.api.ConnectionFactory; -import io.servicetalk.client.api.ConnectionLimitReachedException; -import io.servicetalk.client.api.ConnectionRejectedException; import io.servicetalk.client.api.LoadBalancedConnection; import io.servicetalk.client.api.LoadBalancer; -import io.servicetalk.client.api.NoActiveHostException; -import io.servicetalk.client.api.NoAvailableHostException; import io.servicetalk.client.api.ServiceDiscovererEvent; import io.servicetalk.concurrent.PublisherSource.Processor; import io.servicetalk.concurrent.PublisherSource.Subscriber; import io.servicetalk.concurrent.PublisherSource.Subscription; -import io.servicetalk.concurrent.api.AsyncCloseable; -import io.servicetalk.concurrent.api.AsyncContext; import io.servicetalk.concurrent.api.Completable; import io.servicetalk.concurrent.api.CompositeCloseable; -import io.servicetalk.concurrent.api.Executor; import io.servicetalk.concurrent.api.ListenableAsyncCloseable; import io.servicetalk.concurrent.api.Publisher; import io.servicetalk.concurrent.api.Single; -import io.servicetalk.concurrent.internal.DelayedCancellable; import io.servicetalk.concurrent.internal.SequentialCancellable; -import io.servicetalk.concurrent.internal.ThrowableUtils; import io.servicetalk.context.api.ContextMap; +import io.servicetalk.loadbalancer.Exceptions.StacklessConnectionRejectedException; +import io.servicetalk.loadbalancer.Exceptions.StacklessNoActiveHostException; +import io.servicetalk.loadbalancer.Exceptions.StacklessNoAvailableHostException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.time.Duration; -import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Comparator; import java.util.Iterator; @@ -58,7 +49,6 @@ import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.Consumer; -import java.util.function.Function; import java.util.function.Predicate; import java.util.function.UnaryOperator; import java.util.stream.Stream; @@ -71,16 +61,12 @@ import static io.servicetalk.client.api.ServiceDiscovererEvent.Status.UNAVAILABLE; import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable; import static io.servicetalk.concurrent.api.AsyncCloseables.toAsyncCloseable; -import static io.servicetalk.concurrent.api.Completable.completed; import static io.servicetalk.concurrent.api.Processors.newPublisherProcessorDropHeadOnOverflow; -import static io.servicetalk.concurrent.api.Publisher.from; -import static io.servicetalk.concurrent.api.RetryStrategies.retryWithConstantBackoffDeltaJitter; import static io.servicetalk.concurrent.api.Single.defer; import static io.servicetalk.concurrent.api.Single.failed; import static io.servicetalk.concurrent.api.Single.succeeded; import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; import static io.servicetalk.concurrent.api.SourceAdapters.toSource; -import static io.servicetalk.concurrent.internal.FlowControlUtils.addWithOverflowProtection; import static java.lang.Integer.toHexString; import static java.lang.Math.min; import static java.lang.System.identityHashCode; @@ -88,7 +74,6 @@ import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.NANOSECONDS; -import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.newUpdater; import static java.util.stream.Collectors.toList; /** @@ -101,7 +86,6 @@ final class RoundRobinLoadBalancer { private static final Logger LOGGER = LoggerFactory.getLogger(RoundRobinLoadBalancer.class); - private static final Object[] EMPTY_ARRAY = new Object[0]; @SuppressWarnings("rawtypes") private static final AtomicReferenceFieldUpdater usedHostsUpdater = @@ -233,7 +217,7 @@ private static boolean allUn final List> usedHosts) { boolean allUnhealthy = !usedHosts.isEmpty(); for (Host host : usedHosts) { - if (!Host.isUnhealthy(host.connState)) { + if (!host.isUnhealthy()) { allUnhealthy = false; break; } @@ -382,6 +366,7 @@ private List> markHostAsExpired( } private Host createHost(ResolvedAddress addr) { + // All hosts will share the healthcheck config of the parent RR loadbalancer. Host host = new Host<>(RoundRobinLoadBalancer.this.toString(), addr, healthCheckConfig); host.onClose().afterFinally(() -> usedHostsUpdater.updateAndGet(RoundRobinLoadBalancer.this, previousHosts -> { @@ -515,7 +500,7 @@ private Single selectConnection0(final Predicate selector, @Nullable final if (!forceNewConnectionAndReserve) { // Try first to see if an existing connection can be used - final Object[] connections = host.connState.connections; + final Object[] connections = host.connections(); // Exhaust the linear search space first: final int linearAttempts = min(connections.length, linearSearchSpace); for (int j = 0; j < linearAttempts; ++j) { @@ -568,7 +553,7 @@ private Single selectConnection0(final Predicate selector, @Nullable final // This LB implementation does not automatically provide TransportObserver. Therefore, we pass "null" here. // Users can apply a ConnectionFactoryFilter if they need to override this "null" value with TransportObserver. Single establishConnection = connectionFactory.newConnection(host.address, context, null); - if (host.healthCheckConfig != null) { + if (healthCheckConfig != null) { // Schedule health check before returning establishConnection = establishConnection.beforeOnError(t -> host.markUnhealthy(t, connectionFactory)); } @@ -636,513 +621,6 @@ List>> usedAddresses() { return usedHosts.stream().map(Host::asEntry).collect(toList()); } - static final class HealthCheckConfig { - private final Executor executor; - private final Duration healthCheckInterval; - private final Duration jitter; - private final int failedThreshold; - private final long healthCheckResubscribeLowerBound; - private final long healthCheckResubscribeUpperBound; - - HealthCheckConfig(final Executor executor, final Duration healthCheckInterval, final Duration healthCheckJitter, - final int failedThreshold, final long healthCheckResubscribeLowerBound, - final long healthCheckResubscribeUpperBound) { - this.executor = executor; - this.healthCheckInterval = healthCheckInterval; - this.failedThreshold = failedThreshold; - this.jitter = healthCheckJitter; - this.healthCheckResubscribeLowerBound = healthCheckResubscribeLowerBound; - this.healthCheckResubscribeUpperBound = healthCheckResubscribeUpperBound; - } - } - - private static final class Host implements ListenableAsyncCloseable { - - private enum State { - // The enum is not exhaustive, as other states have dynamic properties. - // For clarity, the other state classes are listed as comments: - // ACTIVE - see ActiveState - // UNHEALTHY - see HealthCheck - EXPIRED, - CLOSED - } - - private static final ActiveState STATE_ACTIVE_NO_FAILURES = new ActiveState(); - private static final ConnState ACTIVE_EMPTY_CONN_STATE = new ConnState(EMPTY_ARRAY, STATE_ACTIVE_NO_FAILURES); - private static final ConnState CLOSED_CONN_STATE = new ConnState(EMPTY_ARRAY, State.CLOSED); - - @SuppressWarnings("rawtypes") - private static final AtomicReferenceFieldUpdater connStateUpdater = - newUpdater(Host.class, ConnState.class, "connState"); - - private final String lbDescription; - final Addr address; - @Nullable - private final HealthCheckConfig healthCheckConfig; - private final ListenableAsyncCloseable closeable; - private volatile ConnState connState = ACTIVE_EMPTY_CONN_STATE; - - Host(String lbDescription, Addr address, @Nullable HealthCheckConfig healthCheckConfig) { - this.lbDescription = lbDescription; - this.address = address; - this.healthCheckConfig = healthCheckConfig; - this.closeable = toAsyncCloseable(graceful -> - graceful ? doClose(AsyncCloseable::closeAsyncGracefully) : doClose(AsyncCloseable::closeAsync)); - } - - boolean markActiveIfNotClosed() { - final Object oldState = connStateUpdater.getAndUpdate(this, oldConnState -> { - if (oldConnState.state == State.EXPIRED) { - return new ConnState(oldConnState.connections, STATE_ACTIVE_NO_FAILURES); - } - // If oldConnState.state == State.ACTIVE this could mean either a duplicate event, - // or a repeated CAS operation. We could issue a warning, but as we don't know, we don't log anything. - // UNHEALTHY state cannot transition to ACTIVE without passing the health check. - return oldConnState; - }).state; - return oldState != State.CLOSED; - } - - void markClosed() { - final ConnState oldState = closeConnState(); - final Object[] toRemove = oldState.connections; - cancelIfHealthCheck(oldState); - LOGGER.debug("{}: closing {} connection(s) gracefully to the closed address: {}.", - lbDescription, toRemove.length, address); - for (Object conn : toRemove) { - @SuppressWarnings("unchecked") - final C cConn = (C) conn; - cConn.closeAsyncGracefully().subscribe(); - } - } - - private ConnState closeConnState() { - for (;;) { - // We need to keep the oldState.connections around even if we are closed because the user may do - // closeGracefully with a timeout, which fails, and then force close. If we discard connections when - // closeGracefully is started we may leak connections. - final ConnState oldState = connState; - if (oldState.state == State.CLOSED || connStateUpdater.compareAndSet(this, oldState, - new ConnState(oldState.connections, State.CLOSED))) { - return oldState; - } - } - } - - void markExpired() { - for (;;) { - ConnState oldState = connStateUpdater.get(this); - if (oldState.state == State.EXPIRED || oldState.state == State.CLOSED) { - break; - } - Object nextState = oldState.connections.length == 0 ? State.CLOSED : State.EXPIRED; - - if (connStateUpdater.compareAndSet(this, oldState, - new ConnState(oldState.connections, nextState))) { - cancelIfHealthCheck(oldState); - if (nextState == State.CLOSED) { - // Trigger the callback to remove the host from usedHosts array. - this.closeAsync().subscribe(); - } - break; - } - } - } - - void markHealthy(final HealthCheck originalHealthCheckState) { - // Marking healthy is called when we need to recover from an unexpected error. - // However, it is possible that in the meantime, the host entered an EXPIRED state, then ACTIVE, then failed - // to open connections and entered the UNHEALTHY state before the original thread continues execution here. - // In such case, the flipped state is not the same as the one that just succeeded to open a connection. - // In an unlikely scenario that the following connection attempts fail indefinitely, a health check task - // would leak and would not be cancelled. Therefore, we cancel it here and allow failures to trigger a new - // health check. - ConnState oldState = connStateUpdater.getAndUpdate(this, previous -> { - if (Host.isUnhealthy(previous)) { - return new ConnState(previous.connections, STATE_ACTIVE_NO_FAILURES); - } - return previous; - }); - if (oldState.state != originalHealthCheckState) { - cancelIfHealthCheck(oldState); - } - } - - void markUnhealthy(final Throwable cause, final ConnectionFactory connectionFactory) { - assert healthCheckConfig != null; - for (;;) { - ConnState previous = connStateUpdater.get(this); - - if (!Host.isActive(previous) || previous.connections.length > 0 - || cause instanceof ConnectionLimitReachedException) { - LOGGER.debug("{}: failed to open a new connection to the host on address {}. {}.", - lbDescription, address, previous, cause); - break; - } - - ActiveState previousState = (ActiveState) previous.state; - if (previousState.failedConnections + 1 < healthCheckConfig.failedThreshold) { - final ActiveState nextState = previousState.forNextFailedConnection(); - if (connStateUpdater.compareAndSet(this, previous, - new ConnState(previous.connections, nextState))) { - LOGGER.debug("{}: failed to open a new connection to the host on address {}" + - " {} time(s) ({} consecutive failures will trigger health-checking).", - lbDescription, address, nextState.failedConnections, - healthCheckConfig.failedThreshold, cause); - break; - } - // another thread won the race, try again - continue; - } - - final HealthCheck healthCheck = new HealthCheck<>(connectionFactory, this, cause); - final ConnState nextState = new ConnState(previous.connections, healthCheck); - if (connStateUpdater.compareAndSet(this, previous, nextState)) { - LOGGER.info("{}: failed to open a new connection to the host on address {} " + - "{} time(s) in a row. Error counting threshold reached, marking this host as " + - "UNHEALTHY for the selection algorithm and triggering background health-checking.", - lbDescription, address, healthCheckConfig.failedThreshold, cause); - healthCheck.schedule(cause); - break; - } - } - } - - boolean isActiveAndHealthy() { - return isActive(connState); - } - - static boolean isActive(final ConnState connState) { - return ActiveState.class.equals(connState.state.getClass()); - } - - static boolean isUnhealthy(final ConnState connState) { - return HealthCheck.class.equals(connState.state.getClass()); - } - - boolean addConnection(final C connection, final @Nullable HealthCheck currentHealthCheck) { - int addAttempt = 0; - for (;;) { - final ConnState previous = connStateUpdater.get(this); - if (previous.state == State.CLOSED) { - return false; - } - ++addAttempt; - - final Object[] existing = previous.connections; - // Brute force iteration to avoid duplicates. If connections grow larger and faster lookup is required - // we can keep a Set for faster lookups (at the cost of more memory) as well as array. - for (final Object o : existing) { - if (o.equals(connection)) { - return true; - } - } - Object[] newList = Arrays.copyOf(existing, existing.length + 1); - newList[existing.length] = connection; - - // If we were able to add a new connection to the list, we should mark the host as ACTIVE again and - // reset its failures counter. - final Object newState = Host.isActive(previous) || Host.isUnhealthy(previous) ? - STATE_ACTIVE_NO_FAILURES : previous.state; - - if (connStateUpdater.compareAndSet(this, - previous, new ConnState(newList, newState))) { - // It could happen that the Host turned into UNHEALTHY state either concurrently with adding a new - // connection or with passing a previous health-check (if SD turned it into ACTIVE state). In both - // cases we have to cancel the "previous" ongoing health check. See "markHealthy" for more context. - if (Host.isUnhealthy(previous) && - (currentHealthCheck == null || previous.state != currentHealthCheck)) { - assert newState == STATE_ACTIVE_NO_FAILURES; - cancelIfHealthCheck(previous); - } - break; - } - } - - LOGGER.trace("{}: added a new connection {} to {} after {} attempt(s).", - lbDescription, connection, this, addAttempt); - // Instrument the new connection so we prune it on close - connection.onClose().beforeFinally(() -> { - int removeAttempt = 0; - for (;;) { - final ConnState currentConnState = this.connState; - if (currentConnState.state == State.CLOSED) { - break; - } - assert currentConnState.connections.length > 0; - ++removeAttempt; - int i = 0; - final Object[] connections = currentConnState.connections; - // Search for the connection in the list. - for (; i < connections.length; ++i) { - if (connections[i].equals(connection)) { - break; - } - } - if (i == connections.length) { - // Connection was already removed, nothing to do. - break; - } else if (connections.length == 1) { - assert !Host.isUnhealthy(currentConnState) : "Cannot be UNHEALTHY with #connections > 0"; - if (Host.isActive(currentConnState)) { - if (connStateUpdater.compareAndSet(this, currentConnState, - new ConnState(EMPTY_ARRAY, currentConnState.state))) { - break; - } - } else if (currentConnState.state == State.EXPIRED - // We're closing the last connection, close the Host. - // Closing the host will trigger the Host's onClose method, which will remove the host - // from used hosts list. If a race condition appears and a new connection was added - // in the meantime, that would mean the host is available again and the CAS operation - // will allow for determining that. It will prevent closing the Host and will only - // remove the connection (previously considered as the last one) from the array - // in the next iteration. - && connStateUpdater.compareAndSet(this, currentConnState, CLOSED_CONN_STATE)) { - this.closeAsync().subscribe(); - break; - } - } else { - Object[] newList = new Object[connections.length - 1]; - System.arraycopy(connections, 0, newList, 0, i); - System.arraycopy(connections, i + 1, newList, i, newList.length - i); - if (connStateUpdater.compareAndSet(this, - currentConnState, new ConnState(newList, currentConnState.state))) { - break; - } - } - } - LOGGER.trace("{}: removed connection {} from {} after {} attempt(s).", - lbDescription, connection, this, removeAttempt); - }).onErrorComplete(t -> { - // Use onErrorComplete instead of whenOnError to avoid double logging of an error inside subscribe(): - // SimpleCompletableSubscriber. - LOGGER.error("{}: unexpected error while processing connection.onClose() for {}.", - lbDescription, connection, t); - return true; - }).subscribe(); - return true; - } - - // Used for testing only - @SuppressWarnings("unchecked") - Entry> asEntry() { - return new SimpleImmutableEntry<>(address, - Stream.of(connState.connections).map(conn -> (C) conn).collect(toList())); - } - - @Override - public Completable closeAsync() { - return closeable.closeAsync(); - } - - @Override - public Completable closeAsyncGracefully() { - return closeable.closeAsyncGracefully(); - } - - @Override - public Completable onClose() { - return closeable.onClose(); - } - - @Override - public Completable onClosing() { - return closeable.onClosing(); - } - - @SuppressWarnings("unchecked") - private Completable doClose(final Function closeFunction) { - return Completable.defer(() -> { - final ConnState oldState = closeConnState(); - cancelIfHealthCheck(oldState); - final Object[] connections = oldState.connections; - return (connections.length == 0 ? completed() : - from(connections).flatMapCompletableDelayError(conn -> closeFunction.apply((C) conn))) - .shareContextOnSubscribe(); - }); - } - - private void cancelIfHealthCheck(ConnState connState) { - if (Host.isUnhealthy(connState)) { - @SuppressWarnings("unchecked") - HealthCheck healthCheck = (HealthCheck) connState.state; - LOGGER.debug("{}: health check cancelled for {}.", lbDescription, healthCheck.host); - healthCheck.cancel(); - } - } - - @Override - public String toString() { - final ConnState connState = this.connState; - return "Host{" + - "lbDescription=" + lbDescription + - ", address=" + address + - ", state=" + connState.state + - ", #connections=" + connState.connections.length + - '}'; - } - - private static final class ActiveState { - private final int failedConnections; - - ActiveState() { - this(0); - } - - private ActiveState(int failedConnections) { - this.failedConnections = failedConnections; - } - - ActiveState forNextFailedConnection() { - return new ActiveState(addWithOverflowProtection(this.failedConnections, 1)); - } - - @Override - public String toString() { - return "ACTIVE(failedConnections=" + failedConnections + ')'; - } - } - - private static final class HealthCheck - extends DelayedCancellable { - private final ConnectionFactory connectionFactory; - private final Host host; - private final Throwable lastError; - - private HealthCheck(final ConnectionFactory connectionFactory, - final Host host, final Throwable lastError) { - this.connectionFactory = connectionFactory; - this.host = host; - this.lastError = lastError; - } - - public void schedule(final Throwable originalCause) { - assert host.healthCheckConfig != null; - delayedCancellable( - // Use retry strategy to utilize jitter. - retryWithConstantBackoffDeltaJitter(cause -> true, - host.healthCheckConfig.healthCheckInterval, - host.healthCheckConfig.jitter, - host.healthCheckConfig.executor) - .apply(0, originalCause) - // Remove any state from async context - .beforeOnSubscribe(__ -> AsyncContext.clear()) - .concat(connectionFactory.newConnection(host.address, null, null) - // There is no risk for StackOverflowError because result of each connection - // attempt will be invoked on IoExecutor as a new task. - .retryWhen(retryWithConstantBackoffDeltaJitter( - cause -> { - LOGGER.debug("{}: health check failed for {}.", - host.lbDescription, host, cause); - return true; - }, - host.healthCheckConfig.healthCheckInterval, - host.healthCheckConfig.jitter, - host.healthCheckConfig.executor))) - .flatMapCompletable(newCnx -> { - if (host.addConnection(newCnx, this)) { - LOGGER.info("{}: health check passed for {}, marked this " + - "host as ACTIVE for the selection algorithm.", - host.lbDescription, host); - return completed(); - } else { - // This happens only if the host is closed, no need to mark as healthy. - assert host.connState.state == State.CLOSED; - LOGGER.debug("{}: health check passed for {}, but the " + - "host rejected a new connection {}. Closing it now.", - host.lbDescription, host, newCnx); - return newCnx.closeAsync(); - } - }) - // Use onErrorComplete instead of whenOnError to avoid double logging of an error inside - // subscribe(): SimpleCompletableSubscriber. - .onErrorComplete(t -> { - LOGGER.error("{}: health check terminated with " + - "an unexpected error for {}. Marking this host as ACTIVE as a fallback " + - "to allow connection attempts.", host.lbDescription, host, t); - host.markHealthy(this); - return true; - }) - .subscribe()); - } - - @Override - public String toString() { - return "UNHEALTHY(" + lastError + ')'; - } - } - - private static final class ConnState { - final Object[] connections; - final Object state; - - ConnState(final Object[] connections, final Object state) { - this.connections = connections; - this.state = state; - } - - @Override - public String toString() { - return "ConnState{" + - "state=" + state + - ", #connections=" + connections.length + - '}'; - } - } - } - - private static final class StacklessNoAvailableHostException extends NoAvailableHostException { - private static final long serialVersionUID = 5942960040738091793L; - - private StacklessNoAvailableHostException(final String message) { - super(message); - } - - @Override - public Throwable fillInStackTrace() { - return this; - } - - public static StacklessNoAvailableHostException newInstance(String message, Class clazz, String method) { - return ThrowableUtils.unknownStackTrace(new StacklessNoAvailableHostException(message), clazz, method); - } - } - - private static final class StacklessNoActiveHostException extends NoActiveHostException { - - private static final long serialVersionUID = 7500474499335155869L; - - private StacklessNoActiveHostException(final String message) { - super(message); - } - - @Override - public Throwable fillInStackTrace() { - return this; - } - - public static StacklessNoActiveHostException newInstance(String message, Class clazz, String method) { - return ThrowableUtils.unknownStackTrace(new StacklessNoActiveHostException(message), clazz, method); - } - } - - private static final class StacklessConnectionRejectedException extends ConnectionRejectedException { - private static final long serialVersionUID = -4940708893680455819L; - - private StacklessConnectionRejectedException(final String message) { - super(message); - } - - @Override - public Throwable fillInStackTrace() { - return this; - } - - public static StacklessConnectionRejectedException newInstance(String message, Class clazz, String method) { - return ThrowableUtils.unknownStackTrace(new StacklessConnectionRejectedException(message), clazz, method); - } - } - private static boolean isClosedList(List list) { return list.getClass().equals(ClosedList.class); } diff --git a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancerFactory.java b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancerFactory.java index 9b4362b6f1..e47b9a3464 100644 --- a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancerFactory.java +++ b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancerFactory.java @@ -24,7 +24,6 @@ import io.servicetalk.concurrent.api.Executor; import io.servicetalk.concurrent.api.Executors; import io.servicetalk.concurrent.api.Publisher; -import io.servicetalk.loadbalancer.RoundRobinLoadBalancer.HealthCheckConfig; import io.servicetalk.transport.api.ExecutionStrategy; import java.time.Duration;