Skip to content

Commit

Permalink
[8.x] [Inference API] Add node-local rate limiting for the inference …
Browse files Browse the repository at this point in the history
…API (#120400) (#121251)

* [Inference API] Add node-local rate limiting for the inference API (#120400)

* Add node-local rate limiting for the inference API

* Fix integration tests by using new LocalStateInferencePlugin instead of InferencePlugin and adjust formatting.

* Correct feature flag name

* Add more docs, reorganize methods and make some methods package private

* Clarify comment in BaseInferenceActionRequest

* Fix wrong merge

* Fix checkstyle

* Fix checkstyle in tests

* Check that the service we want to the read the rate limit config for actually exists

* [CI] Auto commit changes from spotless

* checkStyle apply

* Update docs/changelog/120400.yaml

* Move rate limit division logic to RequestExecutorService

* Spotless apply

* Remove debug sout

* Adding a few suggestions

* Adam feedback

* Fix compilation error

* [CI] Auto commit changes from spotless

* Add BWC test case to InferenceActionRequestTests

* Add BWC test case to UnifiedCompletionActionRequestTests

* Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java

Co-authored-by: Adam Demjen <[email protected]>

* Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java

Co-authored-by: Adam Demjen <[email protected]>

* Remove addressed TODO

* Spotless apply

* Only use new rate limit specific feature flag

* Use ThreadLocalRandom

* [CI] Auto commit changes from spotless

* Use Randomness.get()

* [CI] Auto commit changes from spotless

* Fix import

* Use ConcurrentHashMap in InferenceServiceNodeLocalRateLimitCalculator

* Check for null value in getRateLimitAssignment and remove AtomicReference

* Remove newAssignments

* Up the default rate limit for completions

* Put deprecated feature flag back in

* Check feature flag in BaseTransportInferenceAction

* spotlessApply

* Export inference.common

* Do not export inference.common

* Provide noop rate limit calculator, if feature flag is disabled

* Add proper dependency injection

---------

Co-authored-by: elasticsearchmachine <[email protected]>
Co-authored-by: Jonathan Buttner <[email protected]>
Co-authored-by: Adam Demjen <[email protected]>

* Use .get(0) as getFirst() doesn't exist in 8.18 (probably JDK difference?)

---------

Co-authored-by: elasticsearchmachine <[email protected]>
Co-authored-by: Jonathan Buttner <[email protected]>
Co-authored-by: Adam Demjen <[email protected]>
  • Loading branch information
4 people authored Jan 30, 2025
1 parent 1261557 commit f0a5e25
Show file tree
Hide file tree
Showing 29 changed files with 1,015 additions and 49 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/120400.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 120400
summary: "[Inference API] Add node-local rate limiting for the inference API"
area: Machine Learning
type: feature
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0);
public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_00_0);
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,56 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.TaskType;

import java.io.IOException;

/**
* Base class for inference action requests. Tracks request routing state to prevent potential routing loops
* and supports both streaming and non-streaming inference operations.
*/
public abstract class BaseInferenceActionRequest extends ActionRequest {

private boolean hasBeenRerouted;

public BaseInferenceActionRequest() {
super();
}

public BaseInferenceActionRequest(StreamInput in) throws IOException {
super(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
this.hasBeenRerouted = in.readBoolean();
} else {
// For backwards compatibility, we treat all inference requests coming from ES nodes having
// a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior.
this.hasBeenRerouted = true;
}
}

public abstract boolean isStreaming();

public abstract TaskType getTaskType();

public abstract String getInferenceEntityId();

public void setHasBeenRerouted(boolean hasBeenRerouted) {
this.hasBeenRerouted = hasBeenRerouted;
}

public boolean hasBeenRerouted() {
return hasBeenRerouted;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
out.writeBoolean(hasBeenRerouted);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,29 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn
assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED));
}

public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException {
var instance = new InferenceAction.Request(
TaskType.TEXT_EMBEDDING,
"model",
null,
List.of("input"),
Map.of(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT,
false
);

InferenceAction.Request deserializedInstance = copyWriteable(
instance,
getNamedWriteableRegistry(),
instanceReader(),
TransportVersions.V_8_13_0
);

// Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version
assertTrue(deserializedInstance.hasBeenRerouted());
}

public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsUnspecified_VersionBeforeUnspecifiedIntroduced() {
assertThat(getInputTypeToWrite(InputType.UNSPECIFIED, TransportVersions.V_8_12_1), is(InputType.INGEST));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -65,6 +66,25 @@ public void testValidation_ReturnsNull_When_TaskType_IsAny() {
assertNull(request.validate());
}

public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException {
var instance = new UnifiedCompletionAction.Request(
"model",
TaskType.ANY,
UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())),
TimeValue.timeValueSeconds(10)
);

UnifiedCompletionAction.Request deserializedInstance = copyWriteable(
instance,
getNamedWriteableRegistry(),
instanceReader(),
TransportVersions.ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION
);

// Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version
assertTrue(deserializedInstance.hasBeenRerouted());
}

@Override
protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) {
return instance;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction;
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.NoopNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
Expand Down Expand Up @@ -133,6 +136,7 @@
import java.util.function.Supplier;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;

public class InferencePlugin extends Plugin
implements
Expand Down Expand Up @@ -229,6 +233,7 @@ public List<RestHandler> getRestHandlers(

@Override
public Collection<?> createComponents(PluginServices services) {
var components = new ArrayList<>();
var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
var truncator = new Truncator(settings, services.clusterService());
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
Expand Down Expand Up @@ -297,20 +302,38 @@ public Collection<?> createComponents(PluginServices services) {

// This must be done after the HttpRequestSenderFactory is created so that the services can get the
// reference correctly
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
registry.init(services.client());
for (var service : registry.getServices().values()) {
var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
serviceRegistry.init(services.client());
for (var service : serviceRegistry.getServices().values()) {
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
}
inferenceServiceRegistry.set(registry);
inferenceServiceRegistry.set(serviceRegistry);

var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), registry, modelRegistry);
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry);
shardBulkInferenceActionFilter.set(actionFilter);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));

components.add(serviceRegistry);
components.add(modelRegistry);
components.add(httpClientManager);
components.add(inferenceStats);

// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
// if the rate limiting feature flags are enabled, otherwise provide noop implementation
InferenceServiceRateLimitCalculator calculator;
if (INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG.isEnabled()) {
calculator = new InferenceServiceNodeLocalRateLimitCalculator(services.clusterService(), serviceRegistry);
} else {
calculator = new NoopNodeLocalRateLimitCalculator();
}

// Add binding for interface -> implementation
components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator));
components.add(calculator);

return List.of(modelRegistry, registry, httpClientManager, stats);
return components;
}

@Override
Expand Down
Loading

0 comments on commit f0a5e25

Please sign in to comment.