diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 224b2dd7a073b..e6987dc36212a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.support.MappedActionFilter; @@ -46,6 +47,7 @@ import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.ThreadPool; @@ -74,11 +76,15 @@ import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; +import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; +import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceAclRequest; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAclResponseEntity; import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper; @@ -124,17 +130,20 @@ import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.ArrayList; import java.util.Collection; import java.util.EnumSet; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Stream; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG; @@ -279,8 +288,8 @@ public Collection createComponents(PluginServices services) { String elasticInferenceUrl = this.getElasticInferenceServiceUrl(inferenceServiceSettings); elasticInferenceServiceComponents.set( new ElasticInferenceServiceComponents( - elasticInferenceUrl, - new ElasticInferenceServiceACL(Map.of("model-abc", EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))) + elasticInferenceUrl +// new ElasticInferenceServiceACL(Map.of("model-abc", EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))) ) ); @@ -289,7 +298,8 @@ public Collection createComponents(PluginServices services) { context -> new ElasticInferenceService( elasicInferenceServiceFactory.get(), serviceComponents.get(), - elasticInferenceServiceComponents.get() + elasticInferenceServiceComponents.get(), + modelRegistry ) ) ); @@ -320,6 +330,13 @@ public Collection createComponents(PluginServices services) { return List.of(modelRegistry, registry, httpClientManager, stats); } + private TraceContext getCurrentTraceInfo() { + var traceParent = threadPoolSetOnce.get().getThreadContext().getHeader(Task.TRACE_PARENT); + var traceState = threadPoolSetOnce.get().getThreadContext().getHeader(Task.TRACE_STATE); + + return new TraceContext(traceParent, traceState); + } + @Override public void loadExtensions(ExtensionLoader loader) { inferenceServiceExtensions = loader.loadExtensions(InferenceServiceExtension.class); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java index a8d85d896d684..ec4550b036d23 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.amazonbedrock; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; @@ -14,12 +15,14 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestExecutorService; import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.ServiceComponents; import java.io.IOException; @@ -123,6 +126,17 @@ public void send( listener.onFailure(new ElasticsearchException("Amazon Bedrock request sender did not receive a valid request request manager")); } + @Override + public void sendWithoutQueuing( + Logger logger, + Request request, + ResponseHandler responseHandler, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("not implemented"); + } + @Override public void close() throws IOException { executorService.shutdown(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java index d1e309a774ab7..7ebbd9dd83cd3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.inference.external.http.sender; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; @@ -17,8 +19,10 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.RequestExecutor; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; +import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.ServiceComponents; import java.io.IOException; @@ -74,6 +78,7 @@ public Sender createSender() { private final RequestExecutor service; private final AtomicBoolean started = new AtomicBoolean(false); private final CountDownLatch startCompleted = new CountDownLatch(1); + private final RequestSender requestSender; private HttpRequestSender( ThreadPool threadPool, @@ -84,6 +89,7 @@ private HttpRequestSender( ) { this.threadPool = Objects.requireNonNull(threadPool); this.manager = Objects.requireNonNull(httpClientManager); + this.requestSender = Objects.requireNonNull(requestSender); service = new RequestExecutorService( threadPool, startCompleted, @@ -141,4 +147,32 @@ public void send( waitForStartToComplete(); service.execute(requestCreator, inferenceInputs, timeout, listener); } + + /** + * This method sends a request and parses the response. It does not leverage any queuing or + * rate limiting logic. This method should only be used for requests that are not sent often. + * + * @param logger A logger to use for messages + * @param request A request to be sent + * @param responseHandler A handler for parsing the response + * @param timeout the maximum time the request should wait for a response before timing out. If null, the timeout is ignored. + * The queuing logic may still throw a timeout if it fails to send the request because it couldn't get a leased + * @param listener a listener to handle the response + */ + public void sendWithoutQueuing( + Logger logger, + Request request, + ResponseHandler responseHandler, + @Nullable TimeValue timeout, + ActionListener listener + ) { + assert started.get() : "call start() before sending a request"; + waitForStartToComplete(); + + var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()); + var timedListener = new TimedListener<>(timeout, preservedListener, threadPool); + + threadPool.executor(UTILITY_THREAD_POOL_NAME) + .execute(() -> requestSender.send(logger, request, timedListener::hasCompleted, responseHandler, timedListener.getListener())); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java index e5c29adeb9176..cba9bf73a9e99 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java @@ -7,28 +7,20 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.ListenerTimeouts; -import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; import java.util.Objects; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; - class RequestTask implements RejectableTask { - private final AtomicBoolean finished = new AtomicBoolean(); private final RequestManager requestCreator; private final InferenceInputs inferenceInputs; - private final ActionListener listener; + private final TimedListener timedListener; RequestTask( RequestManager requestCreator, @@ -38,44 +30,13 @@ class RequestTask implements RejectableTask { ActionListener listener ) { this.requestCreator = Objects.requireNonNull(requestCreator); - this.listener = getListener(Objects.requireNonNull(listener), timeout, Objects.requireNonNull(threadPool)); + this.timedListener = new TimedListener<>(timeout, listener, threadPool); this.inferenceInputs = Objects.requireNonNull(inferenceInputs); } - private ActionListener getListener( - ActionListener origListener, - @Nullable TimeValue timeout, - ThreadPool threadPool - ) { - ActionListener notificationListener = ActionListener.wrap(result -> { - finished.set(true); - origListener.onResponse(result); - }, e -> { - finished.set(true); - origListener.onFailure(e); - }); - - if (timeout == null) { - return notificationListener; - } - - return ListenerTimeouts.wrapWithTimeout( - threadPool, - timeout, - threadPool.executor(UTILITY_THREAD_POOL_NAME), - notificationListener, - (ignored) -> notificationListener.onFailure( - new ElasticsearchStatusException( - Strings.format("Request timed out waiting to be sent after [%s]", timeout), - RestStatus.REQUEST_TIMEOUT - ) - ) - ); - } - @Override public boolean hasCompleted() { - return finished.get(); + return timedListener.hasCompleted(); } @Override @@ -90,12 +51,12 @@ public InferenceInputs getInferenceInputs() { @Override public ActionListener getListener() { - return listener; + return timedListener.getListener(); } @Override public void onRejection(Exception e) { - listener.onFailure(e); + timedListener.getListener().onFailure(e); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java index 5a3af3d4a377f..3975a554586b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java @@ -7,10 +7,13 @@ package org.elasticsearch.xpack.inference.external.http.sender; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.Request; import java.io.Closeable; @@ -23,4 +26,12 @@ void send( @Nullable TimeValue timeout, ActionListener listener ); + + void sendWithoutQueuing( + Logger logger, + Request request, + ResponseHandler responseHandler, + @Nullable TimeValue timeout, + ActionListener listener + ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TimedListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TimedListener.java new file mode 100644 index 0000000000000..d455f2ef5180c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TimedListener.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ListenerTimeouts; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; + +/** + * Provides a way to set a timeout on the listener. If the time expires, the original listener's + * {@link ActionListener#onFailure(Exception)} is called with an error indicating there was a timeout. + * + * @param the type of the value that is passed in {@link ActionListener#onResponse(Object)} + */ +public class TimedListener { + + private final ActionListener listenerWithTimeout; + private final AtomicBoolean finished = new AtomicBoolean(); + + public TimedListener(@Nullable TimeValue timeout, ActionListener listener, ThreadPool threadPool) { + listenerWithTimeout = getListener(Objects.requireNonNull(listener), timeout, Objects.requireNonNull(threadPool)); + } + + private ActionListener getListener( + ActionListener origListener, + @Nullable TimeValue timeout, + ThreadPool threadPool + ) { + ActionListener notificationListener = ActionListener.wrap(result -> { + finished.set(true); + origListener.onResponse(result); + }, e -> { + finished.set(true); + origListener.onFailure(e); + }); + + if (timeout == null) { + return notificationListener; + } + + return ListenerTimeouts.wrapWithTimeout( + threadPool, + timeout, + threadPool.executor(UTILITY_THREAD_POOL_NAME), + notificationListener, + (ignored) -> notificationListener.onFailure( + new ElasticsearchStatusException(Strings.format("Request timed out after [%s]", timeout), RestStatus.REQUEST_TIMEOUT) + ) + ); + } + + public boolean hasCompleted() { + return finished.get(); + } + + public ActionListener getListener() { + return listenerWithTimeout; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAclRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAclRequest.java new file mode 100644 index 0000000000000..b8d1577b9dacc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAclRequest.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.apache.http.client.methods.HttpGet; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Objects; + +public class ElasticInferenceServiceAclRequest implements ElasticInferenceServiceRequest { + + private final URI uri; + private final TraceContextHandler traceContextHandler; + + public ElasticInferenceServiceAclRequest(String url, TraceContext traceContext) { + this.uri = createUri(Objects.requireNonNull(url)); + this.traceContextHandler = new TraceContextHandler(traceContext); + } + + private URI createUri(String url) throws ElasticsearchStatusException { + try { + // TODO, consider transforming the base URL into a URI for better error handling. + return new URI(url + "/api/v1/allowed-models"); + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + "Failed to create URI for service [" + ElasticInferenceService.NAME + "]: " + e.getMessage(), + RestStatus.BAD_REQUEST, + e + ); + } + } + + @Override + public HttpRequest createHttpRequest() { + var httpGet = new HttpGet(uri); + traceContextHandler.propagateTraceContext(httpGet); + + return new HttpRequest(httpGet, getInferenceEntityId()); + } + + public TraceContext getTraceContext() { + return traceContextHandler.traceContext(); + } + + @Override + public String getInferenceEntityId() { + // TODO look into refactoring so we don't even need to return this, look at the RetryingHttpSender to fix this + return ""; + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAclResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAclResponseEntity.java new file mode 100644 index 0000000000000..cb6287a3b6660 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAclResponseEntity.java @@ -0,0 +1,154 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.elastic; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; + +public class ElasticInferenceServiceAclResponseEntity implements InferenceServiceResults { + + public static final String NAME = "elastic_inference_service_acl_results"; + public static final String COMPLETION = TaskType.COMPLETION.name().toLowerCase(Locale.ROOT); + + @SuppressWarnings("unchecked") + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ElasticInferenceServiceAclResponseEntity.class.getSimpleName(), + args -> new ElasticInferenceServiceAclResponseEntity((List) args[0]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), AllowedModel.ALLOWED_MODEL_PARSER::apply, new ParseField("allowed_models")); + } + + public record AllowedModel(String modelName, EnumSet taskTypes) implements Writeable, ToXContentObject { + + @SuppressWarnings("unchecked") + public static ConstructingObjectParser ALLOWED_MODEL_PARSER = new ConstructingObjectParser<>( + AllowedModel.class.getSimpleName(), + args -> new AllowedModel((String) args[0], toTaskTypes((List) args[1])) + ); + + static { + ALLOWED_MODEL_PARSER.declareString(constructorArg(), new ParseField("model_name")); + ALLOWED_MODEL_PARSER.declareStringArray(constructorArg(), new ParseField("task_types")); + } + + private static EnumSet toTaskTypes(List stringTaskTypes) { + var taskTypes = EnumSet.noneOf(TaskType.class); + for (String taskType : stringTaskTypes) { + taskTypes.add(TaskType.fromStringOrStatusException(taskType)); + } + + return taskTypes; + } + + public AllowedModel(StreamInput in) throws IOException { + this(in.readString(), in.readEnumSet(TaskType.class)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelName); + out.writeEnumSet(taskTypes); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field("model_name", modelName); + builder.field("task_types", taskTypes.stream().map(TaskType::toString).collect(Collectors.toList())); + + builder.endObject(); + + return builder; + } + } + + private final List allowedModels; + + public ElasticInferenceServiceAclResponseEntity(List allowedModels) { + this.allowedModels = Objects.requireNonNull(allowedModels); + } + + public ElasticInferenceServiceAclResponseEntity(StreamInput in) throws IOException { + this(in.readCollectionAsList(AllowedModel::new)); + } + + public static ElasticInferenceServiceAclResponseEntity fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + return PARSER.apply(jsonParser, null); + } + } + + public List getAllowedModels() { + return allowedModels; + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + throw new UnsupportedOperationException(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(allowedModels); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public List transformToCoordinationFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public List transformToLegacyFormat() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public Map asMap() { + throw new UnsupportedOperationException("Not implemented"); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index d857471e538bf..c51c7fc27aa48 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.services.elastic; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; @@ -34,11 +36,16 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionCreator; +import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceUnifiedCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceAclRequest; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAclResponseEntity; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -52,8 +59,10 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Set; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; @@ -70,37 +79,86 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; - private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; - + private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class); private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); private static final String SERVICE_NAME = "Elastic"; + private static final ResponseHandler aclResponseHandler = createAclResponseHandler(); + /** * The task types that the {@link InferenceAction.Request} can accept. */ private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING); - private final Configuration configuration; - private final EnumSet enabledTaskTypes; + private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; + private Configuration configuration; + private EnumSet enabledTaskTypes; + private final ModelRegistry modelRegistry; public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, - ElasticInferenceServiceComponents elasticInferenceServiceComponents + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ModelRegistry modelRegistry + ) { + super(factory, serviceComponents); + this.elasticInferenceServiceComponents = Objects.requireNonNull(elasticInferenceServiceComponents); + this.modelRegistry = Objects.requireNonNull(modelRegistry); + + enabledTaskTypes = EnumSet.noneOf(TaskType.class); + configuration = new Configuration(enabledTaskTypes); + + getAcl(elasticInferenceServiceComponents.elasticInferenceServiceUrl()); + } + + // TODO consider removing this should only be used for testing maybe mock getAcl instead? + ElasticInferenceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ModelRegistry modelRegistry, + ElasticInferenceServiceACL acl ) { super(factory, serviceComponents); - this.elasticInferenceServiceComponents = elasticInferenceServiceComponents; - enabledTaskTypes = enabledTaskTypes(this.elasticInferenceServiceComponents.acl()); + this.elasticInferenceServiceComponents = Objects.requireNonNull(elasticInferenceServiceComponents); + this.modelRegistry = Objects.requireNonNull(modelRegistry); + + setEnabledTaskTypes(acl); + } + + private void getAcl(String baseEISUrl) { + ActionListener listener = ActionListener.wrap(r -> { + if (r instanceof ElasticInferenceServiceAclResponseEntity aclResponseEntity) { + setEnabledTaskTypes(ElasticInferenceServiceACL.of(aclResponseEntity)); + } + }, e -> { logger.warn(Strings.format("Failed to retrieve ACL information for the Elastic Inference Service gateway: %s", e)); }); + + var request = new ElasticInferenceServiceAclRequest(baseEISUrl, getCurrentTraceInfo()); + + getSender().sendWithoutQueuing(logger, request, aclResponseHandler, DEFAULT_TIMEOUT, listener); + } + + private static ResponseHandler createAclResponseHandler() { + return new ElasticInferenceServiceResponseHandler( + String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), + ElasticInferenceServiceAclResponseEntity::fromResponse + ); + } + + private synchronized void setEnabledTaskTypes(ElasticInferenceServiceACL acl) { + enabledTaskTypes = filterTaskTypesByAcl(acl); configuration = new Configuration(enabledTaskTypes); + + defaultConfigIds().forEach(modelRegistry::addDefaultIds); } - private static EnumSet enabledTaskTypes(ElasticInferenceServiceACL acl) { + private static EnumSet filterTaskTypesByAcl(ElasticInferenceServiceACL acl) { var implementedTaskTypes = EnumSet.copyOf(IMPLEMENTED_TASK_TYPES); implementedTaskTypes.retainAll(acl.enabledTaskTypes()); return implementedTaskTypes; } @Override - public Set supportedStreamingTasks() { + public synchronized Set supportedStreamingTasks() { var enabledStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); enabledStreamingTaskTypes.retainAll(enabledTaskTypes); @@ -111,6 +169,12 @@ public Set supportedStreamingTasks() { return enabledStreamingTaskTypes; } + @Override + public synchronized List defaultConfigIds() { + // TODO once we have the enabledTaskTypes figure out which default endpoints we should expose + return List.of(); + } + @Override protected void doUnifiedCompletionInfer( Model model, @@ -235,17 +299,17 @@ public void parseRequestConfig( } @Override - public InferenceServiceConfiguration getConfiguration() { + public synchronized InferenceServiceConfiguration getConfiguration() { return configuration.get(); } @Override - public EnumSet supportedTaskTypes() { + public synchronized EnumSet supportedTaskTypes() { return enabledTaskTypes; } @Override - public boolean hideFromConfigurationApi() { + public synchronized boolean hideFromConfigurationApi() { return enabledTaskTypes.isEmpty(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceACL.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceACL.java index 270aced63bcb5..fe2083a5cdb6d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceACL.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceACL.java @@ -8,8 +8,10 @@ package org.elasticsearch.xpack.inference.services.elastic; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAclResponseEntity; import java.util.EnumSet; +import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -21,6 +23,22 @@ */ public record ElasticInferenceServiceACL(Map> enabledModels) { + /** + * Converts an ACL response into the {@link ElasticInferenceServiceACL} format. + * + * @param responseEntity the response from the upstream gateway. + * @return a new {@link ElasticInferenceServiceACL} + */ + public static ElasticInferenceServiceACL of(ElasticInferenceServiceAclResponseEntity responseEntity) { + var enabledModels = new HashMap>(); + + for (var model : responseEntity.getAllowedModels()) { + enabledModels.put(model.modelName(), model.taskTypes()); + } + + return new ElasticInferenceServiceACL(enabledModels); + } + /** * Returns an object indicating that the cluster has no access to EIS. */ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java index 7901e7d7b4177..c5b2cb693df13 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java @@ -7,4 +7,4 @@ package org.elasticsearch.xpack.inference.services.elastic; -public record ElasticInferenceServiceComponents(String elasticInferenceServiceUrl, ElasticInferenceServiceACL acl) {} +public record ElasticInferenceServiceComponents(String elasticInferenceServiceUrl) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextHandler.java index 92fe214d821db..7452317189208 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextHandler.java @@ -7,12 +7,12 @@ package org.elasticsearch.xpack.inference.telemetry; -import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; import org.elasticsearch.tasks.Task; public record TraceContextHandler(TraceContext traceContext) { - public void propagateTraceContext(HttpPost httpPost) { + public void propagateTraceContext(HttpRequestBase httpRequest) { if (traceContext == null) { return; } @@ -21,11 +21,11 @@ public void propagateTraceContext(HttpPost httpPost) { var traceState = traceContext.traceState(); if (traceParent != null) { - httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent); + httpRequest.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent); } if (traceState != null) { - httpPost.setHeader(Task.TRACE_STATE, traceState); + httpRequest.setHeader(Task.TRACE_STATE, traceState); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 860e0d84011ab..dc65ecf3bf0c9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; @@ -38,6 +39,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; @@ -301,7 +303,9 @@ public void testCheckModelConfig_ReturnsNewModelReference() throws IOException { var service = new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(getUrl(webServer), ElasticInferenceServiceACLTests.createEnabledAcl()) + new ElasticInferenceServiceComponents(getUrl(webServer)), + mockModelRegistry(), + ElasticInferenceServiceACLTests.createEnabledAcl() ) ) { var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer)); @@ -325,7 +329,9 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException var service = new ElasticInferenceService( factory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(null, ElasticInferenceServiceACLTests.createEnabledAcl()) + new ElasticInferenceServiceComponents(null), + mockModelRegistry(), + ElasticInferenceServiceACLTests.createEnabledAcl() ) ) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -355,6 +361,12 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException verifyNoMoreInteractions(sender); } + private ModelRegistry mockModelRegistry() { + var client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + return new ModelRegistry(client); + } + public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { var sender = mock(Sender.class); @@ -367,7 +379,9 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { var service = new ElasticInferenceService( factory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(null, ElasticInferenceServiceACLTests.createEnabledAcl()) + new ElasticInferenceServiceComponents(null), + mockModelRegistry(), + ElasticInferenceServiceACLTests.createEnabledAcl() ) ) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -412,7 +426,9 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var service = new ElasticInferenceService( factory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(null, ElasticInferenceServiceACLTests.createEnabledAcl()) + new ElasticInferenceServiceComponents(null), + mockModelRegistry(), + ElasticInferenceServiceACLTests.createEnabledAcl() ) ) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -455,7 +471,9 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { var service = new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(eisGatewayUrl, ElasticInferenceServiceACLTests.createEnabledAcl()) + new ElasticInferenceServiceComponents(eisGatewayUrl), + mockModelRegistry(), + ElasticInferenceServiceACLTests.createEnabledAcl() ) ) { String responseJson = """ @@ -512,7 +530,9 @@ public void testChunkedInfer_PassesThrough() throws IOException { var service = new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(eisGatewayUrl, ElasticInferenceServiceACLTests.createEnabledAcl()) + new ElasticInferenceServiceComponents(eisGatewayUrl), + mockModelRegistry(), + ElasticInferenceServiceACLTests.createEnabledAcl() ) ) { String responseJson = """ @@ -757,7 +777,9 @@ private ElasticInferenceService createServiceWithMockSender() { return new ElasticInferenceService( mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(null, ElasticInferenceServiceACLTests.createEnabledAcl()) + new ElasticInferenceServiceComponents(null), + mockModelRegistry(), + ElasticInferenceServiceACLTests.createEnabledAcl() ); } @@ -765,7 +787,9 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ return new ElasticInferenceService( mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(null, acl) + new ElasticInferenceServiceComponents(null), + mockModelRegistry(), + acl ); } }