-
Notifications
You must be signed in to change notification settings - Fork 25k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Adding dynamic filtering for EIS configuration #120235
base: main
Are you sure you want to change the base?
Changes from all commits
1ca8b32
5c0b35d
f7d978f
bbf6693
5bc83f2
8945b93
75bbf96
72ac2b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,13 +24,11 @@ | |
import org.elasticsearch.xcontent.XContentType; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.EnumSet; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.stream.Collectors; | ||
|
||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; | ||
|
||
|
@@ -80,14 +78,11 @@ public InferenceServiceConfiguration(StreamInput in) throws IOException { | |
private static final ConstructingObjectParser<InferenceServiceConfiguration, Void> PARSER = new ConstructingObjectParser<>( | ||
"inference_service_configuration", | ||
true, | ||
args -> { | ||
List<String> taskTypes = (ArrayList<String>) args[2]; | ||
return new InferenceServiceConfiguration.Builder().setService((String) args[0]) | ||
.setName((String) args[1]) | ||
.setTaskTypes(EnumSet.copyOf(taskTypes.stream().map(TaskType::fromString).collect(Collectors.toList()))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
.setConfigurations((Map<String, SettingsConfiguration>) args[3]) | ||
.build(); | ||
} | ||
args -> new InferenceServiceConfiguration.Builder().setService((String) args[0]) | ||
.setName((String) args[1]) | ||
.setTaskTypes((List<String>) args[2]) | ||
.setConfigurations((Map<String, SettingsConfiguration>) args[3]) | ||
.build() | ||
); | ||
|
||
static { | ||
|
@@ -195,6 +190,16 @@ public Builder setTaskTypes(EnumSet<TaskType> taskTypes) { | |
return this; | ||
} | ||
|
||
public Builder setTaskTypes(List<String> taskTypes) { | ||
var enumTaskTypes = EnumSet.noneOf(TaskType.class); | ||
|
||
for (var supportedTaskTypeString : taskTypes) { | ||
enumTaskTypes.add(TaskType.fromString(supportedTaskTypeString)); | ||
} | ||
this.taskTypes = enumTaskTypes; | ||
return this; | ||
} | ||
|
||
public Builder setConfigurations(Map<String, SettingsConfiguration> configurations) { | ||
this.configurations = configurations; | ||
return this; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,28 +7,20 @@ | |
|
||
package org.elasticsearch.xpack.inference.external.http.sender; | ||
|
||
import org.elasticsearch.ElasticsearchStatusException; | ||
import org.elasticsearch.action.ActionListener; | ||
import org.elasticsearch.action.support.ListenerTimeouts; | ||
import org.elasticsearch.common.Strings; | ||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.core.TimeValue; | ||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.rest.RestStatus; | ||
import org.elasticsearch.threadpool.ThreadPool; | ||
|
||
import java.util.Objects; | ||
import java.util.concurrent.atomic.AtomicBoolean; | ||
import java.util.function.Supplier; | ||
|
||
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; | ||
|
||
class RequestTask implements RejectableTask { | ||
|
||
private final AtomicBoolean finished = new AtomicBoolean(); | ||
private final RequestManager requestCreator; | ||
private final InferenceInputs inferenceInputs; | ||
private final ActionListener<InferenceServiceResults> listener; | ||
private final TimedListener<InferenceServiceResults> timedListener; | ||
|
||
RequestTask( | ||
RequestManager requestCreator, | ||
|
@@ -38,44 +30,13 @@ class RequestTask implements RejectableTask { | |
ActionListener<InferenceServiceResults> listener | ||
) { | ||
this.requestCreator = Objects.requireNonNull(requestCreator); | ||
this.listener = getListener(Objects.requireNonNull(listener), timeout, Objects.requireNonNull(threadPool)); | ||
this.timedListener = new TimedListener<>(timeout, listener, threadPool); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved this into |
||
this.inferenceInputs = Objects.requireNonNull(inferenceInputs); | ||
} | ||
|
||
private ActionListener<InferenceServiceResults> getListener( | ||
ActionListener<InferenceServiceResults> origListener, | ||
@Nullable TimeValue timeout, | ||
ThreadPool threadPool | ||
) { | ||
ActionListener<InferenceServiceResults> notificationListener = ActionListener.wrap(result -> { | ||
finished.set(true); | ||
origListener.onResponse(result); | ||
}, e -> { | ||
finished.set(true); | ||
origListener.onFailure(e); | ||
}); | ||
|
||
if (timeout == null) { | ||
return notificationListener; | ||
} | ||
|
||
return ListenerTimeouts.wrapWithTimeout( | ||
threadPool, | ||
timeout, | ||
threadPool.executor(UTILITY_THREAD_POOL_NAME), | ||
notificationListener, | ||
(ignored) -> notificationListener.onFailure( | ||
new ElasticsearchStatusException( | ||
Strings.format("Request timed out waiting to be sent after [%s]", timeout), | ||
RestStatus.REQUEST_TIMEOUT | ||
) | ||
) | ||
); | ||
} | ||
|
||
@Override | ||
public boolean hasCompleted() { | ||
return finished.get(); | ||
return timedListener.hasCompleted(); | ||
} | ||
|
||
@Override | ||
|
@@ -90,12 +51,12 @@ public InferenceInputs getInferenceInputs() { | |
|
||
@Override | ||
public ActionListener<InferenceServiceResults> getListener() { | ||
return listener; | ||
return timedListener.getListener(); | ||
} | ||
|
||
@Override | ||
public void onRejection(Exception e) { | ||
listener.onFailure(e); | ||
timedListener.getListener().onFailure(e); | ||
} | ||
|
||
@Override | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some refactoring, I think we can use a primitive here since I don't believe we'll ever want to return null.