From 0c24e2a3b69c2a0863c44b8de8b27829a6ad1388 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Fri, 17 Jan 2025 12:14:13 -0800 Subject: [PATCH] Simplify ES|QL execution info --- .../esql/plugin/ClusterComputeHandler.java | 166 ++++---- .../xpack/esql/plugin/ComputeListener.java | 240 +---------- .../xpack/esql/plugin/ComputeService.java | 181 ++++++--- .../esql/plugin/DataNodeComputeHandler.java | 162 ++++---- .../esql/plugin/ComputeListenerTests.java | 381 ++---------------- 5 files changed, 341 insertions(+), 789 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java index 1f2b8faf83ee3..656cf3680dfc7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java @@ -11,10 +11,10 @@ import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.support.ChannelActionListener; -import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; @@ -25,17 +25,16 @@ import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.session.Configuration; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; /** * Manages computes across multiple clusters by sending {@link ClusterComputeRequest} to remote clusters and executing the computes. @@ -63,47 +62,39 @@ final class ClusterComputeHandler implements TransportRequestHandler clusters, - ComputeListener computeListener + RemoteCluster cluster, + ActionListener listener ) { var queryPragmas = configuration.pragmas(); - var linkExchangeListeners = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); - try (EsqlRefCountingListener refs = new EsqlRefCountingListener(linkExchangeListeners)) { - for (RemoteCluster cluster : clusters) { - final var childSessionId = computeService.newChildSession(sessionId); - ExchangeService.openExchange( - transportService, + listener = ActionListener.runBefore(listener, exchangeSource.addEmptySink()::close); + final var childSessionId = computeService.newChildSession(sessionId); + ExchangeService.openExchange( + transportService, + cluster.connection, + childSessionId, + queryPragmas.exchangeBufferSize(), + esqlExecutor, + listener.delegateFailureAndWrap((l, unused) -> { + var remoteSink = exchangeService.newRemoteSink(rootTask, childSessionId, transportService, cluster.connection); + exchangeSource.addRemoteSink(remoteSink, true, queryPragmas.concurrentExchangeClients(), ActionListener.noop()); + var remotePlan = new RemoteClusterPlan(plan, cluster.concreteIndices, cluster.originalIndices); + var clusterRequest = new ClusterComputeRequest(cluster.clusterAlias, childSessionId, configuration, remotePlan); + transportService.sendChildRequest( cluster.connection, - childSessionId, - queryPragmas.exchangeBufferSize(), - esqlExecutor, - refs.acquire().delegateFailureAndWrap((l, unused) -> { - var remoteSink = exchangeService.newRemoteSink(rootTask, childSessionId, transportService, cluster.connection); - exchangeSource.addRemoteSink(remoteSink, true, queryPragmas.concurrentExchangeClients(), ActionListener.noop()); - var remotePlan = new RemoteClusterPlan(plan, cluster.concreteIndices, cluster.originalIndices); - var clusterRequest = new ClusterComputeRequest(cluster.clusterAlias, childSessionId, configuration, remotePlan); - var clusterListener = ActionListener.runBefore( - computeListener.acquireCompute(cluster.clusterAlias()), - () -> l.onResponse(null) - ); - transportService.sendChildRequest( - cluster.connection, - ComputeService.CLUSTER_ACTION_NAME, - clusterRequest, - rootTask, - TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(clusterListener, ComputeResponse::new, esqlExecutor) - ); - }) + ComputeService.CLUSTER_ACTION_NAME, + clusterRequest, + rootTask, + TransportRequestOptions.EMPTY, + new ActionListenerResponseHandler<>(l, ComputeResponse::new, esqlExecutor) ); - } - } + }) + ); } List getRemoteClusters( @@ -141,28 +132,16 @@ public void messageReceived(ClusterComputeRequest request, TransportChannel chan listener.onFailure(new IllegalStateException("expected exchange sink for a remote compute; got " + plan)); return; } - String clusterAlias = request.clusterAlias(); - /* - * This handler runs only on remote cluster coordinators, so it creates a new local EsqlExecutionInfo object to record - * execution metadata for ES|QL processing local to this cluster. The execution info will be copied into the - * ComputeResponse that is sent back to the primary coordinating cluster. - */ - EsqlExecutionInfo execInfo = new EsqlExecutionInfo(true); - execInfo.swapCluster(clusterAlias, (k, v) -> new EsqlExecutionInfo.Cluster(clusterAlias, Arrays.toString(request.indices()))); - CancellableTask cancellable = (CancellableTask) task; - try (var computeListener = ComputeListener.create(clusterAlias, transportService, cancellable, execInfo, listener)) { - runComputeOnRemoteCluster( - clusterAlias, - request.sessionId(), - (CancellableTask) task, - request.configuration(), - (ExchangeSinkExec) plan, - Set.of(remoteClusterPlan.targetIndices()), - remoteClusterPlan.originalIndices(), - execInfo, - computeListener - ); - } + runComputeOnRemoteCluster( + request.clusterAlias(), + request.sessionId(), + (CancellableTask) task, + request.configuration(), + (ExchangeSinkExec) plan, + Set.of(remoteClusterPlan.targetIndices()), + remoteClusterPlan.originalIndices(), + listener + ); } /** @@ -182,8 +161,7 @@ void runComputeOnRemoteCluster( ExchangeSinkExec plan, Set concreteIndices, OriginalIndices originalIndices, - EsqlExecutionInfo executionInfo, - ComputeListener computeListener + ActionListener listener ) { final var exchangeSink = exchangeService.getSinkHandler(globalSessionId); parentTask.addListener( @@ -191,39 +169,51 @@ void runComputeOnRemoteCluster( ); final String localSessionId = clusterAlias + ":" + globalSessionId; final PhysicalPlan coordinatorPlan = ComputeService.reductionPlan(plan, true); - var exchangeSource = new ExchangeSourceHandler( - configuration.pragmas().exchangeBufferSize(), - transportService.getThreadPool().executor(ThreadPool.Names.SEARCH), - computeListener.acquireAvoid() - ); - try (Releasable ignored = exchangeSource.addEmptySink()) { - exchangeSink.addCompletionListener(computeListener.acquireAvoid()); - computeService.runCompute( - parentTask, - new ComputeContext( + final AtomicReference finalResponse = new AtomicReference<>(); + final long startTimeInNanos = System.nanoTime(); + final Runnable cancelQueryOnFailure = computeService.cancelQueryOnFailure(parentTask); + try (var computeListener = new ComputeListener(transportService.getThreadPool(), cancelQueryOnFailure, listener.map(profiles -> { + final TimeValue took = TimeValue.timeValueNanos(System.nanoTime() - startTimeInNanos); + final ComputeResponse r = finalResponse.get(); + return new ComputeResponse(profiles, took, r.totalShards, r.successfulShards, r.skippedShards, r.failedShards); + }))) { + var exchangeSource = new ExchangeSourceHandler( + configuration.pragmas().exchangeBufferSize(), + transportService.getThreadPool().executor(ThreadPool.Names.SEARCH), + computeListener.acquireAvoid() + ); + try (Releasable ignored = exchangeSource.addEmptySink()) { + exchangeSink.addCompletionListener(computeListener.acquireAvoid()); + computeService.runCompute( + parentTask, + new ComputeContext( + localSessionId, + clusterAlias, + List.of(), + configuration, + configuration.newFoldContext(), + exchangeSource, + exchangeSink + ), + coordinatorPlan, + computeListener.acquireCompute() + ); + dataNodeComputeHandler.startComputeOnDataNodes( localSessionId, clusterAlias, - List.of(), + parentTask, configuration, - configuration.newFoldContext(), + plan, + concreteIndices, + originalIndices, exchangeSource, - exchangeSink - ), - coordinatorPlan, - computeListener.acquireCompute(clusterAlias) - ); - dataNodeComputeHandler.startComputeOnDataNodes( - localSessionId, - clusterAlias, - parentTask, - configuration, - plan, - concreteIndices, - originalIndices, - exchangeSource, - executionInfo, - computeListener - ); + cancelQueryOnFailure, + computeListener.acquireCompute().map(r -> { + finalResponse.set(r); + return r.getProfiles(); + }) + ); + } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java index 8bd23230fcde7..3d358b8c7a8a2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java @@ -12,189 +12,44 @@ import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.operator.DriverProfile; import org.elasticsearch.compute.operator.ResponseHeadersCollector; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.tasks.CancellableTask; -import org.elasticsearch.transport.RemoteClusterAware; -import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo; +import org.elasticsearch.threadpool.ThreadPool; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; /** * A variant of {@link RefCountingListener} with the following differences: - * 1. Automatically cancels sub tasks on failure. + * 1. Automatically cancels sub tasks on failure (via runOnTaskFailure) * 2. Collects driver profiles from sub tasks. * 3. Collects response headers from sub tasks, specifically warnings emitted during compute * 4. Collects failures and returns the most appropriate exception to the caller. - * 5. Updates {@link EsqlExecutionInfo} for display in the response for cross-cluster searches */ final class ComputeListener implements Releasable { - private static final Logger LOGGER = LogManager.getLogger(ComputeService.class); - private final EsqlRefCountingListener refs; - private final AtomicBoolean cancelled = new AtomicBoolean(); - private final CancellableTask task; - private final TransportService transportService; private final List collectedProfiles; private final ResponseHeadersCollector responseHeaders; - private final EsqlExecutionInfo esqlExecutionInfo; - // clusterAlias indicating where this ComputeListener is running - // used by the top level ComputeListener in ComputeService on both local and remote clusters - private final String whereRunning; - - /** - * Create a ComputeListener that does not need to gather any metadata in EsqlExecutionInfo - * (currently that's the ComputeListener in DataNodeRequestHandler). - */ - public static ComputeListener create( - TransportService transportService, - CancellableTask task, - ActionListener delegate - ) { - return new ComputeListener(transportService, task, null, null, delegate); - } + private final Runnable runOnFailure; - /** - * Create a ComputeListener that gathers metadata in EsqlExecutionInfo - * (currently that's the top level ComputeListener in ComputeService). - * @param clusterAlias the clusterAlias where this ComputeListener is running. For the querying cluster, use - * RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY. For remote clusters that are part of a CCS, - * the remote cluster is given its clusterAlias in the request sent to it, so that should be - * passed in here. This gives context to the ComputeListener as to where this listener is running - * and thus how it should behave with respect to the {@link EsqlExecutionInfo} metadata it gathers. - * @param transportService - * @param task - * @param executionInfo {@link EsqlExecutionInfo} to capture execution metadata - * @param delegate - */ - public static ComputeListener create( - String clusterAlias, - TransportService transportService, - CancellableTask task, - EsqlExecutionInfo executionInfo, - ActionListener delegate - ) { - return new ComputeListener(transportService, task, clusterAlias, executionInfo, delegate); - } - - private ComputeListener( - TransportService transportService, - CancellableTask task, - String clusterAlias, - EsqlExecutionInfo executionInfo, - ActionListener delegate - ) { - this.transportService = transportService; - this.task = task; - this.responseHeaders = new ResponseHeadersCollector(transportService.getThreadPool().getThreadContext()); + ComputeListener(ThreadPool threadPool, Runnable runOnFailure, ActionListener> delegate) { + this.runOnFailure = runOnFailure; + this.responseHeaders = new ResponseHeadersCollector(threadPool.getThreadContext()); this.collectedProfiles = Collections.synchronizedList(new ArrayList<>()); - this.esqlExecutionInfo = executionInfo; - this.whereRunning = clusterAlias; - // for the DataNodeHandler ComputeListener, clusterAlias and executionInfo will be null - // for the top level ComputeListener in ComputeService both will be non-null - assert (clusterAlias == null && executionInfo == null) || (clusterAlias != null && executionInfo != null) - : "clusterAlias and executionInfo must both be null or both non-null"; - // listener that executes after all the sub-listeners refs (created via acquireCompute) have completed this.refs = new EsqlRefCountingListener(delegate.delegateFailure((l, ignored) -> { responseHeaders.finish(); - ComputeResponse result; - - if (runningOnRemoteCluster()) { - // for remote executions - this ComputeResponse is created on the remote cluster/node and will be serialized and - // received by the acquireCompute method callback on the coordinating cluster - setFinalStatusAndShardCounts(clusterAlias, executionInfo); - EsqlExecutionInfo.Cluster cluster = esqlExecutionInfo.getCluster(clusterAlias); - result = new ComputeResponse( - collectedProfiles.isEmpty() ? List.of() : collectedProfiles.stream().toList(), - cluster.getTook(), - cluster.getTotalShards(), - cluster.getSuccessfulShards(), - cluster.getSkippedShards(), - cluster.getFailedShards() - ); - } else { - result = new ComputeResponse(collectedProfiles.isEmpty() ? List.of() : collectedProfiles.stream().toList()); - if (coordinatingClusterIsSearchedInCCS()) { - // if not already marked as SKIPPED, mark the local cluster as finished once the coordinator and all - // data nodes have finished processing - setFinalStatusAndShardCounts(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, executionInfo); - } - } - delegate.onResponse(result); + delegate.onResponse(collectedProfiles.stream().toList()); })); } - private static void setFinalStatusAndShardCounts(String clusterAlias, EsqlExecutionInfo executionInfo) { - executionInfo.swapCluster(clusterAlias, (k, v) -> { - // TODO: once PARTIAL status is supported (partial results work to come), modify this code as needed - if (v.getStatus() != EsqlExecutionInfo.Cluster.Status.SKIPPED) { - assert v.getTotalShards() != null && v.getSkippedShards() != null : "Null total or skipped shard count: " + v; - return new EsqlExecutionInfo.Cluster.Builder(v).setStatus(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL) - /* - * Total and skipped shard counts are set early in execution (after can-match). - * Until ES|QL supports shard-level partial results, we just set all non-skipped shards - * as successful and none are failed. - */ - .setSuccessfulShards(v.getTotalShards()) - .setFailedShards(0) - .build(); - } else { - return v; - } - }); - } - - /** - * @return true if the "local" querying/coordinator cluster is being searched in a cross-cluster search - */ - private boolean coordinatingClusterIsSearchedInCCS() { - return esqlExecutionInfo != null - && esqlExecutionInfo.isCrossClusterSearch() - && esqlExecutionInfo.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY) != null; - } - - /** - * @return true if this Listener is running on a remote cluster (i.e., not the querying cluster) - */ - private boolean runningOnRemoteCluster() { - return whereRunning != null && whereRunning.equals(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY) == false; - } - - /** - * @return true if the listener is in a context where the took time needs to be recorded into the EsqlExecutionInfo - */ - private boolean shouldRecordTookTime() { - return runningOnRemoteCluster() || coordinatingClusterIsSearchedInCCS(); - } - - /** - * @param computeClusterAlias the clusterAlias passed to the acquireCompute method - * @return true if this listener is waiting for a remote response in a CCS search - */ - private boolean isCCSListener(String computeClusterAlias) { - return RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY.equals(whereRunning) - && computeClusterAlias.equals(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY) == false; - } - /** * Acquires a new listener that doesn't collect result */ ActionListener acquireAvoid() { return refs.acquire().delegateResponse((l, e) -> { try { - if (cancelled.compareAndSet(false, true)) { - LOGGER.debug("cancelling ESQL task {} on failure", task); - transportService.getTaskManager().cancelTaskAndDescendants(task, "cancelled on failure", false, ActionListener.noop()); - } + runOnFailure.run(); } finally { l.onFailure(e); } @@ -203,86 +58,21 @@ ActionListener acquireAvoid() { /** * Acquires a new listener that collects compute result. This listener will also collect warnings emitted during compute - * @param computeClusterAlias The cluster alias where the compute is happening. Used when metadata needs to be gathered - * into the {@link EsqlExecutionInfo} Cluster objects. Callers that do not required execution - * info to be gathered (namely, the DataNodeRequestHandler ComputeListener) should pass in null. */ - ActionListener acquireCompute(@Nullable String computeClusterAlias) { - assert computeClusterAlias == null || (esqlExecutionInfo != null && esqlExecutionInfo.getRelativeStartNanos() != null) - : "When clusterAlias is provided to acquireCompute, executionInfo and relativeStartTimeNanos must be non-null"; - - return acquireAvoid().map(resp -> { + ActionListener> acquireCompute() { + final ActionListener delegate = acquireAvoid(); + return ActionListener.wrap(profiles -> { responseHeaders.collect(); - var profiles = resp.getProfiles(); if (profiles != null && profiles.isEmpty() == false) { collectedProfiles.addAll(profiles); } - if (computeClusterAlias == null) { - return null; - } - if (isCCSListener(computeClusterAlias)) { - // this is the callback for the listener on the primary coordinator that receives a remote ComputeResponse - updateExecutionInfoWithRemoteResponse(computeClusterAlias, resp); - - } else if (shouldRecordTookTime()) { - Long relativeStartNanos = esqlExecutionInfo.getRelativeStartNanos(); - // handler for this cluster's data node and coordinator completion (runs on "local" and remote clusters) - assert relativeStartNanos != null : "queryStartTimeNanos not set properly"; - TimeValue tookTime = new TimeValue(System.nanoTime() - relativeStartNanos, TimeUnit.NANOSECONDS); - esqlExecutionInfo.swapCluster(computeClusterAlias, (k, v) -> { - if (v.getStatus() != EsqlExecutionInfo.Cluster.Status.SKIPPED - && (v.getTook() == null || v.getTook().nanos() < tookTime.nanos())) { - return new EsqlExecutionInfo.Cluster.Builder(v).setTook(tookTime).build(); - } else { - return v; - } - }); - } - return null; + delegate.onResponse(null); + }, e -> { + responseHeaders.collect(); + delegate.onFailure(e); }); } - private void updateExecutionInfoWithRemoteResponse(String computeClusterAlias, ComputeResponse resp) { - TimeValue tookOnCluster; - if (resp.getTook() != null) { - TimeValue remoteExecutionTime = resp.getTook(); - TimeValue planningTookTime = esqlExecutionInfo.planningTookTime(); - tookOnCluster = new TimeValue(planningTookTime.nanos() + remoteExecutionTime.nanos(), TimeUnit.NANOSECONDS); - esqlExecutionInfo.swapCluster( - computeClusterAlias, - (k, v) -> new EsqlExecutionInfo.Cluster.Builder(v) - // for now ESQL doesn't return partial results, so set status to SUCCESSFUL - .setStatus(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL) - .setTook(tookOnCluster) - .setTotalShards(resp.getTotalShards()) - .setSuccessfulShards(resp.getSuccessfulShards()) - .setSkippedShards(resp.getSkippedShards()) - .setFailedShards(resp.getFailedShards()) - .build() - ); - } else { - // if the cluster is an older version and does not send back took time, then calculate it here on the coordinator - // and leave shard info unset, so it is not shown in the CCS metadata section of the JSON response - long remoteTook = System.nanoTime() - esqlExecutionInfo.getRelativeStartNanos(); - tookOnCluster = new TimeValue(remoteTook, TimeUnit.NANOSECONDS); - esqlExecutionInfo.swapCluster( - computeClusterAlias, - (k, v) -> new EsqlExecutionInfo.Cluster.Builder(v) - // for now ESQL doesn't return partial results, so set status to SUCCESSFUL - .setStatus(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL) - .setTook(tookOnCluster) - .build() - ); - } - } - - /** - * Use this method when no execution metadata needs to be added to {@link EsqlExecutionInfo} - */ - ActionListener acquireCompute() { - return acquireCompute(null); - } - @Override public void close() { refs.close(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 2cb4b49ec3591..ecb1a52331ced 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -12,14 +12,17 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverProfile; import org.elasticsearch.compute.operator.DriverTaskRunner; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.logging.LogManager; @@ -52,7 +55,9 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME; @@ -63,6 +68,7 @@ public class ComputeService { public static final String DATA_ACTION_NAME = EsqlQueryAction.NAME + "/data"; public static final String CLUSTER_ACTION_NAME = EsqlQueryAction.NAME + "/cluster"; + private static final String LOCAL_CLUSTER = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; private static final Logger LOGGER = LogManager.getLogger(ComputeService.class); private final SearchService searchService; @@ -137,6 +143,7 @@ public void execute( Map clusterToConcreteIndices = transportService.getRemoteClusterService() .groupIndices(SearchRequest.DEFAULT_INDICES_OPTIONS, PlannerUtils.planConcreteIndices(physicalPlan).toArray(String[]::new)); QueryPragmas queryPragmas = configuration.pragmas(); + Runnable cancelQueryOnFailure = cancelQueryOnFailure(rootTask); if (dataNodePlan == null) { if (clusterToConcreteIndices.values().stream().allMatch(v -> v.indices().length == 0) == false) { String error = "expected no concrete indices without data node plan; got " + clusterToConcreteIndices; @@ -146,20 +153,21 @@ public void execute( } var computeContext = new ComputeContext( newChildSession(sessionId), - RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, + LOCAL_CLUSTER, List.of(), configuration, foldContext, null, null ); - String local = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; updateShardCountForCoordinatorOnlyQuery(execInfo); - try (var computeListener = ComputeListener.create(local, transportService, rootTask, execInfo, listener.map(r -> { - updateExecutionInfoAfterCoordinatorOnlyQuery(execInfo); - return new Result(physicalPlan.output(), collectedPages, r.getProfiles(), execInfo); - }))) { - runCompute(rootTask, computeContext, coordinatorPlan, computeListener.acquireCompute(local)); + try ( + var computeListener = new ComputeListener(transportService.getThreadPool(), cancelQueryOnFailure, listener.map(profiles -> { + updateExecutionInfoAfterCoordinatorOnlyQuery(execInfo); + return new Result(physicalPlan.output(), collectedPages, profiles, execInfo); + })) + ) { + runCompute(rootTask, computeContext, coordinatorPlan, computeListener.acquireCompute()); return; } } else { @@ -172,22 +180,18 @@ public void execute( } Map clusterToOriginalIndices = transportService.getRemoteClusterService() .groupIndices(SearchRequest.DEFAULT_INDICES_OPTIONS, PlannerUtils.planOriginalIndices(physicalPlan)); - var localOriginalIndices = clusterToOriginalIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); - var localConcreteIndices = clusterToConcreteIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); - String local = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; + var localOriginalIndices = clusterToOriginalIndices.remove(LOCAL_CLUSTER); + var localConcreteIndices = clusterToConcreteIndices.remove(LOCAL_CLUSTER); /* * Grab the output attributes here, so we can pass them to * the listener without holding on to a reference to the * entire plan. */ List outputAttributes = physicalPlan.output(); - try ( - // this is the top level ComputeListener called once at the end (e.g., once all clusters have finished for a CCS) - var computeListener = ComputeListener.create(local, transportService, rootTask, execInfo, listener.map(r -> { - execInfo.markEndQuery(); // TODO: revisit this time recording model as part of INLINESTATS improvements - return new Result(outputAttributes, collectedPages, r.getProfiles(), execInfo); - })) - ) { + try (var computeListener = new ComputeListener(transportService.getThreadPool(), cancelQueryOnFailure, listener.map(profiles -> { + execInfo.markEndQuery(); // TODO: revisit this time recording model as part of INLINESTATS improvements + return new Result(outputAttributes, collectedPages, profiles, execInfo); + }))) { var exchangeSource = new ExchangeSourceHandler( queryPragmas.exchangeBufferSize(), transportService.getThreadPool().executor(ThreadPool.Names.SEARCH), @@ -195,50 +199,113 @@ public void execute( ); try (Releasable ignored = exchangeSource.addEmptySink()) { // run compute on the coordinator - runCompute( - rootTask, - new ComputeContext( - sessionId, - RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, - List.of(), - configuration, - foldContext, - exchangeSource, - null - ), - coordinatorPlan, - computeListener.acquireCompute(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY) - ); - // starts computes on data nodes on the main cluster - if (localConcreteIndices != null && localConcreteIndices.indices().length > 0) { - dataNodeComputeHandler.startComputeOnDataNodes( + final AtomicReference localResponse = new AtomicReference<>(new ComputeResponse(List.of())); + try ( + var localListener = new ComputeListener( + transportService.getThreadPool(), + cancelQueryOnFailure, + computeListener.acquireCompute().delegateFailure((l, profiles) -> { + if (execInfo.isCrossClusterSearch() && execInfo.clusterAliases().contains(LOCAL_CLUSTER)) { + var tookTime = TimeValue.timeValueNanos(System.nanoTime() - execInfo.getRelativeStartNanos()); + var r = localResponse.get(); + var merged = new ComputeResponse( + profiles, + tookTime, + r.totalShards, + r.successfulShards, + r.skippedShards, + r.failedShards + ); + updateExecutionInfo(execInfo, LOCAL_CLUSTER, merged); + } + l.onResponse(profiles); + }) + ) + ) { + runCompute( + rootTask, + new ComputeContext(sessionId, LOCAL_CLUSTER, List.of(), configuration, foldContext, exchangeSource, null), + coordinatorPlan, + localListener.acquireCompute() + ); + // starts computes on data nodes on the main cluster + if (localConcreteIndices != null && localConcreteIndices.indices().length > 0) { + dataNodeComputeHandler.startComputeOnDataNodes( + sessionId, + LOCAL_CLUSTER, + rootTask, + configuration, + dataNodePlan, + Set.of(localConcreteIndices.indices()), + localOriginalIndices, + exchangeSource, + cancelQueryOnFailure, + localListener.acquireCompute().map(r -> { + localResponse.set(r); + return r.getProfiles(); + }) + ); + } + } + // starts computes on remote clusters + final var remoteClusters = clusterComputeHandler.getRemoteClusters(clusterToConcreteIndices, clusterToOriginalIndices); + for (ClusterComputeHandler.RemoteCluster cluster : remoteClusters) { + clusterComputeHandler.startComputeOnRemoteCluster( sessionId, - RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, rootTask, configuration, dataNodePlan, - Set.of(localConcreteIndices.indices()), - localOriginalIndices, exchangeSource, - execInfo, - computeListener + cluster, + computeListener.acquireCompute().map(r -> { + updateExecutionInfo(execInfo, cluster.clusterAlias(), r); + return r.getProfiles(); + }) ); } - // starts computes on remote clusters - final var remoteClusters = clusterComputeHandler.getRemoteClusters(clusterToConcreteIndices, clusterToOriginalIndices); - clusterComputeHandler.startComputeOnRemoteClusters( - sessionId, - rootTask, - configuration, - dataNodePlan, - exchangeSource, - remoteClusters, - computeListener - ); } } } + private void updateExecutionInfo(EsqlExecutionInfo executionInfo, String clusterAlias, ComputeResponse resp) { + TimeValue tookOnCluster; + if (resp.getTook() != null) { + TimeValue remoteExecutionTime = resp.getTook(); + final long planningTime; + if (clusterAlias.equals(LOCAL_CLUSTER)) { + planningTime = 0L; + } else { + planningTime = executionInfo.planningTookTime().nanos(); + } + tookOnCluster = new TimeValue(planningTime + remoteExecutionTime.nanos(), TimeUnit.NANOSECONDS); + executionInfo.swapCluster( + clusterAlias, + (k, v) -> new EsqlExecutionInfo.Cluster.Builder(v) + // for now ESQL doesn't return partial results, so set status to SUCCESSFUL + .setStatus(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL) + .setTook(tookOnCluster) + .setTotalShards(resp.getTotalShards()) + .setSuccessfulShards(resp.getSuccessfulShards()) + .setSkippedShards(resp.getSkippedShards()) + .setFailedShards(resp.getFailedShards()) + .build() + ); + } else { + // if the cluster is an older version and does not send back took time, then calculate it here on the coordinator + // and leave shard info unset, so it is not shown in the CCS metadata section of the JSON response + long remoteTook = System.nanoTime() - executionInfo.getRelativeStartNanos(); + tookOnCluster = new TimeValue(remoteTook, TimeUnit.NANOSECONDS); + executionInfo.swapCluster( + clusterAlias, + (k, v) -> new EsqlExecutionInfo.Cluster.Builder(v) + // for now ESQL doesn't return partial results, so set status to SUCCESSFUL + .setStatus(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL) + .setTook(tookOnCluster) + .build() + ); + } + } + // For queries like: FROM logs* | LIMIT 0 (including cross-cluster LIMIT 0 queries) private static void updateShardCountForCoordinatorOnlyQuery(EsqlExecutionInfo execInfo) { if (execInfo.isCrossClusterSearch()) { @@ -272,7 +339,7 @@ private static void updateExecutionInfoAfterCoordinatorOnlyQuery(EsqlExecutionIn } } - void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener listener) { + void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener> listener) { listener = ActionListener.runBefore(listener, () -> Releasables.close(context.searchContexts())); List contexts = new ArrayList<>(context.searchContexts().size()); for (int i = 0; i < context.searchContexts().size(); i++) { @@ -328,10 +395,9 @@ public SourceProvider createSourceProvider() { } ActionListener listenerCollectingStatus = listener.map(ignored -> { if (context.configuration().profile()) { - return new ComputeResponse(drivers.stream().map(Driver::profile).toList()); + return drivers.stream().map(Driver::profile).toList(); } else { - final ComputeResponse response = new ComputeResponse(List.of()); - return response; + return List.of(); } }); listenerCollectingStatus = ActionListener.releaseAfter(listenerCollectingStatus, () -> Releasables.close(drivers)); @@ -357,4 +423,11 @@ static PhysicalPlan reductionPlan(ExchangeSinkExec plan, boolean enable) { String newChildSession(String session) { return session + "/" + childSessionIdGenerator.incrementAndGet(); } + + Runnable cancelQueryOnFailure(CancellableTask task) { + return new RunOnce(() -> { + LOGGER.debug("cancelling ESQL task {} on failure", task); + transportService.getTaskManager().cancelTaskAndDescendants(task, "cancelled on failure", false, ActionListener.noop()); + }); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java index 1a1e5726a487b..d5f0790802cf7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java @@ -17,13 +17,14 @@ import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.compute.EsqlRefCountingListener; +import org.elasticsearch.compute.operator.DriverProfile; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.compute.operator.exchange.ExchangeSink; import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler; import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.shard.IndexShard; @@ -42,7 +43,6 @@ import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo; import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; @@ -95,39 +95,45 @@ void startComputeOnDataNodes( Set concreteIndices, OriginalIndices originalIndices, ExchangeSourceHandler exchangeSource, - EsqlExecutionInfo executionInfo, - ComputeListener computeListener + Runnable runOnTaskFailure, + ActionListener outListener ) { QueryBuilder requestFilter = PlannerUtils.requestTimestampFilter(dataNodePlan); - var lookupListener = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); - // SearchShards API can_match is done in lookupDataNodes + var listener = ActionListener.runAfter(outListener, exchangeSource.addEmptySink()::close); + final long startTimeInNanos = System.nanoTime(); lookupDataNodes(parentTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(dataNodeResult -> { - try (EsqlRefCountingListener refs = new EsqlRefCountingListener(lookupListener)) { - // update ExecutionInfo with shard counts (total and skipped) - executionInfo.swapCluster( - clusterAlias, - (k, v) -> new EsqlExecutionInfo.Cluster.Builder(v).setTotalShards(dataNodeResult.totalShards()) - // do not set successful or failed shard count here - do it when search is done - .setSkippedShards(dataNodeResult.skippedShards()) - .build() - ); - + try ( + ComputeListener computeListener = new ComputeListener( + transportService.getThreadPool(), + runOnTaskFailure, + listener.map(profiles -> { + TimeValue took = TimeValue.timeValueNanos(System.nanoTime() - startTimeInNanos); + return new ComputeResponse( + profiles, + took, + dataNodeResult.totalShards(), + dataNodeResult.totalShards() - dataNodeResult.skippedShards(), + dataNodeResult.skippedShards(), + 0 + ); + }) + ) + ) { // For each target node, first open a remote exchange on the remote node, then link the exchange source to // the new remote exchange sink, and initialize the computation on the target node via data-node-request. for (DataNode node : dataNodeResult.dataNodes()) { var queryPragmas = configuration.pragmas(); var childSessionId = computeService.newChildSession(sessionId); + ActionListener nodeListener = computeListener.acquireCompute().map(ComputeResponse::getProfiles); ExchangeService.openExchange( transportService, node.connection, childSessionId, queryPragmas.exchangeBufferSize(), esqlExecutor, - refs.acquire().delegateFailureAndWrap((l, unused) -> { + nodeListener.delegateFailureAndWrap((l, unused) -> { var remoteSink = exchangeService.newRemoteSink(parentTask, childSessionId, transportService, node.connection); exchangeSource.addRemoteSink(remoteSink, true, queryPragmas.concurrentExchangeClients(), ActionListener.noop()); - ActionListener computeResponseListener = computeListener.acquireCompute(clusterAlias); - var dataNodeListener = ActionListener.runBefore(computeResponseListener, () -> l.onResponse(null)); final boolean sameNode = transportService.getLocalNode().getId().equals(node.connection.getNode().getId()); var dataNodeRequest = new DataNodeRequest( childSessionId, @@ -146,13 +152,13 @@ void startComputeOnDataNodes( dataNodeRequest, parentTask, TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(dataNodeListener, ComputeResponse::new, esqlExecutor) + new ActionListenerResponseHandler<>(nodeListener, ComputeResponse::new, esqlExecutor) ); }) ); } } - }, lookupListener::onFailure)); + }, listener::onFailure)); } private void acquireSearchContexts( @@ -341,11 +347,11 @@ private void runBatch(int startBatchIndex) { final var sessionId = request.sessionId(); final int endBatchIndex = Math.min(startBatchIndex + maxConcurrentShards, request.shardIds().size()); List shardIds = request.shardIds().subList(startBatchIndex, endBatchIndex); - ActionListener batchListener = new ActionListener<>() { - final ActionListener ref = computeListener.acquireCompute(); + ActionListener> batchListener = new ActionListener<>() { + final ActionListener> ref = computeListener.acquireCompute(); @Override - public void onResponse(ComputeResponse result) { + public void onResponse(List result) { try { onBatchCompleted(endBatchIndex); } finally { @@ -396,54 +402,64 @@ private void runComputeOnDataNode( String externalId, PhysicalPlan reducePlan, DataNodeRequest request, - ComputeListener computeListener + ActionListener listener ) { - var parentListener = computeListener.acquireAvoid(); - try { - // run compute with target shards - var internalSink = exchangeService.createSinkHandler(request.sessionId(), request.pragmas().exchangeBufferSize()); - DataNodeRequestExecutor dataNodeRequestExecutor = new DataNodeRequestExecutor( - request, - task, - internalSink, - request.configuration().pragmas().maxConcurrentShardsPerNode(), - computeListener - ); - dataNodeRequestExecutor.start(); - // run the node-level reduction - var externalSink = exchangeService.getSinkHandler(externalId); - task.addListener(() -> exchangeService.finishSinkHandler(externalId, new TaskCancelledException(task.getReasonCancelled()))); - var exchangeSource = new ExchangeSourceHandler(1, esqlExecutor, computeListener.acquireAvoid()); - exchangeSource.addRemoteSink(internalSink::fetchPageAsync, true, 1, ActionListener.noop()); - ActionListener reductionListener = computeListener.acquireCompute(); - computeService.runCompute( - task, - new ComputeContext( - request.sessionId(), - request.clusterAlias(), - List.of(), - request.configuration(), - new FoldContext(request.pragmas().foldLimit().getBytes()), - exchangeSource, - externalSink - ), - reducePlan, - ActionListener.wrap(resp -> { - // don't return until all pages are fetched - externalSink.addCompletionListener(ActionListener.running(() -> { - exchangeService.finishSinkHandler(externalId, null); - reductionListener.onResponse(resp); - })); - }, e -> { - exchangeService.finishSinkHandler(externalId, e); - reductionListener.onFailure(e); - }) - ); - parentListener.onResponse(null); - } catch (Exception e) { - exchangeService.finishSinkHandler(externalId, e); - exchangeService.finishSinkHandler(request.sessionId(), e); - parentListener.onFailure(e); + try ( + ComputeListener computeListener = new ComputeListener( + transportService.getThreadPool(), + computeService.cancelQueryOnFailure(task), + listener.map(ComputeResponse::new) + ) + ) { + var parentListener = computeListener.acquireAvoid(); + try { + // run compute with target shards + var internalSink = exchangeService.createSinkHandler(request.sessionId(), request.pragmas().exchangeBufferSize()); + DataNodeRequestExecutor dataNodeRequestExecutor = new DataNodeRequestExecutor( + request, + task, + internalSink, + request.configuration().pragmas().maxConcurrentShardsPerNode(), + computeListener + ); + dataNodeRequestExecutor.start(); + // run the node-level reduction + var externalSink = exchangeService.getSinkHandler(externalId); + task.addListener( + () -> exchangeService.finishSinkHandler(externalId, new TaskCancelledException(task.getReasonCancelled())) + ); + var exchangeSource = new ExchangeSourceHandler(1, esqlExecutor, computeListener.acquireAvoid()); + exchangeSource.addRemoteSink(internalSink::fetchPageAsync, true, 1, ActionListener.noop()); + var reductionListener = computeListener.acquireCompute(); + computeService.runCompute( + task, + new ComputeContext( + request.sessionId(), + request.clusterAlias(), + List.of(), + request.configuration(), + new FoldContext(request.pragmas().foldLimit().getBytes()), + exchangeSource, + externalSink + ), + reducePlan, + ActionListener.wrap(resp -> { + // don't return until all pages are fetched + externalSink.addCompletionListener(ActionListener.running(() -> { + exchangeService.finishSinkHandler(externalId, null); + reductionListener.onResponse(resp); + })); + }, e -> { + exchangeService.finishSinkHandler(externalId, e); + reductionListener.onFailure(e); + }) + ); + parentListener.onResponse(null); + } catch (Exception e) { + exchangeService.finishSinkHandler(externalId, e); + exchangeService.finishSinkHandler(request.sessionId(), e); + parentListener.onFailure(e); + } } } @@ -469,8 +485,6 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T request.indicesOptions(), request.runNodeLevelReduction() ); - try (var computeListener = ComputeListener.create(transportService, (CancellableTask) task, listener)) { - runComputeOnDataNode((CancellableTask) task, sessionId, reductionPlan, request, computeListener); - } + runComputeOnDataNode((CancellableTask) task, sessionId, reductionPlan, request, listener); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ComputeListenerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ComputeListenerTests.java index ae0db127491f7..7db3216d1736d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ComputeListenerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ComputeListenerTests.java @@ -10,29 +10,18 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.cluster.node.VersionInformation; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.compute.operator.DriverProfile; import org.elasticsearch.compute.operator.DriverSleeps; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.tasks.CancellableTask; -import org.elasticsearch.tasks.TaskCancellationService; import org.elasticsearch.tasks.TaskCancelledException; -import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.TransportVersionUtils; -import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.RemoteClusterAware; -import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo; import org.junit.After; import org.junit.Before; -import org.mockito.Mockito; import java.util.ArrayList; import java.util.HashMap; @@ -44,56 +33,30 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; -import static org.elasticsearch.test.tasks.MockTaskManager.SPY_TASK_MANAGER_SETTING; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThan; -import static org.hamcrest.Matchers.lessThanOrEqualTo; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; public class ComputeListenerTests extends ESTestCase { private ThreadPool threadPool; - private TransportService transportService; @Before public void setUpTransportService() { threadPool = new TestThreadPool(getTestName()); - transportService = MockTransportService.createNewService( - Settings.builder().put(SPY_TASK_MANAGER_SETTING.getKey(), true).build(), - VersionInformation.CURRENT, - TransportVersionUtils.randomVersion(), - threadPool - ); - transportService.start(); - TaskCancellationService cancellationService = new TaskCancellationService(transportService); - transportService.getTaskManager().setTaskCancellationService(cancellationService); - Mockito.clearInvocations(transportService.getTaskManager()); } @After public void shutdownTransportService() { - transportService.close(); terminate(threadPool); } - private CancellableTask newTask() { - return new CancellableTask( - randomIntBetween(1, 100), - "test-type", - "test-action", - "test-description", - TaskId.EMPTY_TASK_ID, - Map.of() - ); - } - - private ComputeResponse randomResponse(boolean includeExecutionInfo) { + private List randomProfiles() { int numProfiles = randomIntBetween(0, 2); List profiles = new ArrayList<>(numProfiles); for (int i = 0; i < numProfiles; i++) { @@ -109,51 +72,23 @@ private ComputeResponse randomResponse(boolean includeExecutionInfo) { ) ); } - if (includeExecutionInfo) { - return new ComputeResponse( - profiles, - new TimeValue(randomLongBetween(0, 50000), TimeUnit.NANOSECONDS), - 10, - 10, - randomIntBetween(0, 3), - 0 - ); - } else { - return new ComputeResponse(profiles); - } + return profiles; } public void testEmpty() { - PlainActionFuture results = new PlainActionFuture<>(); - EsqlExecutionInfo executionInfo = new EsqlExecutionInfo(randomBoolean()); - try ( - ComputeListener ignored = ComputeListener.create( - RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, - transportService, - newTask(), - executionInfo, - results - ) - ) { + PlainActionFuture> results = new PlainActionFuture<>(); + try (var ignored = new ComputeListener(threadPool, () -> {}, results)) { assertFalse(results.isDone()); } assertTrue(results.isDone()); - assertThat(results.actionGet(10, TimeUnit.SECONDS).getProfiles(), empty()); + assertThat(results.actionGet(10, TimeUnit.SECONDS), empty()); } public void testCollectComputeResults() { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); List allProfiles = new ArrayList<>(); - EsqlExecutionInfo executionInfo = new EsqlExecutionInfo(randomBoolean()); - try ( - ComputeListener computeListener = ComputeListener.create( - RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, - transportService, - newTask(), - executionInfo, - future - ) - ) { + AtomicInteger onFailure = new AtomicInteger(); + try (var computeListener = new ComputeListener(threadPool, onFailure::incrementAndGet, future)) { int tasks = randomIntBetween(1, 100); for (int t = 0; t < tasks; t++) { if (randomBoolean()) { @@ -164,261 +99,23 @@ public void testCollectComputeResults() { threadPool.generic() ); } else { - ComputeResponse resp = randomResponse(false); - allProfiles.addAll(resp.getProfiles()); - ActionListener subListener = computeListener.acquireCompute(); + var profiles = randomProfiles(); + allProfiles.addAll(profiles); + ActionListener> subListener = computeListener.acquireCompute(); threadPool.schedule( - ActionRunnable.wrap(subListener, l -> l.onResponse(resp)), + ActionRunnable.wrap(subListener, l -> l.onResponse(profiles)), TimeValue.timeValueNanos(between(0, 100)), threadPool.generic() ); } } } - ComputeResponse response = future.actionGet(10, TimeUnit.SECONDS); - assertThat( - response.getProfiles().stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum)), - equalTo(allProfiles.stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum))) - ); - Mockito.verifyNoInteractions(transportService.getTaskManager()); - } - - /** - * Tests the acquireCompute functionality running on the querying ("local") cluster, that is waiting upon - * a ComputeResponse from a remote cluster. The acquireCompute code under test should fill in the - * {@link EsqlExecutionInfo.Cluster} with the information in the ComputeResponse from the remote cluster. - */ - public void testAcquireComputeCCSListener() { - PlainActionFuture future = new PlainActionFuture<>(); - List allProfiles = new ArrayList<>(); - String remoteAlias = "rc1"; - EsqlExecutionInfo executionInfo = new EsqlExecutionInfo(true); - executionInfo.swapCluster(remoteAlias, (k, v) -> new EsqlExecutionInfo.Cluster(remoteAlias, "logs*", false)); - executionInfo.markEndPlanning(); // set planning took time, so it can be used to calculate per-cluster took time - try ( - ComputeListener computeListener = ComputeListener.create( - // 'whereRunning' for this test is the local cluster, waiting for a response from the remote cluster - RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, - transportService, - newTask(), - executionInfo, - future - ) - ) { - int tasks = randomIntBetween(1, 5); - for (int t = 0; t < tasks; t++) { - ComputeResponse resp = randomResponse(true); - allProfiles.addAll(resp.getProfiles()); - // Use remoteAlias here to indicate what remote cluster alias the listener is waiting to hear back from - ActionListener subListener = computeListener.acquireCompute(remoteAlias); - threadPool.schedule( - ActionRunnable.wrap(subListener, l -> l.onResponse(resp)), - TimeValue.timeValueNanos(between(0, 100)), - threadPool.generic() - ); - } - } - ComputeResponse response = future.actionGet(10, TimeUnit.SECONDS); + List profiles = future.actionGet(10, TimeUnit.SECONDS); assertThat( - response.getProfiles().stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum)), + profiles.stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum)), equalTo(allProfiles.stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum))) ); - - assertTrue(executionInfo.isCrossClusterSearch()); - EsqlExecutionInfo.Cluster rc1Cluster = executionInfo.getCluster(remoteAlias); - assertThat(rc1Cluster.getTook().millis(), greaterThanOrEqualTo(0L)); - assertThat(rc1Cluster.getTotalShards(), equalTo(10)); - assertThat(rc1Cluster.getSuccessfulShards(), equalTo(10)); - assertThat(rc1Cluster.getSkippedShards(), greaterThanOrEqualTo(0)); - assertThat(rc1Cluster.getSkippedShards(), lessThanOrEqualTo(3)); - assertThat(rc1Cluster.getFailedShards(), equalTo(0)); - assertThat(rc1Cluster.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL)); - - Mockito.verifyNoInteractions(transportService.getTaskManager()); - } - - /** - * Tests the acquireCompute functionality running on the querying ("local") cluster, that is waiting upon - * a ComputeResponse from a remote cluster where we simulate connecting to a remote cluster running a version - * of ESQL that does not record and return CCS metadata. Ensure that the local cluster {@link EsqlExecutionInfo} - * is properly updated with took time and shard info is left unset. - */ - public void testAcquireComputeCCSListenerWithComputeResponseFromOlderCluster() { - PlainActionFuture future = new PlainActionFuture<>(); - List allProfiles = new ArrayList<>(); - String remoteAlias = "rc1"; - EsqlExecutionInfo executionInfo = new EsqlExecutionInfo(true); - executionInfo.swapCluster(remoteAlias, (k, v) -> new EsqlExecutionInfo.Cluster(remoteAlias, "logs*", false)); - executionInfo.markEndPlanning(); // set planning took time, so it can be used to calculate per-cluster took time - try ( - ComputeListener computeListener = ComputeListener.create( - // 'whereRunning' for this test is the local cluster, waiting for a response from the remote cluster - RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, - transportService, - newTask(), - executionInfo, - future - ) - ) { - int tasks = randomIntBetween(1, 5); - for (int t = 0; t < tasks; t++) { - ComputeResponse resp = randomResponse(false); // older clusters will not return CCS metadata in response - allProfiles.addAll(resp.getProfiles()); - // Use remoteAlias here to indicate what remote cluster alias the listener is waiting to hear back from - ActionListener subListener = computeListener.acquireCompute(remoteAlias); - threadPool.schedule( - ActionRunnable.wrap(subListener, l -> l.onResponse(resp)), - TimeValue.timeValueNanos(between(0, 100)), - threadPool.generic() - ); - } - } - ComputeResponse response = future.actionGet(10, TimeUnit.SECONDS); - assertThat( - response.getProfiles().stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum)), - equalTo(allProfiles.stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum))) - ); - - assertTrue(executionInfo.isCrossClusterSearch()); - EsqlExecutionInfo.Cluster rc1Cluster = executionInfo.getCluster(remoteAlias); - assertThat(rc1Cluster.getTook().millis(), greaterThanOrEqualTo(0L)); - assertNull(rc1Cluster.getTotalShards()); - assertNull(rc1Cluster.getSuccessfulShards()); - assertNull(rc1Cluster.getSkippedShards()); - assertNull(rc1Cluster.getFailedShards()); - assertThat(rc1Cluster.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL)); - - Mockito.verifyNoInteractions(transportService.getTaskManager()); - } - - /** - * Run an acquireCompute cycle on the RemoteCluster. - * AcquireCompute will fill in the took time on the EsqlExecutionInfo (the shard info is filled in before this, - * so we just hard code them in the Cluster in this test) and then a ComputeResponse will be created in the refs - * listener and returned with the shard and took time info. - */ - public void testAcquireComputeRunningOnRemoteClusterFillsInTookTime() { - PlainActionFuture future = new PlainActionFuture<>(); - List allProfiles = new ArrayList<>(); - EsqlExecutionInfo executionInfo = new EsqlExecutionInfo(true); - String remoteAlias = "rc1"; - executionInfo.swapCluster( - remoteAlias, - (k, v) -> new EsqlExecutionInfo.Cluster( - remoteAlias, - "logs*", - false, - EsqlExecutionInfo.Cluster.Status.RUNNING, - 10, - 10, - 3, - 0, - null, - null // to be filled in the acquireCompute listener - ) - ); - try ( - ComputeListener computeListener = ComputeListener.create( - // whereRunning=remoteAlias simulates running on the remote cluster - remoteAlias, - transportService, - newTask(), - executionInfo, - future - ) - ) { - int tasks = randomIntBetween(1, 5); - for (int t = 0; t < tasks; t++) { - ComputeResponse resp = randomResponse(true); - allProfiles.addAll(resp.getProfiles()); - ActionListener subListener = computeListener.acquireCompute(remoteAlias); - threadPool.schedule( - ActionRunnable.wrap(subListener, l -> l.onResponse(resp)), - TimeValue.timeValueNanos(between(0, 100)), - threadPool.generic() - ); - } - } - ComputeResponse response = future.actionGet(10, TimeUnit.SECONDS); - assertThat( - response.getProfiles().stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum)), - equalTo(allProfiles.stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum))) - ); - assertThat(response.getTotalShards(), equalTo(10)); - assertThat(response.getSuccessfulShards(), equalTo(10)); - assertThat(response.getSkippedShards(), equalTo(3)); - assertThat(response.getFailedShards(), equalTo(0)); - // check that the took time was filled in on the ExecutionInfo for the remote cluster and put into the ComputeResponse to be - // sent back to the querying cluster - assertThat(response.getTook().millis(), greaterThanOrEqualTo(0L)); - assertThat(executionInfo.getCluster(remoteAlias).getTook().millis(), greaterThanOrEqualTo(0L)); - assertThat(executionInfo.getCluster(remoteAlias).getTook(), equalTo(response.getTook())); - assertThat(executionInfo.getCluster(remoteAlias).getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL)); - - Mockito.verifyNoInteractions(transportService.getTaskManager()); - } - - /** - * Run an acquireCompute cycle on the RemoteCluster. - * AcquireCompute will fill in the took time on the EsqlExecutionInfo (the shard info is filled in before this, - * so we just hard code them in the Cluster in this test) and then a ComputeResponse will be created in the refs - * listener and returned with the shard and took time info. - */ - public void testAcquireComputeRunningOnQueryingClusterFillsInTookTime() { - PlainActionFuture future = new PlainActionFuture<>(); - List allProfiles = new ArrayList<>(); - EsqlExecutionInfo executionInfo = new EsqlExecutionInfo(true); - String localCluster = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; - // we need a remote cluster in the ExecutionInfo in order to simulate a CCS, since ExecutionInfo is only - // fully filled in for cross-cluster searches - executionInfo.swapCluster(localCluster, (k, v) -> new EsqlExecutionInfo.Cluster(localCluster, "logs*", false)); - executionInfo.swapCluster("my_remote", (k, v) -> new EsqlExecutionInfo.Cluster("my_remote", "my_remote:logs*", false)); - - // before acquire-compute, can-match (SearchShards) runs filling in total shards and skipped shards, so simulate that here - executionInfo.swapCluster( - localCluster, - (k, v) -> new EsqlExecutionInfo.Cluster.Builder(v).setTotalShards(10).setSkippedShards(1).build() - ); - executionInfo.swapCluster( - "my_remote", - (k, v) -> new EsqlExecutionInfo.Cluster.Builder(v).setTotalShards(10).setSkippedShards(1).build() - ); - - try ( - ComputeListener computeListener = ComputeListener.create( - // whereRunning=localCluster simulates running on the querying cluster - localCluster, - transportService, - newTask(), - executionInfo, - future - ) - ) { - int tasks = randomIntBetween(1, 5); - for (int t = 0; t < tasks; t++) { - ComputeResponse resp = randomResponse(true); - allProfiles.addAll(resp.getProfiles()); - ActionListener subListener = computeListener.acquireCompute(localCluster); - threadPool.schedule( - ActionRunnable.wrap(subListener, l -> l.onResponse(resp)), - TimeValue.timeValueNanos(between(0, 100)), - threadPool.generic() - ); - } - } - ComputeResponse response = future.actionGet(10, TimeUnit.SECONDS); - assertThat( - response.getProfiles().stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum)), - equalTo(allProfiles.stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum))) - ); - // check that the took time was filled in on the ExecutionInfo for the remote cluster and put into the ComputeResponse to be - // sent back to the querying cluster - assertNull("took time is not added to the ComputeResponse on the querying cluster", response.getTook()); - assertThat(executionInfo.getCluster(localCluster).getTook().millis(), greaterThanOrEqualTo(0L)); - // once all the took times have been gathered from the tasks, the refs callback will set execution status to SUCCESSFUL - assertThat(executionInfo.getCluster(localCluster).getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL)); - - Mockito.verifyNoInteractions(transportService.getTaskManager()); + assertThat(onFailure.get(), equalTo(0)); } public void testCancelOnFailure() throws Exception { @@ -429,22 +126,13 @@ public void testCancelOnFailure() throws Exception { ); int successTasks = between(1, 50); int failedTasks = between(1, 100); - PlainActionFuture rootListener = new PlainActionFuture<>(); - CancellableTask rootTask = newTask(); - EsqlExecutionInfo execInfo = new EsqlExecutionInfo(randomBoolean()); - try ( - ComputeListener computeListener = ComputeListener.create( - RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, - transportService, - rootTask, - execInfo, - rootListener - ) - ) { + PlainActionFuture> rootListener = new PlainActionFuture<>(); + final AtomicInteger onFailure = new AtomicInteger(); + try (var computeListener = new ComputeListener(threadPool, onFailure::incrementAndGet, rootListener)) { for (int i = 0; i < successTasks; i++) { - ActionListener subListener = computeListener.acquireCompute(); + ActionListener> subListener = computeListener.acquireCompute(); threadPool.schedule( - ActionRunnable.wrap(subListener, l -> l.onResponse(randomResponse(false))), + ActionRunnable.wrap(subListener, l -> l.onResponse(randomProfiles())), TimeValue.timeValueNanos(between(0, 100)), threadPool.generic() ); @@ -465,18 +153,17 @@ public void testCancelOnFailure() throws Exception { assertNotNull(failure); assertThat(cause, instanceOf(CircuitBreakingException.class)); assertThat(failure.getSuppressed().length, lessThan(10)); - Mockito.verify(transportService.getTaskManager(), Mockito.times(1)) - .cancelTaskAndDescendants(eq(rootTask), eq("cancelled on failure"), eq(false), any()); + assertThat(onFailure.get(), greaterThanOrEqualTo(1)); } public void testCollectWarnings() throws Exception { List allProfiles = new ArrayList<>(); Map> allWarnings = new HashMap<>(); - ActionListener rootListener = new ActionListener<>() { + ActionListener> rootListener = new ActionListener<>() { @Override - public void onResponse(ComputeResponse result) { + public void onResponse(List result) { assertThat( - result.getProfiles().stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum)), + result.stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum)), equalTo(allProfiles.stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum))) ); Map> responseHeaders = threadPool.getThreadContext() @@ -492,14 +179,12 @@ public void onFailure(Exception e) { throw new AssertionError(e); } }; + AtomicInteger onFailure = new AtomicInteger(); CountDownLatch latch = new CountDownLatch(1); - EsqlExecutionInfo executionInfo = new EsqlExecutionInfo(randomBoolean()); try ( - ComputeListener computeListener = ComputeListener.create( - RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, - transportService, - newTask(), - executionInfo, + var computeListener = new ComputeListener( + threadPool, + onFailure::incrementAndGet, ActionListener.runAfter(rootListener, latch::countDown) ) ) { @@ -513,8 +198,8 @@ public void onFailure(Exception e) { threadPool.generic() ); } else { - ComputeResponse resp = randomResponse(false); - allProfiles.addAll(resp.getProfiles()); + var resp = randomProfiles(); + allProfiles.addAll(resp); int numWarnings = randomIntBetween(1, 5); Map warnings = new HashMap<>(); for (int i = 0; i < numWarnings; i++) { @@ -523,7 +208,7 @@ public void onFailure(Exception e) { for (Map.Entry e : warnings.entrySet()) { allWarnings.computeIfAbsent(e.getKey(), v -> new HashSet<>()).add(e.getValue()); } - ActionListener subListener = computeListener.acquireCompute(); + var subListener = computeListener.acquireCompute(); threadPool.schedule(ActionRunnable.wrap(subListener, l -> { for (Map.Entry e : warnings.entrySet()) { threadPool.getThreadContext().addResponseHeader(e.getKey(), e.getValue()); @@ -534,6 +219,6 @@ public void onFailure(Exception e) { } } assertTrue(latch.await(10, TimeUnit.SECONDS)); - Mockito.verifyNoInteractions(transportService.getTaskManager()); + assertThat(onFailure.get(), equalTo(0)); } }