Skip to content

Commit

Permalink
Adding acl call
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-buttner committed Jan 17, 2025
1 parent bbf6693 commit 5bc83f2
Show file tree
Hide file tree
Showing 13 changed files with 522 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.support.MappedActionFilter;
Expand Down Expand Up @@ -46,6 +47,7 @@
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -74,11 +76,15 @@
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings;
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceAclRequest;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAclResponseEntity;
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
Expand Down Expand Up @@ -124,17 +130,20 @@
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG;
Expand Down Expand Up @@ -279,8 +288,8 @@ public Collection<?> createComponents(PluginServices services) {
String elasticInferenceUrl = this.getElasticInferenceServiceUrl(inferenceServiceSettings);
elasticInferenceServiceComponents.set(
new ElasticInferenceServiceComponents(
elasticInferenceUrl,
new ElasticInferenceServiceACL(Map.of("model-abc", EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)))
elasticInferenceUrl
// new ElasticInferenceServiceACL(Map.of("model-abc", EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)))
)
);

Expand All @@ -289,7 +298,8 @@ public Collection<?> createComponents(PluginServices services) {
context -> new ElasticInferenceService(
elasicInferenceServiceFactory.get(),
serviceComponents.get(),
elasticInferenceServiceComponents.get()
elasticInferenceServiceComponents.get(),
modelRegistry
)
)
);
Expand Down Expand Up @@ -320,6 +330,13 @@ public Collection<?> createComponents(PluginServices services) {
return List.of(modelRegistry, registry, httpClientManager, stats);
}

private TraceContext getCurrentTraceInfo() {
var traceParent = threadPoolSetOnce.get().getThreadContext().getHeader(Task.TRACE_PARENT);
var traceState = threadPoolSetOnce.get().getThreadContext().getHeader(Task.TRACE_STATE);

return new TraceContext(traceParent, traceState);
}

