Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Adding dynamic filtering for EIS configuration #120235

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

jonathan-buttner
Copy link
Contributor

@jonathan-buttner jonathan-buttner commented Jan 15, 2025

WIP

This PR adds the ability to determine which models and task types will be supported by the cluster at the node bootup time.

This is my suggestion for the format of the response from the gateway:

GET /allowed-models
{
  "allowed-models": [
    {
      "model-name": "model-a",
      "task-types": ["text_embedding", "chat_completion"]
    },
    ...
  ]
}

My reasoning for a list instead of a single entry is that openai's gpt4-o supports completions and image generation which I'm guess would be two separate task types for us in the future. So best to allow multiple entries here.

@jonathan-buttner jonathan-buttner added >refactoring :ml Machine learning Team:ML Meta label for the ML team auto-backport Automatically create backport pull requests when merged v9.0.0 v8.18.0 labels Jan 15, 2025
@@ -78,8 +78,8 @@ default void init(Client client) {}
* Whether this service should be hidden from the API. Should be used for services
* that are not ready to be used.
*/
default Boolean hideFromConfigurationApi() {
return Boolean.FALSE;
default boolean hideFromConfigurationApi() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some refactoring, I think we can use a primitive here since I don't believe we'll ever want to return null.

List<String> taskTypes = (ArrayList<String>) args[2];
return new InferenceServiceConfiguration.Builder().setService((String) args[0])
.setName((String) args[1])
.setTaskTypes(EnumSet.copyOf(taskTypes.stream().map(TaskType::fromString).collect(Collectors.toList())))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copyOf throws if it receives an empty set so I modified this to allow an empty set via the builder. An empty set should be unlikely in production because we shouldn't be getting the configuration at all if no task types are supported but it helps the tests.

@@ -274,7 +277,12 @@ public Collection<?> createComponents(PluginServices services) {

ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
String elasticInferenceUrl = this.getElasticInferenceServiceUrl(inferenceServiceSettings);
elasticInferenceServiceComponents.set(new ElasticInferenceServiceComponents(elasticInferenceUrl));
elasticInferenceServiceComponents.set(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brendan-jugan-elastic this is where we'll need the logic to retrieve the actual enabled models and task types from the EIS gateway.

var enabledStreamingTaskTypes = EnumSet.of(TaskType.COMPLETION);
enabledStreamingTaskTypes.retainAll(enabledTaskTypes);

if (enabledStreamingTaskTypes.isEmpty() == false) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are no enabled task types we won't add any since we don't want to support anything.

}

private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
() -> {
private LazyInitializable<InferenceServiceConfiguration, RuntimeException> initConfiguration() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing static here because this depends on a field initialized in the constructor.

@jonathan-buttner jonathan-buttner marked this pull request as ready for review January 16, 2025 01:08
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

@@ -38,44 +30,13 @@ class RequestTask implements RejectableTask {
ActionListener<InferenceServiceResults> listener
) {
this.requestCreator = Objects.requireNonNull(requestCreator);
this.listener = getListener(Objects.requireNonNull(listener), timeout, Objects.requireNonNull(threadPool));
this.timedListener = new TimedListener<>(timeout, listener, threadPool);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this into TimedListener so we could access it in the new send method of HttpRequestSender.

private static EnumSet<TaskType> toTaskTypes(List<String> stringTaskTypes) {
var taskTypes = EnumSet.noneOf(TaskType.class);
for (String taskType : stringTaskTypes) {
taskTypes.add(TaskType.fromStringOrStatusException(taskType));
Copy link
Contributor Author

@jonathan-buttner jonathan-buttner Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: If the task type is invalid we should ignore it, that could result in an empty task_types array. If that happens we should remove the model entry.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
auto-backport Automatically create backport pull requests when merged :ml Machine learning >refactoring Team:ML Meta label for the ML team v8.18.0 v9.0.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants