diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 688d2aaf905a6..60ff354c18f5d 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -206,6 +206,7 @@ static TransportVersion def(int id) { public static final TransportVersion INGEST_PIPELINE_CONFIGURATION_AS_MAP = def(8_797_00_0); public static final TransportVersion INDEXING_PRESSURE_THROTTLING_STATS = def(8_798_00_0); public static final TransportVersion REINDEX_DATA_STREAMS = def(8_799_00_0); + public static final TransportVersion INFERENCE_API_CLUSTER_WIDE_RATE_LIMITS = def(8_800_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 62405a2e9f7de..d1e02271f22ac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -59,6 +59,7 @@ import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.UpdateRateLimitsClusterService; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; @@ -142,6 +143,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce eisComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); + private final SetOnce updateRateLimitsClusterService = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -188,6 +190,8 @@ public List getRestHandlers( @Override public Collection createComponents(PluginServices services) { + var clusterService = services.clusterService(); + var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService()); var truncator = new Truncator(settings, services.clusterService()); serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator)); @@ -238,7 +242,9 @@ public Collection createComponents(PluginServices services) { var meterRegistry = services.telemetryProvider().getMeterRegistry(); var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); - return List.of(modelRegistry, registry, httpClientManager, stats); + updateRateLimitsClusterService.set(new UpdateRateLimitsClusterService(clusterService, registry)); + + return List.of(modelRegistry, registry, httpClientManager, stats, updateRateLimitsClusterService.get()); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/UpdateRateLimitsClusterService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/UpdateRateLimitsClusterService.java new file mode 100644 index 0000000000000..06cd479427a79 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/UpdateRateLimitsClusterService.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.gateway.GatewayService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; + +public class UpdateRateLimitsClusterService implements ClusterStateListener { + + private static final Logger LOGGER = LogManager.getLogger(UpdateRateLimitsClusterService.class); + + private final InferenceServiceRegistry inferenceServiceRegistry; + + public UpdateRateLimitsClusterService(ClusterService clusterService, InferenceServiceRegistry inferenceServiceRegistry) { + this.inferenceServiceRegistry = inferenceServiceRegistry; + clusterService.addListener(this); + LOGGER.info("Added UpdateRateLimitsClusterService as a ClusterStateListener"); + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + LOGGER.info("Received cluster changed event {}", event.source()); + + // TODO: check what this actually does and if it's necessary + if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) { + return; + } + + // TODO: check if the cluster is ready? + + // TODO: other sanity checks? + + if (event.nodesAdded() || event.nodesRemoved()) { + LOGGER.info("Received nodesAdded or nodesRemoved event"); + + var numNodes = event.state().nodes().getSize(); + + LOGGER.info("Number of nodes in the cluster: {}", numNodes); + var elasticInferenceServiceOptional = inferenceServiceRegistry.getService(ElasticInferenceService.NAME); + + if (elasticInferenceServiceOptional.isPresent() == false) { + // TODO: adapt + LOGGER.info("ElasticInferenceService is not present"); + return; + } + + ElasticInferenceService elasticInferenceService = (ElasticInferenceService) elasticInferenceServiceOptional.get(); + var sender = elasticInferenceService.getSender(); + + if (sender instanceof HttpRequestSender == false) { + // TODO: adapt + LOGGER.warn("sender is not type HttpRequestSender"); + return; + } + + HttpRequestSender httpRequestSender = (HttpRequestSender) sender; + + LOGGER.info("Updating rate limits for {} endpoints", httpRequestSender.rateLimitingEndpointHandlers().size()); + for (RequestExecutorService.RateLimitingEndpointHandler rateLimitingEndpointHandler : httpRequestSender + .rateLimitingEndpointHandlers()) { + var originalRequestsPerTimeUnit = rateLimitingEndpointHandler.originalRequestsPerTimeUnit(); + var clusterAwareTokenLimit = originalRequestsPerTimeUnit / numNodes; + + if(event.nodesAdded()){ + LOGGER.info( + "Decreasing per node rate limit for endpoint {} from {} to {} tokens per time unit (node added)", + rateLimitingEndpointHandler.id(), + rateLimitingEndpointHandler.currentRequestsPerTimeUnit(), + clusterAwareTokenLimit + ); + } else if(event.nodesRemoved()){ + LOGGER.info( + "Increasing per node rate limit for endpoint {} from {} to {} tokens per time unit (node removed)", + rateLimitingEndpointHandler.id(), + rateLimitingEndpointHandler.currentRequestsPerTimeUnit(), + clusterAwareTokenLimit + ); + } + + rateLimitingEndpointHandler.updateTokensPerTimeUnit(clusterAwareTokenLimit); + } + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java index d1e309a774ab7..56b66217d5c3b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java @@ -22,6 +22,8 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import java.io.IOException; +import java.util.Collection; +import java.util.List; import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -92,6 +94,15 @@ private HttpRequestSender( ); } + public Collection rateLimitingEndpointHandlers() { + // TODO: there's probably a better way; just for the POC + if (service instanceof RequestExecutorService) { + return ((RequestExecutorService) service).rateLimitingEndpointHandlers(); + } + + return List.of(); + } + /** * Start various internal services. This is required before sending requests. */ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index ad1324d0a315f..0dbc906288a17 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -27,6 +27,7 @@ import java.time.Clock; import java.time.Instant; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Objects; import java.util.concurrent.BlockingQueue; @@ -53,7 +54,7 @@ * attempting to execute a task (aka waiting for the connection manager to lease a connection). See * {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. */ -class RequestExecutorService implements RequestExecutor { +public class RequestExecutorService implements RequestExecutor { /** * Provides dependency injection mainly for testing @@ -174,6 +175,10 @@ public int queueSize() { return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); } + public Collection rateLimitingEndpointHandlers() { + return rateLimitGroupings.values(); + } + /** * Begin servicing tasks. *

@@ -311,10 +316,10 @@ public void execute( /** * Provides rate limiting functionality for requests. A single {@link RateLimitingEndpointHandler} governs a group of requests. - * This allows many requests to be serialized if they are being sent too fast. If the rate limit has not been met they will be sent - * as soon as a thread is available. + * This allows many requests to be queued up and processed sequentially (serialized) if they are being sent too fast. + * If the rate limit has not been met they will be sent as soon as a thread is available. */ - private static class RateLimitingEndpointHandler { + public static class RateLimitingEndpointHandler { private static final TimeValue NO_TASKS_AVAILABLE = TimeValue.MAX_VALUE; private static final TimeValue EXECUTED_A_TASK = TimeValue.ZERO; @@ -329,6 +334,9 @@ private static class RateLimitingEndpointHandler { private final Clock clock; private final RateLimiter rateLimiter; private final RequestExecutorServiceSettings requestExecutorServiceSettings; + private final RateLimitSettings rateLimitSettings; + private final Long originalRequestsPerTimeUnit; + private Double currentRequestsPerTimeUnit; RateLimitingEndpointHandler( String id, @@ -346,8 +354,10 @@ private static class RateLimitingEndpointHandler { this.requestSender = Objects.requireNonNull(requestSender); this.clock = Objects.requireNonNull(clock); this.isShutdownMethod = Objects.requireNonNull(isShutdownMethod); + this.rateLimitSettings = Objects.requireNonNull(rateLimitSettings); + this.originalRequestsPerTimeUnit = rateLimitSettings.requestsPerTimeUnit(); + this.currentRequestsPerTimeUnit = (double) rateLimitSettings.requestsPerTimeUnit(); - Objects.requireNonNull(rateLimitSettings); Objects.requireNonNull(rateLimiterCreator); rateLimiter = rateLimiterCreator.create( ACCUMULATED_TOKENS_LIMIT, @@ -371,6 +381,23 @@ private void onCapacityChange(int capacity) { } } + public synchronized void updateTokensPerTimeUnit(double newTokensPerTimeUnit) { + rateLimiter.setRate(ACCUMULATED_TOKENS_LIMIT, newTokensPerTimeUnit, rateLimitSettings.timeUnit()); + this.currentRequestsPerTimeUnit = newTokensPerTimeUnit; + } + + public synchronized long originalRequestsPerTimeUnit() { + return originalRequestsPerTimeUnit; + } + + public synchronized long currentRequestsPerTimeUnit(){ + return currentRequestsPerTimeUnit.longValue(); + } + + public String id() { + return id; + } + public int queueSize() { return queue.size(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index b8a99227cf517..61869cd839875 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -42,7 +42,7 @@ public SenderService(HttpRequestSender.Factory factory, ServiceComponents servic this.serviceComponents = Objects.requireNonNull(serviceComponents); } - protected Sender getSender() { + public Sender getSender() { return sender; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 8525710c6cf23..ab34ffe2847c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -14,7 +14,10 @@ public class ElasticInferenceServiceSettings { - static final Setting EIS_GATEWAY_URL = Setting.simpleString("xpack.inference.eis.gateway.url", Setting.Property.NodeScope); + public static final Setting EIS_GATEWAY_URL = Setting.simpleString( + "xpack.inference.eis.gateway.url", + Setting.Property.NodeScope + ); // Adjust this variable to be volatile, if the setting can be updated at some point in time private final String eisGatewayUrl; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/UpdateRateLimitsClusterServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/UpdateRateLimitsClusterServiceTests.java new file mode 100644 index 0000000000000..24b9f6f3694fa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/UpdateRateLimitsClusterServiceTests.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.apache.logging.log4j.Level; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.InternalTestCluster; +import org.elasticsearch.test.MockLog; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 0) +public class UpdateRateLimitsClusterServiceTests extends ESIntegTestCase { + + private MockLog mockLog; + + public void setUp() throws Exception { + super.setUp(); + mockLog = MockLog.capture(UpdateRateLimitsClusterService.class, TransportService.class); + } + + public void tearDown() throws Exception { + mockLog.close(); + super.tearDown(); + } + + public void testNodeJoinsRateLimitsUpdated() throws IOException { + // TODO: Expectation does not work, but I see the correct log message + final String rateLimitUpdatedPattern = "Updating rate limit for endpoint"; + final MockLog.LoggingExpectation rateLimitUpdatedExpectation = new MockLog.SeenEventExpectation( + "rate limit updated", + UpdateRateLimitsClusterService.class.getCanonicalName(), + Level.INFO, + rateLimitUpdatedPattern + ); + + // This expectation works + final String publishAddressMessage = "publish_address"; + final MockLog.LoggingExpectation publishAddressExpectation = new MockLog.SeenEventExpectation( + "publish_address", + TransportService.class.getCanonicalName(), + Level.INFO, + publishAddressMessage + ); + + + // TODO: re-enable as soon you know how to capture logs of newly joined nodes + //mockLog.addExpectation(rateLimitUpdatedExpectation); + mockLog.addExpectation(publishAddressExpectation); + + var nodeSettings = Settings.builder() + .put(ElasticInferenceServiceSettings.EIS_GATEWAY_URL.getKey(), "http://localhost:8080") + .build(); + + var oneClusterNode = 1; + var putInferenceModelActionPayload = """ + { + "service": "elastic", + "service_settings": { + "model_id": ".elser_model_2" + } + } + """; + + InternalTestCluster internalCluster = internalCluster(); + + // We need to start a non-master-only-node to be able to store an inference endpoint + internalCluster.startNode(nodeSettings); + ensureStableCluster(oneClusterNode); + + var inferenceEntityId = "inference-endpoint-id"; + var createInferenceEndpointRequest = new PutInferenceModelAction.Request( + TaskType.SPARSE_EMBEDDING, + inferenceEntityId, + new BytesArray(putInferenceModelActionPayload), + XContentType.JSON + ); + client().execute(PutInferenceModelAction.INSTANCE, createInferenceEndpointRequest).actionGet(); + + // Perform inference so the rate limiting endpoint handler is stored + var performInferenceRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, + inferenceEntityId, + null, + List.of("some input"), + new HashMap<>(), + InputType.UNSPECIFIED, + TimeValue.THIRTY_SECONDS, + false + ); + client().execute(InferenceAction.INSTANCE, performInferenceRequest).actionGet(); + + // Start node two -> rate limit should be halved + var nodeTwoName = internalCluster.startNode(); + ensureStableCluster(2); + client().admin().cluster().prepareNodesStats().get(TimeValue.timeValueSeconds(10)); + + // Stop node two -> rate limit should be back at original value + internalCluster.stopNode(nodeTwoName); + ensureStableCluster(1); + + mockLog.assertAllExpectationsMatched(); + } + + @Override + protected Collection> nodePlugins() { + // TODO: InternalSettingsPlugin needed? + return Arrays.asList(InferencePlugin.class); + } +}