@Override
public void loadExtensions(ExtensionLoader loader) {
inferenceServiceExtensions = loader.loadExtensions(InferenceServiceExtension.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@

package org.elasticsearch.xpack.inference.external.amazonbedrock;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestExecutorService;
import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings;
import org.elasticsearch.xpack.inference.external.http.sender.RequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ServiceComponents;

import java.io.IOException;
Expand Down Expand Up @@ -123,6 +126,17 @@ public void send(
listener.onFailure(new ElasticsearchException("Amazon Bedrock request sender did not receive a valid request request manager"));
}

@Override
public void sendWithoutQueuing(
Logger logger,
Request request,
ResponseHandler responseHandler,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throw new UnsupportedOperationException("not implemented");
}

@Override
public void close() throws IOException {
executorService.shutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
Expand All @@ -17,8 +19,10 @@
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ServiceComponents;

import java.io.IOException;
Expand Down Expand Up @@ -74,6 +78,7 @@ public Sender createSender() {
private final RequestExecutor service;
private final AtomicBoolean started = new AtomicBoolean(false);
private final CountDownLatch startCompleted = new CountDownLatch(1);
private final RequestSender requestSender;

private HttpRequestSender(
ThreadPool threadPool,
Expand All @@ -84,6 +89,7 @@ private HttpRequestSender(
) {
this.threadPool = Objects.requireNonNull(threadPool);
this.manager = Objects.requireNonNull(httpClientManager);
this.requestSender = Objects.requireNonNull(requestSender);
service = new RequestExecutorService(
threadPool,
startCompleted,
Expand Down Expand Up @@ -141,4 +147,32 @@ public void send(
waitForStartToComplete();
service.execute(requestCreator, inferenceInputs, timeout, listener);
}

/**
* This method sends a request and parses the response. It does not leverage any queuing or
* rate limiting logic. This method should only be used for requests that are not sent often.
*
* @param logger A logger to use for messages
* @param request A request to be sent
* @param responseHandler A handler for parsing the response
* @param timeout the maximum time the request should wait for a response before timing out. If null, the timeout is ignored.
* The queuing logic may still throw a timeout if it fails to send the request because it couldn't get a leased
* @param listener a listener to handle the response
*/
public void sendWithoutQueuing(
Logger logger,
Request request,
ResponseHandler responseHandler,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
assert started.get() : "call start() before sending a request";
waitForStartToComplete();

var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
var timedListener = new TimedListener<>(timeout, preservedListener, threadPool);

threadPool.executor(UTILITY_THREAD_POOL_NAME)
.execute(() -> requestSender.send(logger, request, timedListener::hasCompleted, responseHandler, timedListener.getListener()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,20 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ListenerTimeouts;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;

class RequestTask implements RejectableTask {

private final AtomicBoolean finished = new AtomicBoolean();
private final RequestManager requestCreator;
private final InferenceInputs inferenceInputs;
private final ActionListener<InferenceServiceResults> listener;
private final TimedListener<InferenceServiceResults> timedListener;

RequestTask(
RequestManager requestCreator,
Expand All @@ -38,44 +30,13 @@ class RequestTask implements RejectableTask {
ActionListener<InferenceServiceResults> listener
) {
this.requestCreator = Objects.requireNonNull(requestCreator);
this.listener = getListener(Objects.requireNonNull(listener), timeout, Objects.requireNonNull(threadPool));
this.timedListener = new TimedListener<>(timeout, listener, threadPool);
this.inferenceInputs = Objects.requireNonNull(inferenceInputs);
}

private ActionListener<InferenceServiceResults> getListener(
ActionListener<InferenceServiceResults> origListener,
@Nullable TimeValue timeout,
ThreadPool threadPool
) {
ActionListener<InferenceServiceResults> notificationListener = ActionListener.wrap(result -> {
finished.set(true);
origListener.onResponse(result);
}, e -> {
finished.set(true);
origListener.onFailure(e);
});

if (timeout == null) {
return notificationListener;
}

return ListenerTimeouts.wrapWithTimeout(
threadPool,
timeout,
threadPool.executor(UTILITY_THREAD_POOL_NAME),
notificationListener,
(ignored) -> notificationListener.onFailure(
new ElasticsearchStatusException(
Strings.format("Request timed out waiting to be sent after [%s]", timeout),
RestStatus.REQUEST_TIMEOUT
)
)
);
}

@Override
public boolean hasCompleted() {
return finished.get();
return timedListener.hasCompleted();
}

@Override
Expand All @@ -90,12 +51,12 @@ public InferenceInputs getInferenceInputs() {

@Override
public ActionListener<InferenceServiceResults> getListener() {
return listener;
return timedListener.getListener();
}

@Override
public void onRejection(Exception e) {
listener.onFailure(e);
timedListener.getListener().onFailure(e);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.io.Closeable;

Expand All @@ -23,4 +26,12 @@ void send(
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
);

void sendWithoutQueuing(
Logger logger,
Request request,
ResponseHandler responseHandler,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ListenerTimeouts;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;

/**
* Provides a way to set a timeout on the listener. If the time expires, the original listener's
* {@link ActionListener#onFailure(Exception)} is called with an error indicating there was a timeout.
*
* @param <Response> the type of the value that is passed in {@link ActionListener#onResponse(Object)}
*/
public class TimedListener<Response> {

private final ActionListener<Response> listenerWithTimeout;
private final AtomicBoolean finished = new AtomicBoolean();

public TimedListener(@Nullable TimeValue timeout, ActionListener<Response> listener, ThreadPool threadPool) {
listenerWithTimeout = getListener(Objects.requireNonNull(listener), timeout, Objects.requireNonNull(threadPool));
}

private ActionListener<Response> getListener(
ActionListener<Response> origListener,
@Nullable TimeValue timeout,
ThreadPool threadPool
) {
ActionListener<Response> notificationListener = ActionListener.wrap(result -> {
finished.set(true);
origListener.onResponse(result);
}, e -> {
finished.set(true);
origListener.onFailure(e);
});

if (timeout == null) {
return notificationListener;
}

return ListenerTimeouts.wrapWithTimeout(
threadPool,
timeout,
threadPool.executor(UTILITY_THREAD_POOL_NAME),
notificationListener,
(ignored) -> notificationListener.onFailure(
new ElasticsearchStatusException(Strings.format("Request timed out after [%s]", timeout), RestStatus.REQUEST_TIMEOUT)
)
);
}

public boolean hasCompleted() {
return finished.get();
}

public ActionListener<Response> getListener() {
return listenerWithTimeout;
}
}
Loading

0 comments on commit 5bc83f2

Please sign in to comment.