Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] "cluster-aware" rate limiting approximation inside the inference API #117505

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INGEST_PIPELINE_CONFIGURATION_AS_MAP = def(8_797_00_0);
public static final TransportVersion INDEXING_PRESSURE_THROTTLING_STATS = def(8_798_00_0);
public static final TransportVersion REINDEX_DATA_STREAMS = def(8_799_00_0);
public static final TransportVersion INFERENCE_API_CLUSTER_WIDE_RATE_LIMITS = def(8_800_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction;
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.common.UpdateRateLimitsClusterService;
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
Expand Down Expand Up @@ -142,6 +143,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
private final SetOnce<ElasticInferenceServiceComponents> eisComponents = new SetOnce<>();
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
private final SetOnce<UpdateRateLimitsClusterService> updateRateLimitsClusterService = new SetOnce<>();
private List<InferenceServiceExtension> inferenceServiceExtensions;

public InferencePlugin(Settings settings) {
Expand Down Expand Up @@ -188,6 +190,8 @@ public List<RestHandler> getRestHandlers(

@Override
public Collection<?> createComponents(PluginServices services) {
var clusterService = services.clusterService();

var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
var truncator = new Truncator(settings, services.clusterService());
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
Expand Down Expand Up @@ -238,7 +242,9 @@ public Collection<?> createComponents(PluginServices services) {
var meterRegistry = services.telemetryProvider().getMeterRegistry();
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));

return List.of(modelRegistry, registry, httpClientManager, stats);
updateRateLimitsClusterService.set(new UpdateRateLimitsClusterService(clusterService, registry));

return List.of(modelRegistry, registry, httpClientManager, stats, updateRateLimitsClusterService.get());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.gateway.GatewayService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;

public class UpdateRateLimitsClusterService implements ClusterStateListener {

private static final Logger LOGGER = LogManager.getLogger(UpdateRateLimitsClusterService.class);

private final InferenceServiceRegistry inferenceServiceRegistry;

public UpdateRateLimitsClusterService(ClusterService clusterService, InferenceServiceRegistry inferenceServiceRegistry) {
this.inferenceServiceRegistry = inferenceServiceRegistry;
clusterService.addListener(this);
LOGGER.info("Added UpdateRateLimitsClusterService as a ClusterStateListener");
}

@Override
public void clusterChanged(ClusterChangedEvent event) {
LOGGER.info("Received cluster changed event {}", event.source());

// TODO: check what this actually does and if it's necessary
if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) {
return;
}

// TODO: check if the cluster is ready?

// TODO: other sanity checks?

if (event.nodesAdded() || event.nodesRemoved()) {
LOGGER.info("Received nodesAdded or nodesRemoved event");

var numNodes = event.state().nodes().getSize();

LOGGER.info("Number of nodes in the cluster: {}", numNodes);
var elasticInferenceServiceOptional = inferenceServiceRegistry.getService(ElasticInferenceService.NAME);

if (elasticInferenceServiceOptional.isPresent() == false) {
// TODO: adapt
LOGGER.info("ElasticInferenceService is not present");
return;
}

ElasticInferenceService elasticInferenceService = (ElasticInferenceService) elasticInferenceServiceOptional.get();
var sender = elasticInferenceService.getSender();

if (sender instanceof HttpRequestSender == false) {
// TODO: adapt
LOGGER.warn("sender is not type HttpRequestSender");
return;
}

HttpRequestSender httpRequestSender = (HttpRequestSender) sender;

LOGGER.info("Updating rate limits for {} endpoints", httpRequestSender.rateLimitingEndpointHandlers().size());
for (RequestExecutorService.RateLimitingEndpointHandler rateLimitingEndpointHandler : httpRequestSender
.rateLimitingEndpointHandlers()) {
var originalRequestsPerTimeUnit = rateLimitingEndpointHandler.originalRequestsPerTimeUnit();
var clusterAwareTokenLimit = originalRequestsPerTimeUnit / numNodes;

if(event.nodesAdded()){
LOGGER.info(
"Decreasing per node rate limit for endpoint {} from {} to {} tokens per time unit (node added)",
rateLimitingEndpointHandler.id(),
rateLimitingEndpointHandler.currentRequestsPerTimeUnit(),
clusterAwareTokenLimit
);
} else if(event.nodesRemoved()){
LOGGER.info(
"Increasing per node rate limit for endpoint {} from {} to {} tokens per time unit (node removed)",
rateLimitingEndpointHandler.id(),
rateLimitingEndpointHandler.currentRequestsPerTimeUnit(),
clusterAwareTokenLimit
);
}

rateLimitingEndpointHandler.updateTokensPerTimeUnit(clusterAwareTokenLimit);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -92,6 +94,15 @@ private HttpRequestSender(
);
}

public Collection<RequestExecutorService.RateLimitingEndpointHandler> rateLimitingEndpointHandlers() {
// TODO: there's probably a better way; just for the POC
if (service instanceof RequestExecutorService) {
return ((RequestExecutorService) service).rateLimitingEndpointHandlers();
}

return List.of();
}

/**
* Start various internal services. This is required before sending requests.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.BlockingQueue;
Expand All @@ -53,7 +54,7 @@
* attempting to execute a task (aka waiting for the connection manager to lease a connection). See
* {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info.
*/
class RequestExecutorService implements RequestExecutor {
public class RequestExecutorService implements RequestExecutor {

/**
* Provides dependency injection mainly for testing
Expand Down Expand Up @@ -174,6 +175,10 @@ public int queueSize() {
return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum();
}

public Collection<RateLimitingEndpointHandler> rateLimitingEndpointHandlers() {
return rateLimitGroupings.values();
}

/**
* Begin servicing tasks.
* <p>
Expand Down Expand Up @@ -311,10 +316,10 @@ public void execute(

/**
* Provides rate limiting functionality for requests. A single {@link RateLimitingEndpointHandler} governs a group of requests.
* This allows many requests to be serialized if they are being sent too fast. If the rate limit has not been met they will be sent
* as soon as a thread is available.
* This allows many requests to be queued up and processed sequentially (serialized) if they are being sent too fast.
* If the rate limit has not been met they will be sent as soon as a thread is available.
*/
private static class RateLimitingEndpointHandler {
public static class RateLimitingEndpointHandler {

private static final TimeValue NO_TASKS_AVAILABLE = TimeValue.MAX_VALUE;
private static final TimeValue EXECUTED_A_TASK = TimeValue.ZERO;
Expand All @@ -329,6 +334,9 @@ private static class RateLimitingEndpointHandler {
private final Clock clock;
private final RateLimiter rateLimiter;
private final RequestExecutorServiceSettings requestExecutorServiceSettings;
private final RateLimitSettings rateLimitSettings;
private final Long originalRequestsPerTimeUnit;
private Double currentRequestsPerTimeUnit;

RateLimitingEndpointHandler(
String id,
Expand All @@ -346,8 +354,10 @@ private static class RateLimitingEndpointHandler {
this.requestSender = Objects.requireNonNull(requestSender);
this.clock = Objects.requireNonNull(clock);
this.isShutdownMethod = Objects.requireNonNull(isShutdownMethod);
this.rateLimitSettings = Objects.requireNonNull(rateLimitSettings);
this.originalRequestsPerTimeUnit = rateLimitSettings.requestsPerTimeUnit();
this.currentRequestsPerTimeUnit = (double) rateLimitSettings.requestsPerTimeUnit();

Objects.requireNonNull(rateLimitSettings);
Objects.requireNonNull(rateLimiterCreator);
rateLimiter = rateLimiterCreator.create(
ACCUMULATED_TOKENS_LIMIT,
Expand All @@ -371,6 +381,23 @@ private void onCapacityChange(int capacity) {
}
}

public synchronized void updateTokensPerTimeUnit(double newTokensPerTimeUnit) {
rateLimiter.setRate(ACCUMULATED_TOKENS_LIMIT, newTokensPerTimeUnit, rateLimitSettings.timeUnit());
this.currentRequestsPerTimeUnit = newTokensPerTimeUnit;
}

public synchronized long originalRequestsPerTimeUnit() {
return originalRequestsPerTimeUnit;
}

public synchronized long currentRequestsPerTimeUnit(){
return currentRequestsPerTimeUnit.longValue();
}

public String id() {
return id;
}

public int queueSize() {
return queue.size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public SenderService(HttpRequestSender.Factory factory, ServiceComponents servic
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

protected Sender getSender() {
public Sender getSender() {
return sender;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

public class ElasticInferenceServiceSettings {

static final Setting<String> EIS_GATEWAY_URL = Setting.simpleString("xpack.inference.eis.gateway.url", Setting.Property.NodeScope);
public static final Setting<String> EIS_GATEWAY_URL = Setting.simpleString(
"xpack.inference.eis.gateway.url",
Setting.Property.NodeScope
);

// Adjust this variable to be volatile, if the setting can be updated at some point in time
private final String eisGatewayUrl;
Expand Down
Loading