Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Adding dynamic filtering for EIS configuration #120235

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ default void init(Client client) {}
* Whether this service should be hidden from the API. Should be used for services
* that are not ready to be used.
*/
default Boolean hideFromConfigurationApi() {
return Boolean.FALSE;
default boolean hideFromConfigurationApi() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some refactoring, I think we can use a primitive here since I don't believe we'll ever want to return null.

return false;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
import org.elasticsearch.xcontent.XContentType;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

Expand Down Expand Up @@ -80,14 +78,11 @@ public InferenceServiceConfiguration(StreamInput in) throws IOException {
private static final ConstructingObjectParser<InferenceServiceConfiguration, Void> PARSER = new ConstructingObjectParser<>(
"inference_service_configuration",
true,
args -> {
List<String> taskTypes = (ArrayList<String>) args[2];
return new InferenceServiceConfiguration.Builder().setService((String) args[0])
.setName((String) args[1])
.setTaskTypes(EnumSet.copyOf(taskTypes.stream().map(TaskType::fromString).collect(Collectors.toList())))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copyOf throws if it receives an empty set so I modified this to allow an empty set via the builder. An empty set should be unlikely in production because we shouldn't be getting the configuration at all if no task types are supported but it helps the tests.

.setConfigurations((Map<String, SettingsConfiguration>) args[3])
.build();
}
args -> new InferenceServiceConfiguration.Builder().setService((String) args[0])
.setName((String) args[1])
.setTaskTypes((List<String>) args[2])
.setConfigurations((Map<String, SettingsConfiguration>) args[3])
.build()
);

static {
Expand Down Expand Up @@ -195,6 +190,16 @@ public Builder setTaskTypes(EnumSet<TaskType> taskTypes) {
return this;
}

public Builder setTaskTypes(List<String> taskTypes) {
var enumTaskTypes = EnumSet.noneOf(TaskType.class);

for (var supportedTaskTypeString : taskTypes) {
enumTaskTypes.add(TaskType.fromString(supportedTaskTypeString));
}
this.taskTypes = enumTaskTypes;
return this;
}

public Builder setConfigurations(Map<String, SettingsConfiguration> configurations) {
this.configurations = configurations;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,47 @@ public void testToXContent() throws IOException {
assertToXContentEquivalent(originalBytes, toXContent(parsed, XContentType.JSON, humanReadable), XContentType.JSON);
}

public void testToXContent_EmptyTaskTypes() throws IOException {
String content = XContentHelper.stripWhitespace("""
{
"service": "some_provider",
"name": "Some Provider",
"task_types": [],
"configurations": {
"text_field_configuration": {
"description": "Wow, this tooltip is useful.",
"label": "Very important field",
"required": true,
"sensitive": true,
"updatable": false,
"type": "str"
},
"numeric_field_configuration": {
"default_value": 3,
"description": "Wow, this tooltip is useful.",
"label": "Very important numeric field",
"required": true,
"sensitive": false,
"updatable": true,
"type": "int"
}
}
}
""");

InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes(
new BytesArray(content),
XContentType.JSON
);
boolean humanReadable = true;
BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable);
InferenceServiceConfiguration parsed;
try (XContentParser parser = createParser(XContentType.JSON.xContent(), originalBytes)) {
parsed = InferenceServiceConfiguration.fromXContent(parser);
}
assertToXContentEquivalent(originalBytes, toXContent(parsed, XContentType.JSON, humanReadable), XContentType.JSON);
}

public void testToMap() {
InferenceServiceConfiguration configField = InferenceServiceConfigurationTestUtils.getRandomServiceConfigurationField();
Map<String, Object> configFieldAsMap = configField.toMap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -122,6 +123,7 @@
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -274,14 +276,17 @@ public Collection<?> createComponents(PluginServices services) {

ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
String elasticInferenceUrl = this.getElasticInferenceServiceUrl(inferenceServiceSettings);
elasticInferenceServiceComponents.set(new ElasticInferenceServiceComponents(elasticInferenceUrl));
elasticInferenceServiceComponents.set(new ElasticInferenceServiceComponents(elasticInferenceUrl
// new ElasticInferenceServiceACL(Map.of("model-abc", EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)))
));

inferenceServices.add(
() -> List.of(
context -> new ElasticInferenceService(
elasicInferenceServiceFactory.get(),
serviceComponents.get(),
elasticInferenceServiceComponents.get()
elasticInferenceServiceComponents.get(),
modelRegistry
)
)
);
Expand Down Expand Up @@ -312,6 +317,13 @@ public Collection<?> createComponents(PluginServices services) {
return List.of(modelRegistry, registry, httpClientManager, stats);
}

private TraceContext getCurrentTraceInfo() {
var traceParent = threadPoolSetOnce.get().getThreadContext().getHeader(Task.TRACE_PARENT);
var traceState = threadPoolSetOnce.get().getThreadContext().getHeader(Task.TRACE_STATE);

return new TraceContext(traceParent, traceState);
}

@Override
public void loadExtensions(ExtensionLoader loader) {
inferenceServiceExtensions = loader.loadExtensions(InferenceServiceExtension.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@

package org.elasticsearch.xpack.inference.external.amazonbedrock;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestExecutorService;
import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings;
import org.elasticsearch.xpack.inference.external.http.sender.RequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ServiceComponents;

import java.io.IOException;
Expand Down Expand Up @@ -123,6 +126,17 @@ public void send(
listener.onFailure(new ElasticsearchException("Amazon Bedrock request sender did not receive a valid request request manager"));
}

@Override
public void sendWithoutQueuing(
Logger logger,
Request request,
ResponseHandler responseHandler,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throw new UnsupportedOperationException("not implemented");
}

@Override
public void close() throws IOException {
executorService.shutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
Expand All @@ -17,8 +19,10 @@
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ServiceComponents;

import java.io.IOException;
Expand Down Expand Up @@ -74,6 +78,7 @@ public Sender createSender() {
private final RequestExecutor service;
private final AtomicBoolean started = new AtomicBoolean(false);
private final CountDownLatch startCompleted = new CountDownLatch(1);
private final RequestSender requestSender;

private HttpRequestSender(
ThreadPool threadPool,
Expand All @@ -84,6 +89,7 @@ private HttpRequestSender(
) {
this.threadPool = Objects.requireNonNull(threadPool);
this.manager = Objects.requireNonNull(httpClientManager);
this.requestSender = Objects.requireNonNull(requestSender);
service = new RequestExecutorService(
threadPool,
startCompleted,
Expand Down Expand Up @@ -141,4 +147,32 @@ public void send(
waitForStartToComplete();
service.execute(requestCreator, inferenceInputs, timeout, listener);
}

/**
* This method sends a request and parses the response. It does not leverage any queuing or
* rate limiting logic. This method should only be used for requests that are not sent often.
*
* @param logger A logger to use for messages
* @param request A request to be sent
* @param responseHandler A handler for parsing the response
* @param timeout the maximum time the request should wait for a response before timing out. If null, the timeout is ignored.
* The queuing logic may still throw a timeout if it fails to send the request because it couldn't get a leased
* @param listener a listener to handle the response
*/
public void sendWithoutQueuing(
Logger logger,
Request request,
ResponseHandler responseHandler,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
assert started.get() : "call start() before sending a request";
waitForStartToComplete();

var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
var timedListener = new TimedListener<>(timeout, preservedListener, threadPool);

threadPool.executor(UTILITY_THREAD_POOL_NAME)
.execute(() -> requestSender.send(logger, request, timedListener::hasCompleted, responseHandler, timedListener.getListener()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,20 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ListenerTimeouts;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;

class RequestTask implements RejectableTask {

private final AtomicBoolean finished = new AtomicBoolean();
private final RequestManager requestCreator;
private final InferenceInputs inferenceInputs;
private final ActionListener<InferenceServiceResults> listener;
private final TimedListener<InferenceServiceResults> timedListener;

RequestTask(
RequestManager requestCreator,
Expand All @@ -38,44 +30,13 @@ class RequestTask implements RejectableTask {
ActionListener<InferenceServiceResults> listener
) {
this.requestCreator = Objects.requireNonNull(requestCreator);
this.listener = getListener(Objects.requireNonNull(listener), timeout, Objects.requireNonNull(threadPool));
this.timedListener = new TimedListener<>(timeout, listener, threadPool);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this into TimedListener so we could access it in the new send method of HttpRequestSender.

this.inferenceInputs = Objects.requireNonNull(inferenceInputs);
}

private ActionListener<InferenceServiceResults> getListener(
ActionListener<InferenceServiceResults> origListener,
@Nullable TimeValue timeout,
ThreadPool threadPool
) {
ActionListener<InferenceServiceResults> notificationListener = ActionListener.wrap(result -> {
finished.set(true);
origListener.onResponse(result);
}, e -> {
finished.set(true);
origListener.onFailure(e);
});

if (timeout == null) {
return notificationListener;
}

return ListenerTimeouts.wrapWithTimeout(
threadPool,
timeout,
threadPool.executor(UTILITY_THREAD_POOL_NAME),
notificationListener,
(ignored) -> notificationListener.onFailure(
new ElasticsearchStatusException(
Strings.format("Request timed out waiting to be sent after [%s]", timeout),
RestStatus.REQUEST_TIMEOUT
)
)
);
}

@Override
public boolean hasCompleted() {
return finished.get();
return timedListener.hasCompleted();
}

@Override
Expand All @@ -90,12 +51,12 @@ public InferenceInputs getInferenceInputs() {

@Override
public ActionListener<InferenceServiceResults> getListener() {
return listener;
return timedListener.getListener();
}

@Override
public void onRejection(Exception e) {
listener.onFailure(e);
timedListener.getListener().onFailure(e);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.io.Closeable;

Expand All @@ -23,4 +26,12 @@ void send(
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
);

void sendWithoutQueuing(
Logger logger,
Request request,
ResponseHandler responseHandler,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
);
}
Loading
Loading