Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference API] Add node-local rate limiting for the inference API #120400

Merged
merged 72 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
e441ea8
Add node-local rate limiting for the inference API
timgrein Jan 17, 2025
75ee9f4
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 20, 2025
73eb9d5
Fix integration tests by using new LocalStateInferencePlugin instead …
timgrein Jan 20, 2025
6126047
Correct feature flag name
timgrein Jan 21, 2025
091a276
Add more docs, reorganize methods and make some methods package private
timgrein Jan 21, 2025
0a849b9
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 21, 2025
bbf5c0c
Clarify comment in BaseInferenceActionRequest
timgrein Jan 21, 2025
01647b6
Fix wrong merge
timgrein Jan 21, 2025
e1d017f
Fix checkstyle
timgrein Jan 21, 2025
e5b8768
Fix checkstyle in tests
timgrein Jan 21, 2025
bae9487
Check that the service we want to the read the rate limit config for …
timgrein Jan 21, 2025
e5b8adf
[CI] Auto commit changes from spotless
elasticsearchmachine Jan 21, 2025
dc4a79b
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 21, 2025
873d605
checkStyle apply
timgrein Jan 21, 2025
6fe7638
Update docs/changelog/120400.yaml
timgrein Jan 23, 2025
1f3ca4e
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 23, 2025
74909f5
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 24, 2025
be9c559
Move rate limit division logic to RequestExecutorService
timgrein Jan 27, 2025
858ca0e
Merge remote-tracking branch 'origin/inference-api-adaptive-rate-limi…
timgrein Jan 27, 2025
15fbb9f
Spotless apply
timgrein Jan 27, 2025
c0598ea
Remove debug sout
timgrein Jan 27, 2025
e6a0786
Adding a few suggestions
jonathan-buttner Jan 27, 2025
23a6785
Adam feedback
jonathan-buttner Jan 27, 2025
cd5686b
Merge pull request #1 from jonathan-buttner/jon-suggested-changes
timgrein Jan 28, 2025
04cf4eb
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 28, 2025
9837304
Fix compilation error
timgrein Jan 28, 2025
fa45864
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 28, 2025
3f88940
[CI] Auto commit changes from spotless
elasticsearchmachine Jan 28, 2025
a59c8df
Add BWC test case to InferenceActionRequestTests
timgrein Jan 28, 2025
b5aacec
Merge remote-tracking branch 'origin/inference-api-adaptive-rate-limi…
timgrein Jan 28, 2025
5d20011
Add BWC test case to UnifiedCompletionActionRequestTests
timgrein Jan 28, 2025
f265d52
Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/…
timgrein Jan 28, 2025
05745d3
Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/…
timgrein Jan 28, 2025
9a13041
Remove addressed TODO
timgrein Jan 28, 2025
5792e38
Merge remote-tracking branch 'origin/inference-api-adaptive-rate-limi…
timgrein Jan 28, 2025
90c20a2
Spotless apply
timgrein Jan 28, 2025
a869907
Only use new rate limit specific feature flag
timgrein Jan 28, 2025
8826038
Use ThreadLocalRandom
timgrein Jan 28, 2025
dccd816
[CI] Auto commit changes from spotless
elasticsearchmachine Jan 28, 2025
9387604
Use Randomness.get()
timgrein Jan 28, 2025
0594d6b
Merge remote-tracking branch 'origin/inference-api-adaptive-rate-limi…
timgrein Jan 28, 2025
7a83f60
[CI] Auto commit changes from spotless
elasticsearchmachine Jan 28, 2025
e5e6471
Fix import
timgrein Jan 28, 2025
850d3a8
Merge remote-tracking branch 'origin/inference-api-adaptive-rate-limi…
timgrein Jan 28, 2025
3d010a4
Use ConcurrentHashMap in InferenceServiceNodeLocalRateLimitCalculator
timgrein Jan 28, 2025
47244bd
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 28, 2025
ea7ef83
Check for null value in getRateLimitAssignment and remove AtomicRefer…
timgrein Jan 28, 2025
920c5e0
Merge remote-tracking branch 'origin/inference-api-adaptive-rate-limi…
timgrein Jan 28, 2025
81913e2
Remove newAssignments
timgrein Jan 28, 2025
43b60e7
Up the default rate limit for completions
timgrein Jan 29, 2025
86531fd
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
86db7a9
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
d7acd02
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
1f54382
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
e8df11a
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
9dc7c10
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
71ae899
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
5912c19
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
67ecf1e
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
d9f209a
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
3906cb3
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
40a6198
Put deprecated feature flag back in
timgrein Jan 29, 2025
c2a5391
Merge remote-tracking branch 'origin/inference-api-adaptive-rate-limi…
timgrein Jan 29, 2025
5052026
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
f856552
Check feature flag in BaseTransportInferenceAction
timgrein Jan 29, 2025
e9a7cae
Merge remote-tracking branch 'origin/inference-api-adaptive-rate-limi…
timgrein Jan 29, 2025
57f28db
spotlessApply
timgrein Jan 29, 2025
aa33350
Export inference.common
timgrein Jan 29, 2025
059fc24
Do not export inference.common
timgrein Jan 29, 2025
2e0be6f
Provide noop rate limit calculator, if feature flag is disabled
timgrein Jan 29, 2025
2143652
Add proper dependency injection
timgrein Jan 29, 2025
4327df2
Merge branch 'main' into inference-api-adaptive-rate-limiting
timgrein Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ static TransportVersion def(int id) {
public static final TransportVersion BYTE_SIZE_VALUE_ALWAYS_USES_BYTES_1 = def(8_825_00_0);
public static final TransportVersion REVERT_BYTE_SIZE_VALUE_ALWAYS_USES_BYTES_1 = def(8_826_00_0);
public static final TransportVersion ESQL_SKIP_ES_INDEX_SERIALIZATION = def(8_827_00_0);
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_ADDED = def(8_828_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,56 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.TaskType;

import java.io.IOException;

/**
* Base class for inference action requests. Tracks request routing state to prevent potential routing loops
* and supports both streaming and non-streaming inference operations.
*/
public abstract class BaseInferenceActionRequest extends ActionRequest {

private boolean hasBeenRerouted;

public BaseInferenceActionRequest() {
super();
}

public BaseInferenceActionRequest(StreamInput in) throws IOException {
super(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_ADDED)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we've talked about this separately but wanted to call it out, if you could let me know if this is ok once you hear. @davidkyle is it allowed to have something like the following in 8.17:

in.readBoolean();

And then in 8.18 add something like this:

if (in.getTransportVersion().onOrAfter(8.18)) {
  this.taskType = TaskType.fromStream(in)
}
in.readBoolean();

Essentially what is happening here is we're adding a new read (depending on the protocol version) to the beginning since this is a base class. Technically I think this is correct because we have the if-block, but I feel like we've always added new fields to the end not the front of the stream.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

Technically I think this is correct because we have the if-block, but I feel like we've always added new fields to the end not the front of the stream.

That's my opinion too, typically we append new fields to end of the existing code but this is in the super method and that is called first. The best way to be sure is to add a BWC streaming test case for the child classes InferenceActionRequestTests and UnifiedCompletionActionRequestTests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this.hasBeenRerouted = in.readBoolean();
} else {
// For backwards compatibility, we treat all inference requests coming from ES nodes having
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// For backwards compatibility, we treat all inference requests coming from ES nodes having
// For backwards compatibility, we treat all inference requests coming into ES nodes having

Just to to clarify that there is no rerouting in this case.

// a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior.
this.hasBeenRerouted = true;
}
}

public abstract boolean isStreaming();

public abstract TaskType getTaskType();

public abstract String getInferenceEntityId();

public void setHasBeenRerouted(boolean hasBeenRerouted) {
this.hasBeenRerouted = hasBeenRerouted;
}

public boolean hasBeenRerouted() {
return hasBeenRerouted;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_ADDED)) {
out.writeBoolean(hasBeenRerouted);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction;
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
Expand Down Expand Up @@ -133,6 +134,7 @@
import java.util.stream.Stream;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.*;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG;
Expand Down Expand Up @@ -243,6 +245,9 @@ public List<RestHandler> getRestHandlers(

@Override
public Collection<?> createComponents(PluginServices services) {
var components = new ArrayList<>();

var clusterService = services.clusterService();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: services.clusterService() is a simple getter, IMO there's no need for introducing a variable for line 306.

var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
var truncator = new Truncator(settings, services.clusterService());
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
Expand Down Expand Up @@ -298,26 +303,37 @@ public Collection<?> createComponents(PluginServices services) {
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(
services.client(),
services.threadPool(),
services.clusterService(),
clusterService,
settings
);

// This must be done after the HttpRequestSenderFactory is created so that the services can get the
// reference correctly
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
registry.init(services.client());
for (var service : registry.getServices().values()) {
var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
serviceRegistry.init(services.client());
for (var service : serviceRegistry.getServices().values()) {
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
}
inferenceServiceRegistry.set(registry);
inferenceServiceRegistry.set(serviceRegistry);

var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), registry, modelRegistry);
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry);
shardBulkInferenceActionFilter.set(actionFilter);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));

components.add(serviceRegistry);
components.add(modelRegistry);
components.add(httpClientManager);
components.add(inferenceStats);

// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
// if elastic inference service and the rate limiting feature flags are enabled
if (isElasticInferenceServiceEnabled() && INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG.isEnabled()) {
components.add(new InferenceServiceNodeLocalRateLimitCalculator(services.clusterService(), serviceRegistry));
jonathan-buttner marked this conversation as resolved.
Show resolved Hide resolved
}

return List.of(modelRegistry, registry, httpClientManager, stats);
return components;
}

@Override
Expand Down
Loading