From f0a5e25fcac5d6f82b2084d92361657aac064596 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Thu, 30 Jan 2025 12:09:25 +0100 Subject: [PATCH] [8.x] [Inference API] Add node-local rate limiting for the inference API (#120400) (#121251) * [Inference API] Add node-local rate limiting for the inference API (#120400) * Add node-local rate limiting for the inference API * Fix integration tests by using new LocalStateInferencePlugin instead of InferencePlugin and adjust formatting. * Correct feature flag name * Add more docs, reorganize methods and make some methods package private * Clarify comment in BaseInferenceActionRequest * Fix wrong merge * Fix checkstyle * Fix checkstyle in tests * Check that the service we want to the read the rate limit config for actually exists * [CI] Auto commit changes from spotless * checkStyle apply * Update docs/changelog/120400.yaml * Move rate limit division logic to RequestExecutorService * Spotless apply * Remove debug sout * Adding a few suggestions * Adam feedback * Fix compilation error * [CI] Auto commit changes from spotless * Add BWC test case to InferenceActionRequestTests * Add BWC test case to UnifiedCompletionActionRequestTests * Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java Co-authored-by: Adam Demjen * Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java Co-authored-by: Adam Demjen * Remove addressed TODO * Spotless apply * Only use new rate limit specific feature flag * Use ThreadLocalRandom * [CI] Auto commit changes from spotless * Use Randomness.get() * [CI] Auto commit changes from spotless * Fix import * Use ConcurrentHashMap in InferenceServiceNodeLocalRateLimitCalculator * Check for null value in getRateLimitAssignment and remove AtomicReference * Remove newAssignments * Up the default rate limit for completions * Put deprecated feature flag back in * Check feature flag in BaseTransportInferenceAction * spotlessApply * Export inference.common * Do not export inference.common * Provide noop rate limit calculator, if feature flag is disabled * Add proper dependency injection --------- Co-authored-by: elasticsearchmachine Co-authored-by: Jonathan Buttner Co-authored-by: Adam Demjen * Use .get(0) as getFirst() doesn't exist in 8.18 (probably JDK difference?) --------- Co-authored-by: elasticsearchmachine Co-authored-by: Jonathan Buttner Co-authored-by: Adam Demjen --- docs/changelog/120400.yaml | 5 + .../org/elasticsearch/TransportVersions.java | 1 + .../action/BaseInferenceActionRequest.java | 31 +++ .../action/InferenceActionRequestTests.java | 23 ++ .../UnifiedCompletionActionRequestTests.java | 20 ++ .../xpack/inference/InferencePlugin.java | 37 +++- .../action/BaseTransportInferenceAction.java | 178 +++++++++++++-- .../action/TransportInferenceAction.java | 13 +- ...sportUnifiedCompletionInferenceAction.java | 13 +- ...nceAPIClusterAwareRateLimitingFeature.java | 28 +++ ...ceServiceNodeLocalRateLimitCalculator.java | 197 +++++++++++++++++ .../InferenceServiceRateLimitCalculator.java | 18 ++ .../NoopNodeLocalRateLimitCalculator.java | 27 +++ .../inference/common/RateLimitAssignment.java | 19 ++ .../xpack/inference/common/RateLimiter.java | 2 +- .../AmazonBedrockRequestSender.java | 5 + .../external/http/RequestExecutor.java | 2 + .../http/sender/HttpRequestSender.java | 4 + .../http/sender/RequestExecutorService.java | 56 ++++- .../external/http/sender/RequestManager.java | 2 + .../external/http/sender/Sender.java | 2 + .../inference/services/SenderService.java | 2 +- ...renceServiceCompletionServiceSettings.java | 2 +- .../BaseTransportInferenceActionTestCase.java | 22 +- .../action/TransportInferenceActionTests.java | 130 ++++++++++- ...TransportUnifiedCompletionActionTests.java | 13 +- ...viceNodeLocalRateLimitCalculatorTests.java | 205 ++++++++++++++++++ .../AmazonBedrockMockRequestSender.java | 5 + ...ServiceCompletionServiceSettingsTests.java | 2 +- 29 files changed, 1015 insertions(+), 49 deletions(-) create mode 100644 docs/changelog/120400.yaml create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceAPIClusterAwareRateLimitingFeature.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceRateLimitCalculator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/NoopNodeLocalRateLimitCalculator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimitAssignment.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java diff --git a/docs/changelog/120400.yaml b/docs/changelog/120400.yaml new file mode 100644 index 0000000000000..57d40730e0c8d --- /dev/null +++ b/docs/changelog/120400.yaml @@ -0,0 +1,5 @@ +pr: 120400 +summary: "[Inference API] Add node-local rate limiting for the inference API" +area: Machine Learning +type: feature +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index dd81e8f10445a..603fc6a32f078 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -176,6 +176,7 @@ static TransportVersion def(int id) { public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0); public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0); public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_00_0); + public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java index e426574c52ce6..855b0bdebb417 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java @@ -7,20 +7,35 @@ package org.elasticsearch.xpack.core.inference.action; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.TaskType; import java.io.IOException; +/** + * Base class for inference action requests. Tracks request routing state to prevent potential routing loops + * and supports both streaming and non-streaming inference operations. + */ public abstract class BaseInferenceActionRequest extends ActionRequest { + private boolean hasBeenRerouted; + public BaseInferenceActionRequest() { super(); } public BaseInferenceActionRequest(StreamInput in) throws IOException { super(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) { + this.hasBeenRerouted = in.readBoolean(); + } else { + // For backwards compatibility, we treat all inference requests coming from ES nodes having + // a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior. + this.hasBeenRerouted = true; + } } public abstract boolean isStreaming(); @@ -28,4 +43,20 @@ public BaseInferenceActionRequest(StreamInput in) throws IOException { public abstract TaskType getTaskType(); public abstract String getInferenceEntityId(); + + public void setHasBeenRerouted(boolean hasBeenRerouted) { + this.hasBeenRerouted = hasBeenRerouted; + } + + public boolean hasBeenRerouted() { + return hasBeenRerouted; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) { + out.writeBoolean(hasBeenRerouted); + } + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index 01c0ff88be222..e9f4df7a523ad 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -386,6 +386,29 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED)); } + public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException { + var instance = new InferenceAction.Request( + TaskType.TEXT_EMBEDDING, + "model", + null, + List.of("input"), + Map.of(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + false + ); + + InferenceAction.Request deserializedInstance = copyWriteable( + instance, + getNamedWriteableRegistry(), + instanceReader(), + TransportVersions.V_8_13_0 + ); + + // Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version + assertTrue(deserializedInstance.hasBeenRerouted()); + } + public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsUnspecified_VersionBeforeUnspecifiedIntroduced() { assertThat(getInputTypeToWrite(InputType.UNSPECIFIED, TransportVersions.V_8_12_1), is(InputType.INGEST)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java index f548bfa0709ed..ceb7c9853a0f4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.inference.action; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.TimeValue; @@ -65,6 +66,25 @@ public void testValidation_ReturnsNull_When_TaskType_IsAny() { assertNull(request.validate()); } + public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException { + var instance = new UnifiedCompletionAction.Request( + "model", + TaskType.ANY, + UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), + TimeValue.timeValueSeconds(10) + ); + + UnifiedCompletionAction.Request deserializedInstance = copyWriteable( + instance, + getNamedWriteableRegistry(), + instanceReader(), + TransportVersions.ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION + ); + + // Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version + assertTrue(deserializedInstance.hasBeenRerouted()); + } + @Override protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) { return instance; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index acb8b962fcb4d..a3aaf8127d935 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -72,6 +72,9 @@ import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; +import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; +import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; +import org.elasticsearch.xpack.inference.common.NoopNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -133,6 +136,7 @@ import java.util.function.Supplier; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG; public class InferencePlugin extends Plugin implements @@ -229,6 +233,7 @@ public List getRestHandlers( @Override public Collection createComponents(PluginServices services) { + var components = new ArrayList<>(); var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService()); var truncator = new Truncator(settings, services.clusterService()); serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator)); @@ -297,20 +302,38 @@ public Collection createComponents(PluginServices services) { // This must be done after the HttpRequestSenderFactory is created so that the services can get the // reference correctly - var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); - registry.init(services.client()); - for (var service : registry.getServices().values()) { + var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext); + serviceRegistry.init(services.client()); + for (var service : serviceRegistry.getServices().values()) { service.defaultConfigIds().forEach(modelRegistry::addDefaultIds); } - inferenceServiceRegistry.set(registry); + inferenceServiceRegistry.set(serviceRegistry); - var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), registry, modelRegistry); + var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry); shardBulkInferenceActionFilter.set(actionFilter); var meterRegistry = services.telemetryProvider().getMeterRegistry(); - var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); + var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); + + components.add(serviceRegistry); + components.add(modelRegistry); + components.add(httpClientManager); + components.add(inferenceStats); + + // Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting, + // if the rate limiting feature flags are enabled, otherwise provide noop implementation + InferenceServiceRateLimitCalculator calculator; + if (INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG.isEnabled()) { + calculator = new InferenceServiceNodeLocalRateLimitCalculator(services.clusterService(), serviceRegistry); + } else { + calculator = new NoopNodeLocalRateLimitCalculator(); + } + + // Add binding for interface -> implementation + components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator)); + components.add(calculator); - return List.of(modelRegistry, registry, httpClientManager, stats); + return components; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index b6c7d26b36f9a..08d74a36d6503 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -13,6 +13,10 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.ChunkedToXContent; @@ -27,24 +31,42 @@ import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; +import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.Executor; import java.util.function.Supplier; import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; +import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG; import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; +/** + * Base class for transport actions that handle inference requests. + * Works in conjunction with {@link InferenceServiceNodeLocalRateLimitCalculator} to + * route requests to specific nodes, iff they support "node-local" rate limiting, which is described in detail + * in {@link InferenceServiceNodeLocalRateLimitCalculator}. + * + * @param The specific type of inference request being handled + */ public abstract class BaseTransportInferenceAction extends HandledTransportAction< Request, InferenceAction.Response> { @@ -57,6 +79,11 @@ public abstract class BaseTransportInferenceAction requestReader + Writeable.Reader requestReader, + InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, + NodeClient nodeClient, + ThreadPool threadPool ) { super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.licenseState = licenseState; @@ -75,8 +105,24 @@ public BaseTransportInferenceAction( this.serviceRegistry = serviceRegistry; this.inferenceStats = inferenceStats; this.streamingTaskManager = streamingTaskManager; + this.inferenceServiceRateLimitCalculator = inferenceServiceNodeLocalRateLimitCalculator; + this.nodeClient = nodeClient; + this.threadPool = threadPool; + this.transportService = transportService; + this.random = Randomness.get(); } + protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel); + + protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel); + + protected abstract void doInference( + Model model, + Request request, + InferenceService service, + ActionListener listener + ); + @Override protected void doExecute(Task task, Request request, ActionListener listener) { if (INFERENCE_API_FEATURE.check(licenseState) == false) { @@ -87,31 +133,32 @@ protected void doExecute(Task task, Request request, ActionListener { - var service = serviceRegistry.getService(unparsedModel.service()); + var serviceName = unparsedModel.service(); + try { - validationHelper(service::isEmpty, () -> unknownServiceException(unparsedModel.service(), request.getInferenceEntityId())); - validationHelper( - () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false, - () -> requestModelTaskTypeMismatchException(request.getTaskType(), unparsedModel.taskType()) - ); - validationHelper( - () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel), - () -> createInvalidTaskTypeException(request, unparsedModel) - ); + validateRequest(request, unparsedModel); } catch (Exception e) { recordMetrics(unparsedModel, timer, e); listener.onFailure(e); return; } - var model = service.get() - .parsePersistedConfigWithSecrets( + var service = serviceRegistry.getService(serviceName).get(); + var routingDecision = determineRouting(serviceName, request, unparsedModel); + + if (routingDecision.currentNodeShouldHandleRequest()) { + var model = service.parsePersistedConfigWithSecrets( unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings(), unparsedModel.secrets() ); - inferOnServiceWithMetrics(model, request, service.get(), timer, listener); + inferOnServiceWithMetrics(model, request, service, timer, listener); + } else { + // Reroute request + request.setHasBeenRerouted(true); + rerouteRequest(request, listener, routingDecision.targetNode); + } }, e -> { try { inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); @@ -124,15 +171,95 @@ protected void doExecute(Task task, Request request, ActionListener unknownServiceException(serviceName, request.getInferenceEntityId())); + validationHelper( + () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false, + () -> requestModelTaskTypeMismatchException(requestTaskType, unparsedModel.taskType()) + ); + validationHelper( + () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel), + () -> createInvalidTaskTypeException(request, unparsedModel) + ); + } + + private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel) { + if (INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG.isEnabled() == false) { + return NodeRoutingDecision.handleLocally(); + } + + var modelTaskType = unparsedModel.taskType(); + + // Rerouting not supported or request was already rerouted + if (inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceName, modelTaskType) == false + || request.hasBeenRerouted()) { + return NodeRoutingDecision.handleLocally(); + } + + var rateLimitAssignment = inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceName, modelTaskType); + + // No assignment yet + if (rateLimitAssignment == null) { + return NodeRoutingDecision.handleLocally(); + } + + var responsibleNodes = rateLimitAssignment.responsibleNodes(); + + // Empty assignment + if (responsibleNodes == null || responsibleNodes.isEmpty()) { + return NodeRoutingDecision.handleLocally(); + } + + var nodeToHandleRequest = responsibleNodes.get(random.nextInt(responsibleNodes.size())); + String localNodeId = nodeClient.getLocalNodeId(); + + // The drawn node is the current node + if (nodeToHandleRequest.getId().equals(localNodeId)) { + return NodeRoutingDecision.handleLocally(); + } + + // Reroute request + return NodeRoutingDecision.routeTo(nodeToHandleRequest); + } + private static void validationHelper(Supplier validationFailure, Supplier exceptionCreator) { if (validationFailure.get()) { throw exceptionCreator.get(); } } - protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel); - - protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel); + private void rerouteRequest(Request request, ActionListener listener, DiscoveryNode nodeToHandleRequest) { + transportService.sendRequest( + nodeToHandleRequest, + InferenceAction.NAME, + request, + new TransportResponseHandler() { + @Override + public Executor executor() { + return threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); + } + + @Override + public void handleResponse(InferenceAction.Response response) { + listener.onResponse(response); + } + + @Override + public void handleException(TransportException exp) { + listener.onFailure(exp); + } + + @Override + public InferenceAction.Response read(StreamInput in) throws IOException { + return new InferenceAction.Response(in); + } + } + ); + } private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { try { @@ -185,13 +312,6 @@ private void inferOnService(Model model, Request request, InferenceService servi } } - protected abstract void doInference( - Model model, - Request request, - InferenceService service, - ActionListener listener - ); - private ElasticsearchStatusException unsupportedStreamingTaskException(Request request, InferenceService service) { var supportedTasks = service.supportedStreamingTasks(); if (supportedTasks.isEmpty()) { @@ -259,4 +379,14 @@ public void onComplete() { super.onComplete(); } } + + private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) { + static NodeRoutingDecision handleLocally() { + return new NodeRoutingDecision(true, null); + } + + static NodeRoutingDecision routeTo(DiscoveryNode node) { + return new NodeRoutingDecision(false, node); + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 24ef0d7d610d0..e8f52e42f5708 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; @@ -17,9 +18,11 @@ import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; @@ -33,7 +36,10 @@ public TransportInferenceAction( ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, - StreamingTaskManager streamingTaskManager + StreamingTaskManager streamingTaskManager, + InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, + NodeClient nodeClient, + ThreadPool threadPool ) { super( InferenceAction.NAME, @@ -44,7 +50,10 @@ public TransportInferenceAction( serviceRegistry, inferenceStats, streamingTaskManager, - InferenceAction.Request::new + InferenceAction.Request::new, + inferenceServiceNodeLocalRateLimitCalculator, + nodeClient, + threadPool ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index 9354ac2a83182..2e3090f2afd59 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; @@ -19,9 +20,11 @@ import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; @@ -35,7 +38,10 @@ public TransportUnifiedCompletionInferenceAction( ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, - StreamingTaskManager streamingTaskManager + StreamingTaskManager streamingTaskManager, + InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, + NodeClient nodeClient, + ThreadPool threadPool ) { super( UnifiedCompletionAction.NAME, @@ -46,7 +52,10 @@ public TransportUnifiedCompletionInferenceAction( serviceRegistry, inferenceStats, streamingTaskManager, - UnifiedCompletionAction.Request::new + UnifiedCompletionAction.Request::new, + inferenceServiceNodeLocalRateLimitCalculator, + nodeClient, + threadPool ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceAPIClusterAwareRateLimitingFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceAPIClusterAwareRateLimitingFeature.java new file mode 100644 index 0000000000000..22de92526ba89 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceAPIClusterAwareRateLimitingFeature.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.common.util.FeatureFlag; +import org.elasticsearch.xpack.inference.InferencePlugin; + +/** + * Cluster aware rate limiting feature flag. When the feature is complete and fully rolled out, this flag will be removed. + * Enable feature via JVM option: `-Des.inference_cluster_aware_rate_limiting_feature_flag_enabled=true`. + * + * This controls, whether {@link InferenceServiceNodeLocalRateLimitCalculator} gets instantiated and + * added as injectable {@link InferencePlugin} component. + */ +public class InferenceAPIClusterAwareRateLimitingFeature { + + public static final FeatureFlag INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG = new FeatureFlag( + "inference_cluster_aware_rate_limiting" + ); + + private InferenceAPIClusterAwareRateLimitingFeature() {} + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java new file mode 100644 index 0000000000000..4778e4cc6d30c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java @@ -0,0 +1,197 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; +import org.elasticsearch.xpack.inference.action.BaseTransportInferenceAction; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Note: {@link InferenceAPIClusterAwareRateLimitingFeature} needs to be enabled for this class to get + * instantiated inside {@link org.elasticsearch.xpack.inference.InferencePlugin} and be available via dependency injection. + * + * Calculates and manages node-local rate limits for inference services based on changes in the cluster topology. + * This calculator calculates a "node-local" rate-limit, which essentially divides the rate limit for a service/task type + * through the number of nodes, which got assigned to this service/task type pair. Without this calculator the rate limit stored + * in the inference endpoint configuration would get effectively multiplied by the number of nodes in a cluster (assuming a ~ uniform + * distribution of requests to the nodes in the cluster). + * + * The calculator works in conjunction with several other components: + * - {@link BaseTransportInferenceAction} - Uses the calculator to determine, whether to reroute a request or not + * - {@link BaseInferenceActionRequest} - Tracks, if the request (an instance of a subclass of {@link BaseInferenceActionRequest}) + * already got re-routed at least once + * - {@link HttpRequestSender} - Provides original rate limits that this calculator divides through the number of nodes + * responsible for a service/task type + */ +public class InferenceServiceNodeLocalRateLimitCalculator implements InferenceServiceRateLimitCalculator { + + public static final Integer DEFAULT_MAX_NODES_PER_GROUPING = 3; + + /** + * Configuration mapping services to their task type rate limiting settings. + * Each service can have multiple configs defining: + * - Which task types support request re-routing and "node-local" rate limit calculation + * - How many nodes should handle requests for each task type, based on cluster size (dynamically calculated or statically provided) + **/ + static final Map> SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS = Map.of( + ElasticInferenceService.NAME, + // TODO: should probably be a map/set + List.of(new NodeLocalRateLimitConfig(TaskType.SPARSE_EMBEDDING, (numNodesInCluster) -> DEFAULT_MAX_NODES_PER_GROUPING)) + ); + + record NodeLocalRateLimitConfig(TaskType taskType, MaxNodesPerGroupingStrategy maxNodesPerGroupingStrategy) {} + + @FunctionalInterface + private interface MaxNodesPerGroupingStrategy { + + Integer calculate(Integer numberOfNodesInCluster); + + } + + private static final Logger logger = LogManager.getLogger(InferenceServiceNodeLocalRateLimitCalculator.class); + + private final InferenceServiceRegistry serviceRegistry; + + private final ConcurrentHashMap> serviceAssignments; + + @Inject + public InferenceServiceNodeLocalRateLimitCalculator(ClusterService clusterService, InferenceServiceRegistry serviceRegistry) { + clusterService.addListener(this); + this.serviceRegistry = serviceRegistry; + this.serviceAssignments = new ConcurrentHashMap<>(); + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + boolean clusterTopologyChanged = event.nodesChanged(); + + // TODO: feature flag per node? We should not reroute to nodes not having eis and/or the inference plugin enabled + // Every node should land on the same grouping by calculation, so no need to put anything into the cluster state + if (clusterTopologyChanged) { + updateAssignments(event); + } + } + + public boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType) { + return SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.getOrDefault(serviceName, Collections.emptyList()) + .stream() + .anyMatch(rateLimitConfig -> taskType.equals(rateLimitConfig.taskType)); + } + + public RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType) { + var assignmentsPerTaskType = serviceAssignments.get(service); + + if (assignmentsPerTaskType == null) { + return null; + } + + return assignmentsPerTaskType.get(taskType); + } + + /** + * Updates instances of {@link RateLimitAssignment} for each service and task type when the cluster topology changes. + * For each service and supported task type, calculates which nodes should handle requests + * and what their local rate limits should be per inference endpoint. + */ + private void updateAssignments(ClusterChangedEvent event) { + var newClusterState = event.state(); + var nodes = newClusterState.nodes().getAllNodes(); + + // Sort nodes by id (every node lands on the same result) + var sortedNodes = nodes.stream().sorted(Comparator.comparing(DiscoveryNode::getId)).toList(); + + // Sort inference services by name (every node lands on the same result) + var sortedServices = new ArrayList<>(serviceRegistry.getServices().values()); + sortedServices.sort(Comparator.comparing(InferenceService::name)); + + for (String serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) { + Optional service = serviceRegistry.getService(serviceName); + + if (service.isPresent()) { + var inferenceService = service.get(); + + for (NodeLocalRateLimitConfig rateLimitConfig : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName)) { + Map perTaskTypeAssignments = new HashMap<>(); + TaskType taskType = rateLimitConfig.taskType(); + + // Calculate node assignments needed for re-routing + var assignedNodes = calculateServiceAssignment(rateLimitConfig.maxNodesPerGroupingStrategy(), sortedNodes); + + // Update rate limits to be "node-local" + var numAssignedNodes = assignedNodes.size(); + updateRateLimits(inferenceService, numAssignedNodes); + + perTaskTypeAssignments.put(taskType, new RateLimitAssignment(assignedNodes)); + serviceAssignments.put(serviceName, perTaskTypeAssignments); + } + } else { + logger.warn( + "Service [{}] is configured for node-local rate limiting but was not found in the service registry", + serviceName + ); + } + } + } + + private List calculateServiceAssignment( + MaxNodesPerGroupingStrategy maxNodesPerGroupingStrategy, + List sortedNodes + ) { + int numberOfNodes = sortedNodes.size(); + int nodesPerGrouping = Math.min(numberOfNodes, maxNodesPerGroupingStrategy.calculate(numberOfNodes)); + + List assignedNodes = new ArrayList<>(); + + // TODO: here we can probably be smarter: if |num nodes in cluster| > |num nodes per task types| + // -> make sure a service provider is not assigned the same nodes for all task types; only relevant as soon as we support more task + // types + for (int j = 0; j < nodesPerGrouping; j++) { + var assignedNode = sortedNodes.get(j % numberOfNodes); + assignedNodes.add(assignedNode); + } + + return assignedNodes; + } + + private void updateRateLimits(InferenceService service, int responsibleNodes) { + if ((service instanceof SenderService) == false) { + return; + } + + SenderService senderService = (SenderService) service; + Sender sender = senderService.getSender(); + // TODO: this needs to take in service and task type as soon as multiple services/task types are supported + sender.updateRateLimitDivisor(responsibleNodes); + } + + InferenceServiceRegistry serviceRegistry() { + return serviceRegistry; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceRateLimitCalculator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceRateLimitCalculator.java new file mode 100644 index 0000000000000..e05637f629ec6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceRateLimitCalculator.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.inference.TaskType; + +public interface InferenceServiceRateLimitCalculator extends ClusterStateListener { + + boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType); + + RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/NoopNodeLocalRateLimitCalculator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/NoopNodeLocalRateLimitCalculator.java new file mode 100644 index 0000000000000..a07217d9e9af7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/NoopNodeLocalRateLimitCalculator.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.inference.TaskType; + +public class NoopNodeLocalRateLimitCalculator implements InferenceServiceRateLimitCalculator { + + @Override + public void clusterChanged(ClusterChangedEvent event) { + // Do nothing + } + + public boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType) { + return false; + } + + public RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType) { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimitAssignment.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimitAssignment.java new file mode 100644 index 0000000000000..de8d85c96271c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimitAssignment.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.cluster.node.DiscoveryNode; + +import java.util.List; + +/** + * Record for storing rate limit assignment information. + * + * @param responsibleNodes - nodes responsible for a certain service and task type + */ +public record RateLimitAssignment(List responsibleNodes) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java index bbc5082d45004..7fde672501ed6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java @@ -55,7 +55,7 @@ public RateLimiter(double accumulatedTokensLimit, double tokensPerTimeUnit, Time setRate(accumulatedTokensLimit, tokensPerTimeUnit, unit); } - public final synchronized void setRate(double newAccumulatedTokensLimit, double newTokensPerTimeUnit, TimeUnit newUnit) { + public synchronized void setRate(double newAccumulatedTokensLimit, double newTokensPerTimeUnit, TimeUnit newUnit) { Objects.requireNonNull(newUnit); if (newAccumulatedTokensLimit < 0) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java index ec4550b036d23..c8e544c26f293 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java @@ -88,6 +88,11 @@ protected AmazonBedrockRequestSender( ); } + @Override + public void updateRateLimitDivisor(int rateLimitDivisor) { + executorService.updateRateLimitDivisor(rateLimitDivisor); + } + @Override public void start() { if (started.compareAndSet(false, true)) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java index 63c042ce8a623..6c7c6e0d114c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java @@ -21,6 +21,8 @@ public interface RequestExecutor { void shutdown(); + void updateRateLimitDivisor(int newDivisor); + boolean isShutdown(); boolean isTerminated(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java index 42671b8166537..689c9e2ec8fc1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java @@ -111,6 +111,10 @@ public void start() { } } + public void updateRateLimitDivisor(int rateLimitDivisor) { + service.updateRateLimitDivisor(rateLimitDivisor); + } + private void waitForStartToComplete() { try { if (startCompleted.await(START_COMPLETED_WAIT_TIME.getSeconds(), TimeUnit.SECONDS) == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index ad1324d0a315f..5ec2acab70596 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -19,6 +19,7 @@ import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue; +import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.common.RateLimiter; import org.elasticsearch.xpack.inference.external.http.RequestExecutor; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; @@ -36,6 +37,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; @@ -92,12 +94,22 @@ interface RateLimiterCreator { RateLimiter create(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit); } + // TODO: for later (after 8.18) + // TODO: pass in divisor to RateLimiterCreator + // TODO: another map for service/task-type-key -> set of RateLimitingEndpointHandler (used for updates; update divisor and then update + // all endpoint handlers) + // TODO: one map for service/task-type-key -> divisor (this gets also read when we create an inference endpoint) + // TODO: divisor value read/writes need to be synchronized in some way + // default for testing static final RateLimiterCreator DEFAULT_RATE_LIMIT_CREATOR = RateLimiter::new; private static final Logger logger = LogManager.getLogger(RequestExecutorService.class); private static final TimeValue RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1); private final ConcurrentMap rateLimitGroupings = new ConcurrentHashMap<>(); + // TODO: add one atomic integer (number of nodes); also explain the assumption and why this works + // TODO: document that this impacts chat completion (and increase the default rate limit) + private final AtomicInteger rateLimitDivisor = new AtomicInteger(1); private final ThreadPool threadPool; private final CountDownLatch startupLatch; private final CountDownLatch terminationLatch = new CountDownLatch(1); @@ -174,6 +186,19 @@ public int queueSize() { return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); } + @Override + public void updateRateLimitDivisor(int numResponsibleNodes) { + // in the unlikely case where we get an invalid value, we'll just ignore it + if (numResponsibleNodes <= 0) { + return; + } + + rateLimitDivisor.set(numResponsibleNodes); + for (var rateLimitingEndpointHandler : rateLimitGroupings.values()) { + rateLimitingEndpointHandler.updateTokensPerTimeUnit(rateLimitDivisor.get()); + } + } + /** * Begin servicing tasks. *

@@ -299,9 +324,12 @@ public void execute( clock, requestManager.rateLimitSettings(), this::isShutdown, - rateLimiterCreator + rateLimiterCreator, + rateLimitDivisor.get() ); + // TODO: add or create/compute if absent set for new map (service/task-type-key -> rate limit endpoint handler) + endpointHandler.init(); return endpointHandler; }); @@ -314,7 +342,7 @@ public void execute( * This allows many requests to be serialized if they are being sent too fast. If the rate limit has not been met they will be sent * as soon as a thread is available. */ - private static class RateLimitingEndpointHandler { + static class RateLimitingEndpointHandler { private static final TimeValue NO_TASKS_AVAILABLE = TimeValue.MAX_VALUE; private static final TimeValue EXECUTED_A_TASK = TimeValue.ZERO; @@ -329,6 +357,8 @@ private static class RateLimitingEndpointHandler { private final Clock clock; private final RateLimiter rateLimiter; private final RequestExecutorServiceSettings requestExecutorServiceSettings; + private final RateLimitSettings rateLimitSettings; + private final Long originalRequestsPerTimeUnit; RateLimitingEndpointHandler( String id, @@ -338,7 +368,8 @@ private static class RateLimitingEndpointHandler { Clock clock, RateLimitSettings rateLimitSettings, Supplier isShutdownMethod, - RateLimiterCreator rateLimiterCreator + RateLimiterCreator rateLimiterCreator, + Integer rateLimitDivisor ) { this.requestExecutorServiceSettings = Objects.requireNonNull(settings); this.id = Objects.requireNonNull(id); @@ -346,6 +377,8 @@ private static class RateLimitingEndpointHandler { this.requestSender = Objects.requireNonNull(requestSender); this.clock = Objects.requireNonNull(clock); this.isShutdownMethod = Objects.requireNonNull(isShutdownMethod); + this.rateLimitSettings = Objects.requireNonNull(rateLimitSettings); + this.originalRequestsPerTimeUnit = rateLimitSettings.requestsPerTimeUnit(); Objects.requireNonNull(rateLimitSettings); Objects.requireNonNull(rateLimiterCreator); @@ -355,12 +388,29 @@ private static class RateLimitingEndpointHandler { rateLimitSettings.timeUnit() ); + this.updateTokensPerTimeUnit(rateLimitDivisor); } public void init() { requestExecutorServiceSettings.registerQueueCapacityCallback(id, this::onCapacityChange); } + /** + * This method is solely called by {@link InferenceServiceNodeLocalRateLimitCalculator} to update + * rate limits, so they're "node-local". + * The general idea is described in {@link InferenceServiceNodeLocalRateLimitCalculator} in more detail. + * + * @param divisor - divisor to divide the initial requests per time unit by + */ + public synchronized void updateTokensPerTimeUnit(Integer divisor) { + double updatedTokensPerTimeUnit = (double) originalRequestsPerTimeUnit / divisor; + rateLimiter.setRate(ACCUMULATED_TOKENS_LIMIT, updatedTokensPerTimeUnit, rateLimitSettings.timeUnit()); + } + + public String id() { + return id; + } + private void onCapacityChange(int capacity) { logger.debug(() -> Strings.format("Executor service grouping [%s] setting queue capacity to [%s]", id, capacity)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java index 853d6fdcb2473..aa606e8c7cc5c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java @@ -30,4 +30,6 @@ void execute( // executePreparedRequest() which will execute all prepared requests aka sends the batch String inferenceEntityId(); + + // TODO: add service() and taskType() } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java index 3975a554586b7..fed92263f9999 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java @@ -27,6 +27,8 @@ void send( ActionListener listener ); + void updateRateLimitDivisor(int rateLimitDivisor); + void sendWithoutQueuing( Logger logger, Request request, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 7c28df4cc0dc4..18378ce9f06b2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -47,7 +47,7 @@ public SenderService(HttpRequestSender.Factory factory, ServiceComponents servic this.serviceComponents = Objects.requireNonNull(serviceComponents); } - protected Sender getSender() { + public Sender getSender() { return sender; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java index 3c8182a7d41a4..293ca1bcb41c0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java @@ -36,7 +36,7 @@ public class ElasticInferenceServiceCompletionServiceSettings extends FilteredXC public static final String NAME = "elastic_inference_service_completion_service_settings"; // TODO what value do we put here? - private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240L); + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(720L); public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index c0fc818e421d0..4fa0a1ec49c74 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -21,11 +22,13 @@ import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.junit.Before; @@ -61,6 +64,9 @@ public abstract class BaseTransportInferenceActionTestCase createAction( ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, - StreamingTaskManager streamingTaskManager + StreamingTaskManager streamingTaskManager, + InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, + NodeClient nodeClient, + ThreadPool threadPool ); protected abstract Request createRequest(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index c303e029cb415..e71d15dbe0420 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -8,16 +8,32 @@ package org.elasticsearch.xpack.inference.action; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.MockLicenseState; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; +import org.elasticsearch.xpack.inference.common.RateLimitAssignment; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import java.util.List; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class TransportInferenceActionTests extends BaseTransportInferenceActionTestCase { @@ -33,7 +49,10 @@ protected BaseTransportInferenceAction createAction( ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, - StreamingTaskManager streamingTaskManager + StreamingTaskManager streamingTaskManager, + InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, + NodeClient nodeClient, + ThreadPool threadPool ) { return new TransportInferenceAction( transportService, @@ -42,7 +61,10 @@ protected BaseTransportInferenceAction createAction( modelRegistry, serviceRegistry, inferenceStats, - streamingTaskManager + streamingTaskManager, + inferenceServiceNodeLocalRateLimitCalculator, + nodeClient, + threadPool ); } @@ -50,4 +72,108 @@ protected BaseTransportInferenceAction createAction( protected InferenceAction.Request createRequest() { return mock(); } + + public void testNoRerouting_WhenTaskTypeNotSupported() { + TaskType unsupportedTaskType = TaskType.COMPLETION; + mockService(listener -> listener.onResponse(mock())); + + when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false); + + var listener = doExecute(unsupportedTaskType); + + verify(listener).onResponse(any()); + // Verify request was handled locally (not rerouted using TransportService) + verify(transportService, never()).sendRequest(any(), any(), any(), any()); + } + + public void testNoRerouting_WhenNoGroupingCalculatedYet() { + mockService(listener -> listener.onResponse(mock())); + + when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true); + when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null); + + var listener = doExecute(taskType); + + verify(listener).onResponse(any()); + // Verify request was handled locally (not rerouted using TransportService) + verify(transportService, never()).sendRequest(any(), any(), any(), any()); + } + + public void testNoRerouting_WhenEmptyNodeList() { + mockService(listener -> listener.onResponse(mock())); + + when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true); + when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn( + new RateLimitAssignment(List.of()) + ); + + var listener = doExecute(taskType); + + verify(listener).onResponse(any()); + // Verify request was handled locally (not rerouted using TransportService) + verify(transportService, never()).sendRequest(any(), any(), any(), any()); + } + + public void testRerouting_ToOtherNode() { + DiscoveryNode otherNode = mock(DiscoveryNode.class); + when(otherNode.getId()).thenReturn("other-node"); + + // The local node is different to the "other-node" responsible for serviceId + when(nodeClient.getLocalNodeId()).thenReturn("local-node"); + when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true); + // Requests for serviceId are always routed to "other-node" + var assignment = new RateLimitAssignment(List.of(otherNode)); + when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment); + + mockService(listener -> listener.onResponse(mock())); + var listener = doExecute(taskType); + + // Verify request was rerouted + verify(transportService).sendRequest(same(otherNode), eq(InferenceAction.NAME), any(), any()); + // Verify local execution didn't happen + verify(listener, never()).onResponse(any()); + } + + public void testRerouting_ToLocalNode_WithoutGoingThroughTransportLayerAgain() { + DiscoveryNode localNode = mock(DiscoveryNode.class); + String localNodeId = "local-node"; + when(localNode.getId()).thenReturn(localNodeId); + + // The local node is the only one responsible for serviceId + when(nodeClient.getLocalNodeId()).thenReturn(localNodeId); + when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true); + var assignment = new RateLimitAssignment(List.of(localNode)); + when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment); + + mockService(listener -> listener.onResponse(mock())); + var listener = doExecute(taskType); + + verify(listener).onResponse(any()); + // Verify request was handled locally (not rerouted using TransportService) + verify(transportService, never()).sendRequest(any(), any(), any(), any()); + } + + public void testRerouting_HandlesTransportException_FromOtherNode() { + DiscoveryNode otherNode = mock(DiscoveryNode.class); + when(otherNode.getId()).thenReturn("other-node"); + + when(nodeClient.getLocalNodeId()).thenReturn("local-node"); + when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true); + var assignment = new RateLimitAssignment(List.of(otherNode)); + when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment); + + mockService(listener -> listener.onResponse(mock())); + + TransportException expectedException = new TransportException("Failed to route"); + doAnswer(invocation -> { + TransportResponseHandler handler = invocation.getArgument(3); + handler.handleException(expectedException); + return null; + }).when(transportService).sendRequest(any(), any(), any(), any()); + + var listener = doExecute(taskType); + + // Verify exception was propagated from "other-node" to "local-node" + verify(listener).onFailure(same(expectedException)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index e8e7d9ac21bed..4ed69e5abe537 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -9,13 +9,16 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; @@ -45,7 +48,10 @@ protected BaseTransportInferenceAction createAc ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, - StreamingTaskManager streamingTaskManager + StreamingTaskManager streamingTaskManager, + InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, + NodeClient nodeClient, + ThreadPool threadPool ) { return new TransportUnifiedCompletionInferenceAction( transportService, @@ -54,7 +60,10 @@ protected BaseTransportInferenceAction createAc modelRegistry, serviceRegistry, inferenceStats, - streamingTaskManager + streamingTaskManager, + inferenceServiceNodeLocalRateLimitCalculator, + nodeClient, + threadPool ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java new file mode 100644 index 0000000000000..05ee936c23fd7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java @@ -0,0 +1,205 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.DEFAULT_MAX_NODES_PER_GROUPING; +import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS; +import static org.hamcrest.Matchers.equalTo; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 0) +public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase { + + public void setUp() throws Exception { + super.setUp(); + } + + public void testInitialClusterGrouping_Correct() { + // Start with 2-5 nodes + var numNodes = randomIntBetween(2, 5); + var nodeNames = internalCluster().startNodes(numNodes); + ensureStableCluster(numNodes); + + RateLimitAssignment firstAssignment = null; + + for (String nodeName : nodeNames) { + var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeName); + + // Check first node's assignments + if (firstAssignment == null) { + // Get assignment for a specific service (e.g., EIS) + firstAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING); + + assertNotNull(firstAssignment); + // Verify there are assignments for this service + assertFalse(firstAssignment.responsibleNodes().isEmpty()); + } else { + // Verify other nodes see the same assignment + var currentAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING); + assertEquals(firstAssignment, currentAssignment); + } + } + } + + public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws IOException { + // Start with 3-5 nodes + var numNodes = randomIntBetween(3, 5); + var nodeNames = internalCluster().startNodes(numNodes); + ensureStableCluster(numNodes); + + var nodeLeftInCluster = nodeNames.get(0); + var currentNumberOfNodes = numNodes; + + // Stop all nodes except one + for (String nodeName : nodeNames) { + if (nodeName.equals(nodeLeftInCluster)) { + continue; + } + internalCluster().stopNode(nodeName); + currentNumberOfNodes--; + ensureStableCluster(currentNumberOfNodes); + } + + var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeLeftInCluster); + + Set supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet(); + + // Check assignments for each supported service + for (var service : supportedServices) { + var assignment = calculator.getRateLimitAssignment(service, TaskType.SPARSE_EMBEDDING); + + assertNotNull(assignment); + // Should have exactly one responsible node + assertEquals(1, assignment.responsibleNodes().size()); + // That node should be our remaining node + assertEquals(nodeLeftInCluster, assignment.responsibleNodes().get(0).getName()); + } + } + + public void testGrouping_RespectsMaxNodesPerGroupingLimit() { + // Start with more nodes possible per grouping + var numNodes = DEFAULT_MAX_NODES_PER_GROUPING + randomIntBetween(1, 3); + var nodeNames = internalCluster().startNodes(numNodes); + ensureStableCluster(numNodes); + + var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0)); + + Set supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet(); + + for (var service : supportedServices) { + var assignment = calculator.getRateLimitAssignment(service, TaskType.SPARSE_EMBEDDING); + + assertNotNull(assignment); + assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(assignment.responsibleNodes().size())); + } + } + + public void testInitialRateLimitsCalculation_Correct() throws IOException { + // Start with max nodes per grouping (=3) + int numNodes = DEFAULT_MAX_NODES_PER_GROUPING; + var nodeNames = internalCluster().startNodes(numNodes); + ensureStableCluster(numNodes); + + var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0)); + + Set supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet(); + + for (var serviceName : supportedServices) { + try (var serviceRegistry = calculator.serviceRegistry()) { + var serviceOptional = serviceRegistry.getService(serviceName); + assertTrue(serviceOptional.isPresent()); + var service = serviceOptional.get(); + + if ((service instanceof SenderService senderService)) { + var sender = senderService.getSender(); + if (sender instanceof HttpRequestSender httpSender) { + var assignment = calculator.getRateLimitAssignment(service.name(), TaskType.SPARSE_EMBEDDING); + + assertNotNull(assignment); + assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(assignment.responsibleNodes().size())); + } + } + } + + } + } + + public void testRateLimits_Decrease_OnNodeJoin() { + // Start with 2 nodes + var initialNodes = 2; + var nodeNames = internalCluster().startNodes(initialNodes); + ensureStableCluster(initialNodes); + + var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0)); + + for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) { + var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName); + for (var config : configs) { + // Get initial assignments and rate limits + var initialAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType()); + assertEquals(2, initialAssignment.responsibleNodes().size()); + + // Add a new node + internalCluster().startNode(); + ensureStableCluster(initialNodes + 1); + + // Get updated assignments + var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType()); + + // Verify number of responsible nodes increased + assertEquals(3, updatedAssignment.responsibleNodes().size()); + } + } + } + + public void testRateLimits_Increase_OnNodeLeave() throws IOException { + // Start with max nodes per grouping (=3) + int numNodes = DEFAULT_MAX_NODES_PER_GROUPING; + var nodeNames = internalCluster().startNodes(numNodes); + ensureStableCluster(numNodes); + + var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0)); + + for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) { + var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName); + for (var config : configs) { + // Get initial assignments and rate limits + var initialAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType()); + assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(initialAssignment.responsibleNodes().size())); + + // Remove a node + var nodeToRemove = nodeNames.get(numNodes - 1); + internalCluster().stopNode(nodeToRemove); + ensureStableCluster(numNodes - 1); + + // Get updated assignments + var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType()); + + // Verify number of responsible nodes decreased + assertThat(2, equalTo(updatedAssignment.responsibleNodes().size())); + } + } + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateInferencePlugin.class); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java index ed5aa5ba7bea9..57b9b03b9781b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java @@ -63,6 +63,11 @@ public void start() { // do nothing } + @Override + public void updateRateLimitDivisor(int rateLimitDivisor) { + // do nothing + } + @Override public void send( RequestManager requestCreator, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java index 0f6386f670338..c530ff5c03482 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java @@ -53,7 +53,7 @@ public void testFromMap() { ConfigurationParseContext.REQUEST ); - assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(240L)))); + assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(720L)))); } public void testFromMap_MissingModelId_ThrowsException() {