Skip to content

Commit

Permalink
[ML] Adds a new field supported_task_types in the configuration respo…
Browse files Browse the repository at this point in the history
…nse (#120150) (#120410)

* Adding new field to settings class

* adding new available_for_task_types field

* Update docs/changelog/120150.yaml

* Delete docs/changelog/120150.yaml

* Fixing tests and task types

* Renaming field to supported_task_types

* Pulling in chat_completion addition

* Addressing feedback
  • Loading branch information
jonathan-buttner authored Jan 17, 2025
1 parent d1f9a0a commit 51204b5
Show file tree
Hide file tree
Showing 46 changed files with 325 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public Builder setName(String name) {
}

public Builder setTaskTypes(EnumSet<TaskType> taskTypes) {
this.taskTypes = taskTypes;
this.taskTypes = TaskType.copyOf(taskTypes);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
import org.elasticsearch.xcontent.XContentType;

import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
Expand All @@ -50,6 +53,7 @@ public class SettingsConfiguration implements Writeable, ToXContentObject {
private final boolean sensitive;
private final boolean updatable;
private final SettingsConfigurationFieldType type;
private final EnumSet<TaskType> supportedTaskTypes;

/**
* Constructs a new {@link SettingsConfiguration} instance with specified properties.
Expand All @@ -61,6 +65,7 @@ public class SettingsConfiguration implements Writeable, ToXContentObject {
* @param sensitive A boolean indicating whether the configuration contains sensitive information.
* @param updatable A boolean indicating whether the configuration can be updated.
* @param type The type of the configuration field, defined by {@link SettingsConfigurationFieldType}.
* @param supportedTaskTypes The task types that support this field.
*/
private SettingsConfiguration(
Object defaultValue,
Expand All @@ -69,7 +74,8 @@ private SettingsConfiguration(
boolean required,
boolean sensitive,
boolean updatable,
SettingsConfigurationFieldType type
SettingsConfigurationFieldType type,
EnumSet<TaskType> supportedTaskTypes
) {
this.defaultValue = defaultValue;
this.description = description;
Expand All @@ -78,6 +84,7 @@ private SettingsConfiguration(
this.sensitive = sensitive;
this.updatable = updatable;
this.type = type;
this.supportedTaskTypes = supportedTaskTypes;
}

public SettingsConfiguration(StreamInput in) throws IOException {
Expand All @@ -88,6 +95,7 @@ public SettingsConfiguration(StreamInput in) throws IOException {
this.sensitive = in.readBoolean();
this.updatable = in.readBoolean();
this.type = in.readEnum(SettingsConfigurationFieldType.class);
this.supportedTaskTypes = in.readEnumSet(TaskType.class);
}

static final ParseField DEFAULT_VALUE_FIELD = new ParseField("default_value");
Expand All @@ -97,14 +105,23 @@ public SettingsConfiguration(StreamInput in) throws IOException {
static final ParseField SENSITIVE_FIELD = new ParseField("sensitive");
static final ParseField UPDATABLE_FIELD = new ParseField("updatable");
static final ParseField TYPE_FIELD = new ParseField("type");
static final ParseField SUPPORTED_TASK_TYPES = new ParseField("supported_task_types");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<SettingsConfiguration, Void> PARSER = new ConstructingObjectParser<>(
"service_configuration",
true,
args -> {
int i = 0;
return new SettingsConfiguration.Builder().setDefaultValue(args[i++])

EnumSet<TaskType> supportedTaskTypes = EnumSet.noneOf(TaskType.class);
var supportedTaskTypesListOfStrings = (List<String>) args[i++];

for (var supportedTaskTypeString : supportedTaskTypesListOfStrings) {
supportedTaskTypes.add(TaskType.fromString(supportedTaskTypeString));
}

return new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue(args[i++])
.setDescription((String) args[i++])
.setLabel((String) args[i++])
.setRequired((Boolean) args[i++])
Expand All @@ -116,6 +133,7 @@ public SettingsConfiguration(StreamInput in) throws IOException {
);

static {
PARSER.declareStringArray(constructorArg(), SUPPORTED_TASK_TYPES);
PARSER.declareField(optionalConstructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
Expand Down Expand Up @@ -169,28 +187,8 @@ public SettingsConfigurationFieldType getType() {
return type;
}

/**
* Parses a configuration value from a parser context.
* This method can parse strings, numbers, booleans, objects, and null values, matching the types commonly
* supported in {@link SettingsConfiguration}.
*
* @param p the {@link org.elasticsearch.xcontent.XContentParser} instance from which to parse the configuration value.
*/
public static Object parseConfigurationValue(XContentParser p) throws IOException {

if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
} else if (p.currentToken() == XContentParser.Token.START_OBJECT) {
// Crawler expects the value to be an object
return p.map();
} else if (p.currentToken() == XContentParser.Token.VALUE_NULL) {
return null;
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
public Set<TaskType> getSupportedTaskTypes() {
return supportedTaskTypes;
}

@Override
Expand All @@ -211,6 +209,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (type != null) {
builder.field(TYPE_FIELD.getPreferredName(), type.toString());
}
builder.field(SUPPORTED_TASK_TYPES.getPreferredName(), supportedTaskTypes);
}
builder.endObject();
return builder;
Expand All @@ -237,6 +236,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(sensitive);
out.writeBoolean(updatable);
out.writeEnum(type);
out.writeEnumSet(supportedTaskTypes);
}

public Map<String, Object> toMap() {
Expand All @@ -253,6 +253,7 @@ public Map<String, Object> toMap() {

Optional.ofNullable(type).ifPresent(t -> map.put(TYPE_FIELD.getPreferredName(), t.toString()));

map.put(SUPPORTED_TASK_TYPES.getPreferredName(), supportedTaskTypes);
return map;
}

Expand All @@ -267,12 +268,13 @@ public boolean equals(Object o) {
&& Objects.equals(defaultValue, that.defaultValue)
&& Objects.equals(description, that.description)
&& Objects.equals(label, that.label)
&& type == that.type;
&& type == that.type
&& Objects.equals(supportedTaskTypes, that.supportedTaskTypes);
}

@Override
public int hashCode() {
return Objects.hash(defaultValue, description, label, required, sensitive, updatable, type);
return Objects.hash(defaultValue, description, label, required, sensitive, updatable, type, supportedTaskTypes);
}

public static class Builder {
Expand All @@ -284,6 +286,11 @@ public static class Builder {
private boolean sensitive;
private boolean updatable;
private SettingsConfigurationFieldType type;
private final EnumSet<TaskType> supportedTaskTypes;

public Builder(EnumSet<TaskType> supportedTaskTypes) {
this.supportedTaskTypes = TaskType.copyOf(Objects.requireNonNull(supportedTaskTypes));
}

public Builder setDefaultValue(Object defaultValue) {
this.defaultValue = defaultValue;
Expand Down Expand Up @@ -321,7 +328,7 @@ public Builder setType(SettingsConfigurationFieldType type) {
}

public SettingsConfiguration build() {
return new SettingsConfiguration(defaultValue, description, label, required, sensitive, updatable, type);
return new SettingsConfiguration(defaultValue, description, label, required, sensitive, updatable, type, supportedTaskTypes);
}
}
}
11 changes: 11 additions & 0 deletions server/src/main/java/org/elasticsearch/inference/TaskType.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.rest.RestStatus;

import java.io.IOException;
import java.util.EnumSet;
import java.util.Locale;
import java.util.Objects;

Expand Down Expand Up @@ -78,4 +79,14 @@ public void writeTo(StreamOutput out) throws IOException {
public static String unsupportedTaskTypeErrorMsg(TaskType taskType, String serviceName) {
return "The [" + serviceName + "] service does not support task type [" + taskType + "]";
}

/**
* Copies a {@link EnumSet<TaskType>} if non-empty, otherwise returns an empty {@link EnumSet<TaskType>}. This is essentially the same
* as {@link EnumSet#copyOf(EnumSet)}, except it does not throw for an empty set.
* @param taskTypes task types to copy
* @return a copy of the passed in {@link EnumSet<TaskType>}
*/
public static EnumSet<TaskType> copyOf(EnumSet<TaskType> taskTypes) {
return taskTypes.isEmpty() ? EnumSet.noneOf(TaskType.class) : EnumSet.copyOf(taskTypes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@

import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;

import java.util.EnumSet;

import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
import static org.elasticsearch.test.ESTestCase.randomBoolean;
import static org.elasticsearch.test.ESTestCase.randomInt;

public class SettingsConfigurationTestUtils {

public static SettingsConfiguration getRandomSettingsConfigurationField() {
return new SettingsConfiguration.Builder().setDefaultValue(randomAlphaOfLength(10))
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
randomAlphaOfLength(10)
)
.setDescription(randomAlphaOfLength(10))
.setLabel(randomAlphaOfLength(10))
.setRequired(randomBoolean())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public void testToXContent() throws IOException {
"required": true,
"sensitive": false,
"updatable": true,
"type": "str"
"type": "str",
"supported_task_types": ["text_embedding", "completion", "sparse_embedding", "rerank"]
}
""");

Expand All @@ -56,7 +57,8 @@ public void testToXContent_WithNumericSelectOptions() throws IOException {
"required": true,
"sensitive": false,
"updatable": true,
"type": "str"
"type": "str",
"supported_task_types": ["text_embedding"]
}
""");

Expand All @@ -74,7 +76,8 @@ public void testToXContentCrawlerConfig_WithNullValue() throws IOException {
String content = XContentHelper.stripWhitespace("""
{
"label": "nextSyncConfig",
"value": null
"value": null,
"supported_task_types": ["text_embedding", "completion", "sparse_embedding", "rerank"]
}
""");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
"model",
new SettingsConfiguration.Builder().setDescription("")
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription("")
.setLabel("Model")
.setRequired(true)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
"model",
new SettingsConfiguration.Builder().setDescription("")
new SettingsConfiguration.Builder(EnumSet.of(TaskType.RERANK)).setDescription("")
.setLabel("Model")
.setRequired(true)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
"model",
new SettingsConfiguration.Builder().setDescription("")
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
.setLabel("Model")
.setRequired(true)
.setSensitive(false)
Expand All @@ -215,7 +215,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
"hidden_field",
new SettingsConfiguration.Builder().setDescription("")
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
.setLabel("Hidden Field")
.setRequired(true)
.setSensitive(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
"model_id",
new SettingsConfiguration.Builder().setDescription("")
new SettingsConfiguration.Builder(EnumSet.of(TaskType.COMPLETION)).setDescription("")
.setLabel("Model ID")
.setRequired(true)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,9 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
SERVICE_ID,
new SettingsConfiguration.Builder().setDescription("The name of the model service to use for the {infer} task.")
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"The name of the model service to use for the {infer} task."
)
.setLabel("Project ID")
.setRequired(true)
.setSensitive(false)
Expand All @@ -390,7 +392,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
HOST,
new SettingsConfiguration.Builder().setDescription(
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"The name of the host address used for the {infer} task. You can find the host address at "
+ "https://opensearch.console.aliyun.com/cn-shanghai/rag/api-key[ the API keys section] "
+ "of the documentation."
Expand All @@ -405,7 +407,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
HTTP_SCHEMA_NAME,
new SettingsConfiguration.Builder().setDescription("")
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("")
.setLabel("HTTP Schema")
.setRequired(true)
.setSensitive(false)
Expand All @@ -416,7 +418,9 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
WORKSPACE_NAME,
new SettingsConfiguration.Builder().setDescription("The name of the workspace used for the {infer} task.")
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"The name of the workspace used for the {infer} task."
)
.setLabel("Workspace")
.setRequired(true)
.setSensitive(false)
Expand All @@ -426,9 +430,12 @@ public static InferenceServiceConfiguration get() {
);

configurationMap.putAll(
DefaultSecretSettings.toSettingsConfigurationWithDescription("A valid API key for the AlibabaCloud AI Search API.")
DefaultSecretSettings.toSettingsConfigurationWithDescription(
"A valid API key for the AlibabaCloud AI Search API.",
supportedTaskTypes
)
);
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration());
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes));

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -128,7 +130,9 @@ public static Map<String, SettingsConfiguration> get() {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
ACCESS_KEY_FIELD,
new SettingsConfiguration.Builder().setDescription("A valid AWS access key that has permissions to use Amazon Bedrock.")
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
"A valid AWS access key that has permissions to use Amazon Bedrock."
)
.setLabel("Access Key")
.setRequired(true)
.setSensitive(true)
Expand All @@ -138,7 +142,9 @@ public static Map<String, SettingsConfiguration> get() {
);
configurationMap.put(
SECRET_KEY_FIELD,
new SettingsConfiguration.Builder().setDescription("A valid AWS secret key that is paired with the access_key.")
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
"A valid AWS secret key that is paired with the access_key."
)
.setLabel("Secret Key")
.setRequired(true)
.setSensitive(true)
Expand Down
Loading

0 comments on commit 51204b5

Please sign in to comment.