-
Notifications
You must be signed in to change notification settings - Fork 25k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference API] Add node-local rate limiting for the inference API #120400
Changes from 8 commits
e441ea8
75ee9f4
73eb9d5
6126047
091a276
0a849b9
bbf5c0c
01647b6
e1d017f
e5b8768
bae9487
e5b8adf
dc4a79b
873d605
6fe7638
1f3ca4e
74909f5
be9c559
858ca0e
15fbb9f
c0598ea
e6a0786
23a6785
cd5686b
04cf4eb
9837304
fa45864
3f88940
a59c8df
b5aacec
5d20011
f265d52
05745d3
9a13041
5792e38
90c20a2
a869907
8826038
dccd816
9387604
0594d6b
7a83f60
e5e6471
850d3a8
3d010a4
47244bd
ea7ef83
920c5e0
81913e2
43b60e7
86531fd
86db7a9
d7acd02
1f54382
e8df11a
9dc7c10
71ae899
5912c19
67ecf1e
d9f209a
3906cb3
40a6198
c2a5391
5052026
f856552
e9a7cae
57f28db
aa33350
059fc24
2e0be6f
2143652
4327df2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -7,25 +7,56 @@ | |||||
|
||||||
package org.elasticsearch.xpack.core.inference.action; | ||||||
|
||||||
import org.elasticsearch.TransportVersions; | ||||||
import org.elasticsearch.action.ActionRequest; | ||||||
import org.elasticsearch.common.io.stream.StreamInput; | ||||||
import org.elasticsearch.common.io.stream.StreamOutput; | ||||||
import org.elasticsearch.inference.TaskType; | ||||||
|
||||||
import java.io.IOException; | ||||||
|
||||||
/** | ||||||
* Base class for inference action requests. Tracks request routing state to prevent potential routing loops | ||||||
* and supports both streaming and non-streaming inference operations. | ||||||
*/ | ||||||
public abstract class BaseInferenceActionRequest extends ActionRequest { | ||||||
|
||||||
private boolean hasBeenRerouted; | ||||||
|
||||||
public BaseInferenceActionRequest() { | ||||||
super(); | ||||||
} | ||||||
|
||||||
public BaseInferenceActionRequest(StreamInput in) throws IOException { | ||||||
super(in); | ||||||
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_ADDED)) { | ||||||
this.hasBeenRerouted = in.readBoolean(); | ||||||
} else { | ||||||
// For backwards compatibility, we treat all inference requests coming from ES nodes having | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Just to to clarify that there is no rerouting in this case. |
||||||
// a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior. | ||||||
this.hasBeenRerouted = true; | ||||||
} | ||||||
} | ||||||
|
||||||
public abstract boolean isStreaming(); | ||||||
|
||||||
public abstract TaskType getTaskType(); | ||||||
|
||||||
public abstract String getInferenceEntityId(); | ||||||
|
||||||
public void setHasBeenRerouted(boolean hasBeenRerouted) { | ||||||
this.hasBeenRerouted = hasBeenRerouted; | ||||||
} | ||||||
|
||||||
public boolean hasBeenRerouted() { | ||||||
return hasBeenRerouted; | ||||||
} | ||||||
|
||||||
@Override | ||||||
public void writeTo(StreamOutput out) throws IOException { | ||||||
super.writeTo(out); | ||||||
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_ADDED)) { | ||||||
out.writeBoolean(hasBeenRerouted); | ||||||
} | ||||||
} | ||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,6 +72,7 @@ | |
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; | ||
import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; | ||
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; | ||
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; | ||
import org.elasticsearch.xpack.inference.common.Truncator; | ||
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; | ||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager; | ||
|
@@ -133,6 +134,7 @@ | |
import java.util.stream.Stream; | ||
|
||
import static java.util.Collections.singletonList; | ||
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.*; | ||
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; | ||
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG; | ||
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG; | ||
|
@@ -243,6 +245,9 @@ public List<RestHandler> getRestHandlers( | |
|
||
@Override | ||
public Collection<?> createComponents(PluginServices services) { | ||
var components = new ArrayList<>(); | ||
|
||
var clusterService = services.clusterService(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: |
||
var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService()); | ||
var truncator = new Truncator(settings, services.clusterService()); | ||
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator)); | ||
|
@@ -298,26 +303,37 @@ public Collection<?> createComponents(PluginServices services) { | |
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext( | ||
services.client(), | ||
services.threadPool(), | ||
services.clusterService(), | ||
clusterService, | ||
settings | ||
); | ||
|
||
// This must be done after the HttpRequestSenderFactory is created so that the services can get the | ||
// reference correctly | ||
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); | ||
registry.init(services.client()); | ||
for (var service : registry.getServices().values()) { | ||
var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext); | ||
serviceRegistry.init(services.client()); | ||
for (var service : serviceRegistry.getServices().values()) { | ||
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds); | ||
} | ||
inferenceServiceRegistry.set(registry); | ||
inferenceServiceRegistry.set(serviceRegistry); | ||
|
||
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), registry, modelRegistry); | ||
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry); | ||
shardBulkInferenceActionFilter.set(actionFilter); | ||
|
||
var meterRegistry = services.telemetryProvider().getMeterRegistry(); | ||
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); | ||
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); | ||
|
||
components.add(serviceRegistry); | ||
components.add(modelRegistry); | ||
components.add(httpClientManager); | ||
components.add(inferenceStats); | ||
|
||
// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting, | ||
// if elastic inference service and the rate limiting feature flags are enabled | ||
if (isElasticInferenceServiceEnabled() && INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG.isEnabled()) { | ||
components.add(new InferenceServiceNodeLocalRateLimitCalculator(services.clusterService(), serviceRegistry)); | ||
jonathan-buttner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
return List.of(modelRegistry, registry, httpClientManager, stats); | ||
return components; | ||
} | ||
|
||
@Override | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we've talked about this separately but wanted to call it out, if you could let me know if this is ok once you hear. @davidkyle is it allowed to have something like the following in 8.17:
And then in 8.18 add something like this:
Essentially what is happening here is we're adding a new read (depending on the protocol version) to the beginning since this is a base class. Technically I think this is correct because we have the if-block, but I feel like we've always added new fields to the end not the front of the stream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point.
That's my opinion too, typically we append new fields to end of the existing code but this is in the
super
method and that is called first. The best way to be sure is to add a BWC streaming test case for the child classesInferenceActionRequestTests
andUnifiedCompletionActionRequestTests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added with Add BWC test case to InferenceActionRequestTests and Add BWC test case to UnifiedCompletionActionRequestTests for old ES version sends to new ES version