diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java index 8d121463fb465..f5c852a0450ae 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -92,9 +92,9 @@ public ActionRequestValidationException validate() { return e; } - if (taskType.isAnyOrSame(TaskType.COMPLETION) == false) { + if (taskType.isAnyOrSame(TaskType.CHAT_COMPLETION) == false) { var e = new ActionRequestValidationException(); - e.addValidationError("Field [taskType] must be [completion]"); + e.addValidationError("Field [taskType] must be [chat_completion]"); return e; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java index 1872ac3caa230..f548bfa0709ed 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java @@ -52,7 +52,7 @@ public void testValidation_ReturnsException_When_TaskType_IsNot_Completion() { TimeValue.timeValueSeconds(10) ); var exception = request.validate(); - assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [completion];")); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [chat_completion];")); } public void testValidation_ReturnsNull_When_TaskType_IsAny() { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index d2be50cb5e841..9aadf764dce68 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -519,14 +519,19 @@ public void testSupportedStream() throws Exception { public void testUnifiedCompletionInference() throws Exception { String modelId = "streaming"; - putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION)); + putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION)); var singleModel = getModel(modelId); assertEquals(modelId, singleModel.get("inference_id")); - assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type")); + assertEquals(TaskType.CHAT_COMPLETION.toString(), singleModel.get("task_type")); var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList(); try { - var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER); + var events = unifiedCompletionInferOnMockService( + modelId, + TaskType.CHAT_COMPLETION, + input, + VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER + ); var expectedResponses = expectedResultsIterator(input); assertThat(events.size(), equalTo((input.size() + 1) * 2)); events.forEach(event -> { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index e071b704c233e..31e15a8c1c35b 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -129,7 +129,7 @@ public void unifiedCompletionInfer( ActionListener listener ) { switch (model.getConfigurations().getTaskType()) { - case COMPLETION -> listener.onResponse(makeUnifiedResults(request)); + case CHAT_COMPLETION -> listener.onResponse(makeUnifiedResults(request)); default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index 1478130f6a6c8..9354ac2a83182 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -52,7 +52,7 @@ public TransportUnifiedCompletionInferenceAction( @Override protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, UnparsedModel unparsedModel) { - return request.getTaskType().isAnyOrSame(TaskType.COMPLETION) == false || unparsedModel.taskType() != TaskType.COMPLETION; + return request.getTaskType().isAnyOrSame(TaskType.CHAT_COMPLETION) == false || unparsedModel.taskType() != TaskType.CHAT_COMPLETION; } @Override @@ -64,7 +64,7 @@ protected ElasticsearchStatusException createInvalidTaskTypeException( "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", RestStatus.BAD_REQUEST, request.getTaskType(), - TaskType.COMPLETION.toString() + TaskType.CHAT_COMPLETION.toString() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index 4d730be6aa6bd..ca25b56953251 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -26,7 +26,7 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager { private static final Logger logger = LogManager.getLogger(OpenAiCompletionRequestManager.class); private static final ResponseHandler HANDLER = createCompletionHandler(); - static final String USER_ROLE = "user"; + public static final String USER_ROLE = "user"; public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { return new OpenAiCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index b5bf77cbb3c7d..e4de3d6beb800 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -23,6 +23,9 @@ public static ModelValidator buildModelValidator(TaskType taskType) { case COMPLETION -> { return new ChatCompletionModelValidator(new SimpleServiceIntegrationValidator()); } + case CHAT_COMPLETION -> { + return new ChatCompletionModelValidator(new SimpleChatCompletionServiceIntegrationValidator()); + } case SPARSE_EMBEDDING, RERANK, ANY -> { return new SimpleModelValidator(new SimpleServiceIntegrationValidator()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java new file mode 100644 index 0000000000000..2d3d3aeb5bed3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java @@ -0,0 +1,54 @@ + +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.validation; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager.USER_ROLE; + +/** + * This class uses the unified chat completion method. + */ +public class SimpleChatCompletionServiceIntegrationValidator implements ServiceIntegrationValidator { + private static final List TEST_INPUT = List.of("how big"); + + @Override + public void validate(InferenceService service, Model model, ActionListener listener) { + var a = new UnifiedChatInput(TEST_INPUT, USER_ROLE, false); + service.unifiedCompletionInfer(model, a.getRequest(), InferenceAction.Request.DEFAULT_TIMEOUT, ActionListener.wrap(r -> { + if (r != null) { + listener.onResponse(r); + } else { + listener.onFailure( + new ElasticsearchStatusException( + "Could not complete inference endpoint creation as validation call to service returned null response.", + RestStatus.BAD_REQUEST + ) + ); + } + }, e -> { + listener.onFailure( + new ElasticsearchStatusException( + "Could not complete inference endpoint creation as validation call to service threw an exception.", + RestStatus.BAD_REQUEST, + e + ) + ); + })); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index 3856a3d111b6e..35142ff05a7e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -71,7 +71,7 @@ public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInfe assertThat(e, isA(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [chat_completion]") ); assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); })); @@ -96,7 +96,7 @@ public void testThrows_IncompatibleTaskTypeException_WhenUsingRequestIsAny_Model assertThat(e, isA(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [chat_completion]") ); assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); })); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java index 0153113be75d9..a854bbdec507a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java @@ -35,6 +35,8 @@ private Map> taskTypeToModelValidatorC SimpleModelValidator.class, TaskType.COMPLETION, ChatCompletionModelValidator.class, + TaskType.CHAT_COMPLETION, + ChatCompletionModelValidator.class, TaskType.ANY, SimpleModelValidator.class );