diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceConfiguration.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceConfiguration.java index 41cf339c751d1..5004186d03848 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceConfiguration.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceConfiguration.java @@ -191,7 +191,7 @@ public Builder setName(String name) { } public Builder setTaskTypes(EnumSet taskTypes) { - this.taskTypes = taskTypes; + this.taskTypes = TaskType.copyOf(taskTypes); return this; } diff --git a/server/src/main/java/org/elasticsearch/inference/SettingsConfiguration.java b/server/src/main/java/org/elasticsearch/inference/SettingsConfiguration.java index 188b8a7e82b57..a19b6735536ef 100644 --- a/server/src/main/java/org/elasticsearch/inference/SettingsConfiguration.java +++ b/server/src/main/java/org/elasticsearch/inference/SettingsConfiguration.java @@ -28,10 +28,13 @@ import org.elasticsearch.xcontent.XContentType; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -50,6 +53,7 @@ public class SettingsConfiguration implements Writeable, ToXContentObject { private final boolean sensitive; private final boolean updatable; private final SettingsConfigurationFieldType type; + private final EnumSet supportedTaskTypes; /** * Constructs a new {@link SettingsConfiguration} instance with specified properties. @@ -61,6 +65,7 @@ public class SettingsConfiguration implements Writeable, ToXContentObject { * @param sensitive A boolean indicating whether the configuration contains sensitive information. * @param updatable A boolean indicating whether the configuration can be updated. * @param type The type of the configuration field, defined by {@link SettingsConfigurationFieldType}. + * @param supportedTaskTypes The task types that support this field. */ private SettingsConfiguration( Object defaultValue, @@ -69,7 +74,8 @@ private SettingsConfiguration( boolean required, boolean sensitive, boolean updatable, - SettingsConfigurationFieldType type + SettingsConfigurationFieldType type, + EnumSet supportedTaskTypes ) { this.defaultValue = defaultValue; this.description = description; @@ -78,6 +84,7 @@ private SettingsConfiguration( this.sensitive = sensitive; this.updatable = updatable; this.type = type; + this.supportedTaskTypes = supportedTaskTypes; } public SettingsConfiguration(StreamInput in) throws IOException { @@ -88,6 +95,7 @@ public SettingsConfiguration(StreamInput in) throws IOException { this.sensitive = in.readBoolean(); this.updatable = in.readBoolean(); this.type = in.readEnum(SettingsConfigurationFieldType.class); + this.supportedTaskTypes = in.readEnumSet(TaskType.class); } static final ParseField DEFAULT_VALUE_FIELD = new ParseField("default_value"); @@ -97,6 +105,7 @@ public SettingsConfiguration(StreamInput in) throws IOException { static final ParseField SENSITIVE_FIELD = new ParseField("sensitive"); static final ParseField UPDATABLE_FIELD = new ParseField("updatable"); static final ParseField TYPE_FIELD = new ParseField("type"); + static final ParseField SUPPORTED_TASK_TYPES = new ParseField("supported_task_types"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -104,7 +113,15 @@ public SettingsConfiguration(StreamInput in) throws IOException { true, args -> { int i = 0; - return new SettingsConfiguration.Builder().setDefaultValue(args[i++]) + + EnumSet supportedTaskTypes = EnumSet.noneOf(TaskType.class); + var supportedTaskTypesListOfStrings = (List) args[i++]; + + for (var supportedTaskTypeString : supportedTaskTypesListOfStrings) { + supportedTaskTypes.add(TaskType.fromString(supportedTaskTypeString)); + } + + return new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue(args[i++]) .setDescription((String) args[i++]) .setLabel((String) args[i++]) .setRequired((Boolean) args[i++]) @@ -116,6 +133,7 @@ public SettingsConfiguration(StreamInput in) throws IOException { ); static { + PARSER.declareStringArray(constructorArg(), SUPPORTED_TASK_TYPES); PARSER.declareField(optionalConstructorArg(), (p, c) -> { if (p.currentToken() == XContentParser.Token.VALUE_STRING) { return p.text(); @@ -169,28 +187,8 @@ public SettingsConfigurationFieldType getType() { return type; } - /** - * Parses a configuration value from a parser context. - * This method can parse strings, numbers, booleans, objects, and null values, matching the types commonly - * supported in {@link SettingsConfiguration}. - * - * @param p the {@link org.elasticsearch.xcontent.XContentParser} instance from which to parse the configuration value. - */ - public static Object parseConfigurationValue(XContentParser p) throws IOException { - - if (p.currentToken() == XContentParser.Token.VALUE_STRING) { - return p.text(); - } else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) { - return p.numberValue(); - } else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) { - return p.booleanValue(); - } else if (p.currentToken() == XContentParser.Token.START_OBJECT) { - // Crawler expects the value to be an object - return p.map(); - } else if (p.currentToken() == XContentParser.Token.VALUE_NULL) { - return null; - } - throw new XContentParseException("Unsupported token [" + p.currentToken() + "]"); + public Set getSupportedTaskTypes() { + return supportedTaskTypes; } @Override @@ -211,6 +209,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (type != null) { builder.field(TYPE_FIELD.getPreferredName(), type.toString()); } + builder.field(SUPPORTED_TASK_TYPES.getPreferredName(), supportedTaskTypes); } builder.endObject(); return builder; @@ -237,6 +236,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(sensitive); out.writeBoolean(updatable); out.writeEnum(type); + out.writeEnumSet(supportedTaskTypes); } public Map toMap() { @@ -253,6 +253,7 @@ public Map toMap() { Optional.ofNullable(type).ifPresent(t -> map.put(TYPE_FIELD.getPreferredName(), t.toString())); + map.put(SUPPORTED_TASK_TYPES.getPreferredName(), supportedTaskTypes); return map; } @@ -267,12 +268,13 @@ public boolean equals(Object o) { && Objects.equals(defaultValue, that.defaultValue) && Objects.equals(description, that.description) && Objects.equals(label, that.label) - && type == that.type; + && type == that.type + && Objects.equals(supportedTaskTypes, that.supportedTaskTypes); } @Override public int hashCode() { - return Objects.hash(defaultValue, description, label, required, sensitive, updatable, type); + return Objects.hash(defaultValue, description, label, required, sensitive, updatable, type, supportedTaskTypes); } public static class Builder { @@ -284,6 +286,11 @@ public static class Builder { private boolean sensitive; private boolean updatable; private SettingsConfigurationFieldType type; + private final EnumSet supportedTaskTypes; + + public Builder(EnumSet supportedTaskTypes) { + this.supportedTaskTypes = TaskType.copyOf(Objects.requireNonNull(supportedTaskTypes)); + } public Builder setDefaultValue(Object defaultValue) { this.defaultValue = defaultValue; @@ -321,7 +328,7 @@ public Builder setType(SettingsConfigurationFieldType type) { } public SettingsConfiguration build() { - return new SettingsConfiguration(defaultValue, description, label, required, sensitive, updatable, type); + return new SettingsConfiguration(defaultValue, description, label, required, sensitive, updatable, type, supportedTaskTypes); } } } diff --git a/server/src/main/java/org/elasticsearch/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java index 17e77be43bd1a..73a0e3cc8a774 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -16,6 +16,7 @@ import org.elasticsearch.rest.RestStatus; import java.io.IOException; +import java.util.EnumSet; import java.util.Locale; import java.util.Objects; @@ -78,4 +79,14 @@ public void writeTo(StreamOutput out) throws IOException { public static String unsupportedTaskTypeErrorMsg(TaskType taskType, String serviceName) { return "The [" + serviceName + "] service does not support task type [" + taskType + "]"; } + + /** + * Copies a {@link EnumSet} if non-empty, otherwise returns an empty {@link EnumSet}. This is essentially the same + * as {@link EnumSet#copyOf(EnumSet)}, except it does not throw for an empty set. + * @param taskTypes task types to copy + * @return a copy of the passed in {@link EnumSet} + */ + public static EnumSet copyOf(EnumSet taskTypes) { + return taskTypes.isEmpty() ? EnumSet.noneOf(TaskType.class) : EnumSet.copyOf(taskTypes); + } } diff --git a/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java b/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java index ed78baeb9abe6..7e68f41de1b2e 100644 --- a/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java +++ b/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java @@ -11,6 +11,8 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import java.util.EnumSet; + import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; import static org.elasticsearch.test.ESTestCase.randomBoolean; import static org.elasticsearch.test.ESTestCase.randomInt; @@ -18,7 +20,9 @@ public class SettingsConfigurationTestUtils { public static SettingsConfiguration getRandomSettingsConfigurationField() { - return new SettingsConfiguration.Builder().setDefaultValue(randomAlphaOfLength(10)) + return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue( + randomAlphaOfLength(10) + ) .setDescription(randomAlphaOfLength(10)) .setLabel(randomAlphaOfLength(10)) .setRequired(randomBoolean()) diff --git a/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTests.java b/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTests.java index 551a25fe52f18..2b286cf86a1c8 100644 --- a/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTests.java +++ b/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTests.java @@ -34,7 +34,8 @@ public void testToXContent() throws IOException { "required": true, "sensitive": false, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion", "sparse_embedding", "rerank"] } """); @@ -56,7 +57,8 @@ public void testToXContent_WithNumericSelectOptions() throws IOException { "required": true, "sensitive": false, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding"] } """); @@ -74,7 +76,8 @@ public void testToXContentCrawlerConfig_WithNullValue() throws IOException { String content = XContentHelper.stripWhitespace(""" { "label": "nextSyncConfig", - "value": null + "value": null, + "supported_task_types": ["text_embedding", "completion", "sparse_embedding", "rerank"] } """); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 89c79dd148598..1f17e335462a7 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -257,7 +257,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( "model", - new SettingsConfiguration.Builder().setDescription("") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription("") .setLabel("Model") .setRequired(true) .setSensitive(true) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 77c762a38baaf..e79c8b9bad522 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -171,7 +171,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( "model", - new SettingsConfiguration.Builder().setDescription("") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.RERANK)).setDescription("") .setLabel("Model") .setRequired(true) .setSensitive(true) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index bef0b1812beda..f700f6672fd63 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -205,7 +205,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( "model", - new SettingsConfiguration.Builder().setDescription("") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("") .setLabel("Model") .setRequired(true) .setSensitive(false) @@ -215,7 +215,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( "hidden_field", - new SettingsConfiguration.Builder().setDescription("") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("") .setLabel("Hidden Field") .setRequired(true) .setSensitive(false) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index b78586174dc1e..b0e43c8607078 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -257,7 +257,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( "model_id", - new SettingsConfiguration.Builder().setDescription("") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.COMPLETION)).setDescription("") .setLabel("Model ID") .setRequired(true) .setSensitive(true) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 24f7fa182b7c2..0fd0c281d8bc6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -379,7 +379,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( SERVICE_ID, - new SettingsConfiguration.Builder().setDescription("The name of the model service to use for the {infer} task.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The name of the model service to use for the {infer} task." + ) .setLabel("Project ID") .setRequired(true) .setSensitive(false) @@ -390,7 +392,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( HOST, - new SettingsConfiguration.Builder().setDescription( + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( "The name of the host address used for the {infer} task. You can find the host address at " + "https://opensearch.console.aliyun.com/cn-shanghai/rag/api-key[ the API keys section] " + "of the documentation." @@ -405,7 +407,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( HTTP_SCHEMA_NAME, - new SettingsConfiguration.Builder().setDescription("") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("") .setLabel("HTTP Schema") .setRequired(true) .setSensitive(false) @@ -416,7 +418,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( WORKSPACE_NAME, - new SettingsConfiguration.Builder().setDescription("The name of the workspace used for the {infer} task.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The name of the workspace used for the {infer} task." + ) .setLabel("Workspace") .setRequired(true) .setSensitive(false) @@ -426,9 +430,12 @@ public static InferenceServiceConfiguration get() { ); configurationMap.putAll( - DefaultSecretSettings.toSettingsConfigurationWithDescription("A valid API key for the AlibabaCloud AI Search API.") + DefaultSecretSettings.toSettingsConfigurationWithDescription( + "A valid API key for the AlibabaCloud AI Search API.", + supportedTaskTypes + ) ); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java index 2105da235babe..80750063b120e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java @@ -18,11 +18,13 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.Collections; +import java.util.EnumSet; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -128,7 +130,9 @@ public static Map get() { var configurationMap = new HashMap(); configurationMap.put( ACCESS_KEY_FIELD, - new SettingsConfiguration.Builder().setDescription("A valid AWS access key that has permissions to use Amazon Bedrock.") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription( + "A valid AWS access key that has permissions to use Amazon Bedrock." + ) .setLabel("Access Key") .setRequired(true) .setSensitive(true) @@ -138,7 +142,9 @@ public static Map get() { ); configurationMap.put( SECRET_KEY_FIELD, - new SettingsConfiguration.Builder().setDescription("A valid AWS secret key that is paired with the access_key.") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription( + "A valid AWS secret key that is paired with the access_key." + ) .setLabel("Secret Key") .setRequired(true) .setSensitive(true) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 07c5e91776192..e13c668197a8f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -378,7 +378,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( PROVIDER_FIELD, - new SettingsConfiguration.Builder().setDescription("The model provider for your deployment.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The model provider for your deployment.") .setLabel("Provider") .setRequired(true) .setSensitive(false) @@ -389,7 +389,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MODEL_FIELD, - new SettingsConfiguration.Builder().setDescription( + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( "The base model ID or an ARN to a custom model based on a foundational model." ) .setLabel("Model") @@ -402,7 +402,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( REGION_FIELD, - new SettingsConfiguration.Builder().setDescription("The region that your model or ARN is deployed in.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The region that your model or ARN is deployed in." + ) .setLabel("Region") .setRequired(true) .setSensitive(false) @@ -414,7 +416,8 @@ public static InferenceServiceConfiguration get() { configurationMap.putAll(AmazonBedrockSecretSettings.Configuration.get()); configurationMap.putAll( RateLimitSettings.toSettingsConfigurationWithDescription( - "By default, the amazonbedrock service sets the number of requests allowed per minute to 240." + "By default, the amazonbedrock service sets the number of requests allowed per minute to 240.", + supportedTaskTypes ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 9dbfb0732f463..64fe42fbbc171 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -256,7 +256,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder().setDescription("The name of the model to use for the inference task.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The name of the model to use for the inference task." + ) .setLabel("Model ID") .setRequired(true) .setSensitive(false) @@ -265,10 +267,11 @@ public static InferenceServiceConfiguration get() { .build() ); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration()); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); configurationMap.putAll( RateLimitSettings.toSettingsConfigurationWithDescription( - "By default, the anthropic service sets the number of requests allowed per minute to 50." + "By default, the anthropic service sets the number of requests allowed per minute to 50.", + supportedTaskTypes ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 649540f7efc5c..88d5b54398d06 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -406,7 +406,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( TARGET_FIELD, - new SettingsConfiguration.Builder().setDescription("The target URL of your Azure AI Studio model deployment.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The target URL of your Azure AI Studio model deployment." + ) .setLabel("Target") .setRequired(true) .setSensitive(false) @@ -417,7 +419,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( ENDPOINT_TYPE_FIELD, - new SettingsConfiguration.Builder().setDescription( + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( "Specifies the type of endpoint that is used in your model deployment." ) .setLabel("Endpoint Type") @@ -430,7 +432,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( PROVIDER_FIELD, - new SettingsConfiguration.Builder().setDescription("The model provider for your deployment.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The model provider for your deployment.") .setLabel("Provider") .setRequired(true) .setSensitive(false) @@ -439,8 +441,8 @@ public static InferenceServiceConfiguration get() { .build() ); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration()); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java index 0601daf562ce9..8405b35a35d9a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java @@ -18,11 +18,13 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.Collections; +import java.util.EnumSet; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -146,7 +148,9 @@ public static Map get() { var configurationMap = new HashMap(); configurationMap.put( API_KEY, - new SettingsConfiguration.Builder().setDescription("You must provide either an API key or an Entra ID.") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription( + "You must provide either an API key or an Entra ID." + ) .setLabel("API Key") .setRequired(false) .setSensitive(true) @@ -156,7 +160,9 @@ public static Map get() { ); configurationMap.put( ENTRA_ID, - new SettingsConfiguration.Builder().setDescription("You must provide either an API key or an Entra ID.") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription( + "You must provide either an API key or an Entra ID." + ) .setLabel("Entra ID") .setRequired(false) .setSensitive(true) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 4fca5a460a12a..5b622d68f2c25 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -351,7 +351,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( RESOURCE_NAME, - new SettingsConfiguration.Builder().setDescription("The name of your Azure OpenAI resource.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The name of your Azure OpenAI resource.") .setLabel("Resource Name") .setRequired(true) .setSensitive(false) @@ -362,7 +362,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( API_VERSION, - new SettingsConfiguration.Builder().setDescription("The Azure API version ID to use.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The Azure API version ID to use.") .setLabel("API Version") .setRequired(true) .setSensitive(false) @@ -373,7 +373,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( DEPLOYMENT_ID, - new SettingsConfiguration.Builder().setDescription("The deployment name of your deployed models.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The deployment name of your deployed models.") .setLabel("Deployment ID") .setRequired(true) .setSensitive(false) @@ -385,7 +385,8 @@ public static InferenceServiceConfiguration get() { configurationMap.putAll(AzureOpenAiSecretSettings.Configuration.get()); configurationMap.putAll( RateLimitSettings.toSettingsConfigurationWithDescription( - "The azureopenai service sets a default number of requests allowed per minute depending on the task type." + "The azureopenai service sets a default number of requests allowed per minute depending on the task type.", + supportedTaskTypes ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 60ab8ca68d5d9..60326a8a34ca3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -363,8 +363,8 @@ public static InferenceServiceConfiguration get() { () -> { var configurationMap = new HashMap(); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration()); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 663b40bc14693..a8a0053796e8c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -379,7 +379,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder().setDescription("The name of the model to use for the inference task.") + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDescription( + "The name of the model to use for the inference task." + ) .setLabel("Model ID") .setRequired(true) .setSensitive(false) @@ -390,7 +392,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MAX_INPUT_TOKENS, - new SettingsConfiguration.Builder().setDescription("Allows you to specify the maximum number of tokens per input.") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription( + "Allows you to specify the maximum number of tokens per input." + ) .setLabel("Maximum Input Tokens") .setRequired(false) .setSensitive(false) @@ -399,7 +403,7 @@ public static InferenceServiceConfiguration get() { .build() ); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES_FOR_SERVICES_API)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 2931f2e23f12d..c538b9acf1321 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -1147,7 +1147,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder().setDefaultValue(MULTILINGUAL_E5_SMALL_MODEL_ID) + new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue(MULTILINGUAL_E5_SMALL_MODEL_ID) .setDescription("The name of the model to use for the inference task.") .setLabel("Model ID") .setRequired(true) @@ -1159,7 +1159,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( NUM_ALLOCATIONS, - new SettingsConfiguration.Builder().setDefaultValue(1) + new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue(1) .setDescription("The total number of allocations this model is assigned across machine learning nodes.") .setLabel("Number Allocations") .setRequired(true) @@ -1171,7 +1171,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( NUM_THREADS, - new SettingsConfiguration.Builder().setDefaultValue(2) + new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue(2) .setDescription("Sets the number of threads used by each model allocation during inference.") .setLabel("Number Threads") .setRequired(true) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 1dbf2ca3e2dad..205cc545a23f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -351,7 +351,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder().setDescription("ID of the LLM you're using.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("ID of the LLM you're using.") .setLabel("Model ID") .setRequired(true) .setSensitive(false) @@ -360,8 +360,8 @@ public static InferenceServiceConfiguration get() { .build() ); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration()); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java index b185800ed75f4..9a39e200368cf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java @@ -18,11 +18,13 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.Collections; +import java.util.EnumSet; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -122,7 +124,9 @@ public static Map get() { var configurationMap = new HashMap(); configurationMap.put( SERVICE_ACCOUNT_JSON, - new SettingsConfiguration.Builder().setDescription("API Key for the provider you're connecting to.") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK)).setDescription( + "API Key for the provider you're connecting to." + ) .setLabel("Credentials JSON") .setRequired(true) .setSensitive(true) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 8fe9f29c73747..55397b2398d39 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -329,7 +329,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder().setDescription("ID of the LLM you're using.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("ID of the LLM you're using.") .setLabel("Model ID") .setRequired(true) .setSensitive(false) @@ -340,7 +340,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( LOCATION, - new SettingsConfiguration.Builder().setDescription( + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( "Please provide the GCP region where the Vertex AI API(s) is enabled. " + "For more information, refer to the {geminiVertexAIDocs}." ) @@ -354,7 +354,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( PROJECT_ID, - new SettingsConfiguration.Builder().setDescription( + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( "The GCP Project ID which has Vertex AI API(s) enabled. For more information " + "on the URL, refer to the {geminiVertexAIDocs}." ) @@ -367,7 +367,7 @@ public static InferenceServiceConfiguration get() { ); configurationMap.putAll(GoogleVertexAiSecretSettings.Configuration.get()); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index ef6beb8ec2627..73c1446b9bb26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -181,7 +181,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( URL, - new SettingsConfiguration.Builder().setDefaultValue("https://api.openai.com/v1/embeddings") + new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings") .setDescription("The URL endpoint to use for the requests.") .setLabel("URL") .setRequired(true) @@ -191,8 +191,8 @@ public static InferenceServiceConfiguration get() { .build() ); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration()); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 52d42f570a413..9c76cc5c41fb1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -176,7 +176,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( URL, - new SettingsConfiguration.Builder().setDescription("The URL endpoint to use for the requests.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The URL endpoint to use for the requests.") .setLabel("URL") .setRequired(true) .setSensitive(false) @@ -185,8 +185,8 @@ public static InferenceServiceConfiguration get() { .build() ); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration()); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index dd368f88a993c..477225f00d22b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -324,7 +324,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( API_VERSION, - new SettingsConfiguration.Builder().setDescription("The IBM Watsonx API version ID to use.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM Watsonx API version ID to use.") .setLabel("API Version") .setRequired(true) .setSensitive(false) @@ -335,7 +335,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( PROJECT_ID, - new SettingsConfiguration.Builder().setDescription("") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("") .setLabel("Project ID") .setRequired(true) .setSensitive(false) @@ -346,7 +346,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder().setDescription("The name of the model to use for the inference task.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The name of the model to use for the inference task." + ) .setLabel("Model ID") .setRequired(true) .setSensitive(false) @@ -357,7 +359,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( URL, - new SettingsConfiguration.Builder().setDescription("") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("") .setLabel("URL") .setRequired(true) .setSensitive(false) @@ -368,7 +370,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MAX_INPUT_TOKENS, - new SettingsConfiguration.Builder().setDescription("Allows you to specify the maximum number of tokens per input.") + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription( + "Allows you to specify the maximum number of tokens per input." + ) .setLabel("Maximum Input Tokens") .setRequired(false) .setSensitive(false) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index ed76df5875562..7ad70fc88054d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -339,8 +339,8 @@ public static InferenceServiceConfiguration get() { () -> { var configurationMap = new HashMap(); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration()); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 129d9023a1ebc..3e40575e42faf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -318,7 +318,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MODEL_FIELD, - new SettingsConfiguration.Builder().setDescription( + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( "Refer to the Mistral models documentation for the list of available text embedding models." ) .setLabel("Model") @@ -331,7 +331,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MAX_INPUT_TOKENS, - new SettingsConfiguration.Builder().setDescription("Allows you to specify the maximum number of tokens per input.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "Allows you to specify the maximum number of tokens per input." + ) .setLabel("Maximum Input Tokens") .setRequired(false) .setSensitive(false) @@ -340,8 +342,8 @@ public static InferenceServiceConfiguration get() { .build() ); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration()); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 3efd7c44c3e97..0ce5bc801b59f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -415,7 +415,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder().setDescription("The name of the model to use for the inference task.") + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDescription( + "The name of the model to use for the inference task." + ) .setLabel("Model ID") .setRequired(true) .setSensitive(false) @@ -426,7 +428,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( ORGANIZATION, - new SettingsConfiguration.Builder().setDescription("The unique identifier of your organization.") + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDescription( + "The unique identifier of your organization." + ) .setLabel("Organization ID") .setRequired(false) .setSensitive(false) @@ -437,7 +441,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( URL, - new SettingsConfiguration.Builder().setDefaultValue("https://api.openai.com/v1/chat/completions") + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDefaultValue( + "https://api.openai.com/v1/chat/completions" + ) .setDescription( "The OpenAI API endpoint URL. For more information on the URL, refer to the " + "https://platform.openai.com/docs/api-reference." @@ -453,12 +459,14 @@ public static InferenceServiceConfiguration get() { configurationMap.putAll( DefaultSecretSettings.toSettingsConfigurationWithDescription( "The OpenAI API authentication key. For more details about generating OpenAI API keys, " - + "refer to the https://platform.openai.com/account/api-keys." + + "refer to the https://platform.openai.com/account/api-keys.", + SUPPORTED_TASK_TYPES_FOR_SERVICES_API ) ); configurationMap.putAll( RateLimitSettings.toSettingsConfigurationWithDescription( - "Default number of requests allowed per minute. For text_embedding is 3000. For completion is 500." + "Default number of requests allowed per minute. For text_embedding is 3000. For completion is 500.", + SUPPORTED_TASK_TYPES_FOR_SERVICES_API ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java index 15e8128969ddb..d076c946889ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java @@ -17,10 +17,12 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -51,11 +53,14 @@ public static DefaultSecretSettings fromMap(@Nullable Map map) { return new DefaultSecretSettings(secureApiToken); } - public static Map toSettingsConfigurationWithDescription(String description) { + public static Map toSettingsConfigurationWithDescription( + String description, + EnumSet supportedTaskTypes + ) { var configurationMap = new HashMap(); configurationMap.put( API_KEY, - new SettingsConfiguration.Builder().setDescription(description) + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(description) .setLabel("API Key") .setRequired(true) .setSensitive(true) @@ -66,8 +71,11 @@ public static Map toSettingsConfigurationWithDesc return configurationMap; } - public static Map toSettingsConfiguration() { - return DefaultSecretSettings.toSettingsConfigurationWithDescription("API Key for the provider you're connecting to."); + public static Map toSettingsConfiguration(EnumSet supportedTaskTypes) { + return DefaultSecretSettings.toSettingsConfigurationWithDescription( + "API Key for the provider you're connecting to.", + supportedTaskTypes + ); } public DefaultSecretSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java index 30147a6d24a96..bc7e555120286 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java @@ -12,12 +12,14 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -51,11 +53,14 @@ public static RateLimitSettings of( return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute); } - public static Map toSettingsConfigurationWithDescription(String description) { + public static Map toSettingsConfigurationWithDescription( + String description, + EnumSet supportedTaskTypes + ) { var configurationMap = new HashMap(); configurationMap.put( FIELD_NAME + "." + REQUESTS_PER_MINUTE_FIELD, - new SettingsConfiguration.Builder().setDescription(description) + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(description) .setLabel("Rate Limit") .setRequired(false) .setSensitive(false) @@ -66,8 +71,8 @@ public static Map toSettingsConfigurationWithDesc return configurationMap; } - public static Map toSettingsConfiguration() { - return RateLimitSettings.toSettingsConfigurationWithDescription("Minimize the number of rate limit errors."); + public static Map toSettingsConfiguration(EnumSet supportedTaskTypes) { + return RateLimitSettings.toSettingsConfigurationWithDescription("Minimize the number of rate limit errors.", supportedTaskTypes); } /** diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index cfe47098e0141..92544d5535acb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -448,7 +448,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion"] }, "api_key": { "description": "A valid API key for the AlibabaCloud AI Search API.", @@ -456,7 +457,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion"] }, "service_id": { "description": "The name of the model service to use for the {infer} task.", @@ -464,7 +466,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion"] }, "host": { "description": "The name of the host address used for the {infer} task. You can find the host address at https://opensearch.console.aliyun.com/cn-shanghai/rag/api-key[ the API keys section] of the documentation.", @@ -472,7 +475,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -480,7 +484,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion"] }, "http_schema": { "description": "", @@ -488,7 +493,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index ce4d928458dca..c11d4b4c7923d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -166,7 +166,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "provider": { "description": "The model provider for your deployment.", @@ -174,7 +175,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "access_key": { "description": "A valid AWS access key that has permissions to use Amazon Bedrock.", @@ -182,7 +184,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "model": { "description": "The base model ID or an ARN to a custom model based on a foundational model.", @@ -190,7 +193,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "rate_limit.requests_per_minute": { "description": "By default, the amazonbedrock service sets the number of requests allowed per minute to 240.", @@ -198,7 +202,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "completion"] }, "region": { "description": "The region that your model or ARN is deployed in.", @@ -206,7 +211,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 7eb7ad1d0a19e..33101a3e02661 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -614,7 +614,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["completion"] }, "rate_limit.requests_per_minute": { "description": "By default, the anthropic service sets the number of requests allowed per minute to 50.", @@ -622,7 +623,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -630,7 +632,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 72ebd0b96bdc1..d2e4652b96488 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -1401,7 +1401,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "provider": { "description": "The model provider for your deployment.", @@ -1409,7 +1410,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "api_key": { "description": "API Key for the provider you're connecting to.", @@ -1417,7 +1419,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1425,7 +1428,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "completion"] }, "target": { "description": "The target URL of your Azure AI Studio model deployment.", @@ -1433,7 +1437,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 5c69815cbb0ab..52527d74aad19 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -1470,7 +1470,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "entra_id": { "description": "You must provide either an API key or an Entra ID.", @@ -1478,7 +1479,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "rate_limit.requests_per_minute": { "description": "The azureopenai service sets a default number of requests allowed per minute depending on the task type.", @@ -1486,7 +1488,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "completion"] }, "deployment_id": { "description": "The deployment name of your deployed models.", @@ -1494,7 +1497,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "resource_name": { "description": "The name of your Azure OpenAI resource.", @@ -1502,7 +1506,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "api_version": { "description": "The Azure API version ID to use.", @@ -1510,7 +1515,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index dbad76fd46fc4..86b3edc4130da 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -1645,7 +1645,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "rerank", "completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1653,7 +1654,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "rerank", "completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 89bf1355ee767..5e7e93b1f5a75 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -582,7 +582,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["sparse_embedding" , "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -590,7 +591,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["sparse_embedding" , "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -598,7 +600,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["sparse_embedding"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 803151a5b3476..93b884a87fba2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -1564,7 +1564,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": true, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank"] }, "num_threads": { "default_value": 2, @@ -1573,7 +1574,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank"] }, "model_id": { "default_value": ".multilingual-e5-small", @@ -1582,7 +1584,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index aa45666eb0fb1..26dae5d172fb0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -1133,7 +1133,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1141,7 +1142,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "completion"] }, "model_id": { "description": "ID of the LLM you're using.", @@ -1149,7 +1151,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 555e9f0785fa2..932dfc21e9396 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -879,7 +879,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] }, "project_id": { "description": "The GCP Project ID which has Vertex AI API(s) enabled. For more information on the URL, refer to the {geminiVertexAIDocs}.", @@ -887,7 +888,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] }, "location": { "description": "Please provide the GCP region where the Vertex AI API(s) is enabled. For more information, refer to the {geminiVertexAIDocs}.", @@ -895,7 +897,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -903,7 +906,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "rerank"] }, "model_id": { "description": "ID of the LLM you're using.", @@ -911,7 +915,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index 0b774badd56b6..53e7c6c25fd47 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -150,7 +150,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["sparse_embedding"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -158,7 +159,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["sparse_embedding"] }, "url": { "description": "The URL endpoint to use for the requests.", @@ -166,7 +168,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["sparse_embedding"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 1399712450af1..f3137d7011cec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -868,7 +868,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -876,7 +877,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding"] }, "url": { "default_value": "https://api.openai.com/v1/embeddings", @@ -885,7 +887,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 468b0f676ee85..ff99101fc4ee5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -985,7 +985,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -993,7 +994,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding"] }, "api_version": { "description": "The IBM Watsonx API version ID to use.", @@ -1001,7 +1003,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -1009,7 +1012,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding"] }, "url": { "description": "", @@ -1017,7 +1021,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 6e68bde6a1266..5fa14da4ba733 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -1843,7 +1843,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1851,7 +1852,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "rerank"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 395e29240dfd4..95ac2cde0e31b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -758,7 +758,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding"] }, "model": { "description": "Refer to the Mistral models documentation for the list of available text embedding models.", @@ -766,7 +767,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -774,7 +776,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -782,7 +785,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 4941d3ae49a23..6fddbf4450283 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -1749,7 +1749,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": true, "updatable": true, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "organization_id": { "description": "The unique identifier of your organization.", @@ -1757,7 +1758,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "rate_limit.requests_per_minute": { "description": "Default number of requests allowed per minute. For text_embedding is 3000. For completion is 500.", @@ -1765,7 +1767,8 @@ public void testGetConfiguration() throws Exception { "required": false, "sensitive": false, "updatable": false, - "type": "int" + "type": "int", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -1773,7 +1776,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "url": { "default_value": "https://api.openai.com/v1/chat/completions", @@ -1782,7 +1786,8 @@ public void testGetConfiguration() throws Exception { "required": true, "sensitive": false, "updatable": false, - "type": "str" + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] } } }