Skip to content

Commit

Permalink
[Inference API] Remove second calculator instance as component and up…
Browse files Browse the repository at this point in the history
…date tests (#121284)
  • Loading branch information
timgrein authored Jan 31, 2025
1 parent a4455d4 commit 2993998
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,6 @@ public Collection<?> createComponents(PluginServices services) {

// Add binding for interface -> implementation
components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator));
components.add(calculator);

return components;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.services.SenderService;
Expand All @@ -23,6 +24,7 @@
import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.DEFAULT_MAX_NODES_PER_GROUPING;
import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;

@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 0)
public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase {
Expand All @@ -39,7 +41,7 @@ public void testInitialClusterGrouping_Correct() throws Exception {
var nodeNames = internalCluster().startNodes(numNodes);
ensureStableCluster(numNodes);

var firstCalculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
var firstCalculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst());
waitForRateLimitingAssignments(firstCalculator);

RateLimitAssignment firstAssignment = firstCalculator.getRateLimitAssignment(
Expand All @@ -49,7 +51,7 @@ public void testInitialClusterGrouping_Correct() throws Exception {

// Verify that all other nodes land on the same assignment
for (String nodeName : nodeNames.subList(1, nodeNames.size())) {
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeName);
var calculator = getCalculatorInstance(internalCluster(), nodeName);
waitForRateLimitingAssignments(calculator);
var currentAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
assertEquals(firstAssignment, currentAssignment);
Expand All @@ -75,7 +77,7 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
ensureStableCluster(currentNumberOfNodes);
}

var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeLeftInCluster);
var calculator = getCalculatorInstance(internalCluster(), nodeLeftInCluster);
waitForRateLimitingAssignments(calculator);

Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
Expand All @@ -98,7 +100,7 @@ public void testGrouping_RespectsMaxNodesPerGroupingLimit() throws Exception {
var nodeNames = internalCluster().startNodes(numNodes);
ensureStableCluster(numNodes);

var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst());
waitForRateLimitingAssignments(calculator);

Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
Expand All @@ -117,7 +119,7 @@ public void testInitialRateLimitsCalculation_Correct() throws Exception {
var nodeNames = internalCluster().startNodes(numNodes);
ensureStableCluster(numNodes);

var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst());
waitForRateLimitingAssignments(calculator);

Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
Expand Down Expand Up @@ -148,7 +150,7 @@ public void testRateLimits_Decrease_OnNodeJoin() throws Exception {
var nodeNames = internalCluster().startNodes(initialNodes);
ensureStableCluster(initialNodes);

var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst());
waitForRateLimitingAssignments(calculator);

for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
Expand Down Expand Up @@ -178,7 +180,7 @@ public void testRateLimits_Increase_OnNodeLeave() throws Exception {
var nodeNames = internalCluster().startNodes(numNodes);
ensureStableCluster(numNodes);

var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst());
waitForRateLimitingAssignments(calculator);

for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
Expand Down Expand Up @@ -208,6 +210,27 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateInferencePlugin.class);
}

private InferenceServiceNodeLocalRateLimitCalculator getCalculatorInstance(InternalTestCluster internalTestCluster, String nodeName) {
InferenceServiceRateLimitCalculator calculatorInstance = internalTestCluster.getInstance(
InferenceServiceRateLimitCalculator.class,
nodeName
);
assertThat(
"["
+ InferenceServiceNodeLocalRateLimitCalculatorTests.class.getName()
+ "] should use ["
+ InferenceServiceNodeLocalRateLimitCalculator.class.getName()
+ "] as implementation for ["
+ InferenceServiceRateLimitCalculator.class.getName()
+ "]. Provided implementation was ["
+ calculatorInstance.getClass().getName()
+ "].",
calculatorInstance,
instanceOf(InferenceServiceNodeLocalRateLimitCalculator.class)
);
return (InferenceServiceNodeLocalRateLimitCalculator) calculatorInstance;
}

private void waitForRateLimitingAssignments(InferenceServiceNodeLocalRateLimitCalculator calculator) throws Exception {
assertBusy(() -> {
var assignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
Expand Down

0 comments on commit 2993998

Please sign in to comment.