diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java index 8d121463fb465..f5c852a0450ae 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -92,9 +92,9 @@ public ActionRequestValidationException validate() { return e; } - if (taskType.isAnyOrSame(TaskType.COMPLETION) == false) { + if (taskType.isAnyOrSame(TaskType.CHAT_COMPLETION) == false) { var e = new ActionRequestValidationException(); - e.addValidationError("Field [taskType] must be [completion]"); + e.addValidationError("Field [taskType] must be [chat_completion]"); return e; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java index 1872ac3caa230..f548bfa0709ed 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java @@ -52,7 +52,7 @@ public void testValidation_ReturnsException_When_TaskType_IsNot_Completion() { TimeValue.timeValueSeconds(10) ); var exception = request.validate(); - assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [completion];")); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [chat_completion];")); } public void testValidation_ReturnsNull_When_TaskType_IsAny() { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index d2be50cb5e841..610fafb8390da 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -272,9 +272,9 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { List services = getServices(TaskType.CHAT_COMPLETION); if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled() || ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) { - assertThat(services.size(), equalTo(2)); + assertThat(services.size(), equalTo(3)); } else { - assertThat(services.size(), equalTo(1)); + assertThat(services.size(), equalTo(2)); } String[] providers = new String[services.size()]; @@ -283,7 +283,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { providers[i] = (String) serviceConfig.get("service"); } - var providerList = new ArrayList<>(List.of("openai")); + var providerList = new ArrayList<>(List.of("openai", "streaming_completion_test_service")); if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled() || ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) { @@ -519,14 +519,19 @@ public void testSupportedStream() throws Exception { public void testUnifiedCompletionInference() throws Exception { String modelId = "streaming"; - putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION)); + putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION)); var singleModel = getModel(modelId); assertEquals(modelId, singleModel.get("inference_id")); - assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type")); + assertEquals(TaskType.CHAT_COMPLETION.toString(), singleModel.get("task_type")); var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList(); try { - var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER); + var events = unifiedCompletionInferOnMockService( + modelId, + TaskType.CHAT_COMPLETION, + input, + VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER + ); var expectedResponses = expectedResultsIterator(input); assertThat(events.size(), equalTo((input.size() + 1) * 2)); events.forEach(event -> { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index e071b704c233e..b78586174dc1e 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -55,9 +55,9 @@ public List getInferenceServiceFactories() { public static class TestInferenceService extends AbstractTestInferenceService { private static final String NAME = "streaming_completion_test_service"; - private static final Set supportedStreamingTasks = Set.of(TaskType.COMPLETION); + private static final Set supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.COMPLETION); + private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {} @@ -129,7 +129,7 @@ public void unifiedCompletionInfer( ActionListener listener ) { switch (model.getConfigurations().getTaskType()) { - case COMPLETION -> listener.onResponse(makeUnifiedResults(request)); + case CHAT_COMPLETION -> listener.onResponse(makeUnifiedResults(request)); default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index 1478130f6a6c8..9354ac2a83182 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -52,7 +52,7 @@ public TransportUnifiedCompletionInferenceAction( @Override protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, UnparsedModel unparsedModel) { - return request.getTaskType().isAnyOrSame(TaskType.COMPLETION) == false || unparsedModel.taskType() != TaskType.COMPLETION; + return request.getTaskType().isAnyOrSame(TaskType.CHAT_COMPLETION) == false || unparsedModel.taskType() != TaskType.CHAT_COMPLETION; } @Override @@ -64,7 +64,7 @@ protected ElasticsearchStatusException createInvalidTaskTypeException( "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", RestStatus.BAD_REQUEST, request.getTaskType(), - TaskType.COMPLETION.toString() + TaskType.CHAT_COMPLETION.toString() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index 4d730be6aa6bd..ca25b56953251 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -26,7 +26,7 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager { private static final Logger logger = LogManager.getLogger(OpenAiCompletionRequestManager.class); private static final ResponseHandler HANDLER = createCompletionHandler(); - static final String USER_ROLE = "user"; + public static final String USER_ROLE = "user"; public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { return new OpenAiCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index cb554cf288121..663b40bc14693 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -251,7 +251,7 @@ private static ElasticInferenceServiceModel createModel( eisServiceComponents, context ); - case COMPLETION -> new ElasticInferenceServiceCompletionModel( + case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel( inferenceEntityId, taskType, NAME, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index b5bf77cbb3c7d..e4de3d6beb800 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -23,6 +23,9 @@ public static ModelValidator buildModelValidator(TaskType taskType) { case COMPLETION -> { return new ChatCompletionModelValidator(new SimpleServiceIntegrationValidator()); } + case CHAT_COMPLETION -> { + return new ChatCompletionModelValidator(new SimpleChatCompletionServiceIntegrationValidator()); + } case SPARSE_EMBEDDING, RERANK, ANY -> { return new SimpleModelValidator(new SimpleServiceIntegrationValidator()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java new file mode 100644 index 0000000000000..1092d84a6ef6b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java @@ -0,0 +1,59 @@ + +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.validation; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager.USER_ROLE; + +/** + * This class uses the unified chat completion method to perform validation. + */ +public class SimpleChatCompletionServiceIntegrationValidator implements ServiceIntegrationValidator { + private static final List TEST_INPUT = List.of("how big"); + + @Override + public void validate(InferenceService service, Model model, ActionListener listener) { + var chatCompletionInput = new UnifiedChatInput(TEST_INPUT, USER_ROLE, false); + service.unifiedCompletionInfer( + model, + chatCompletionInput.getRequest(), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.wrap(r -> { + if (r != null) { + listener.onResponse(r); + } else { + listener.onFailure( + new ElasticsearchStatusException( + "Could not complete inference endpoint creation as validation call to service returned null response.", + RestStatus.BAD_REQUEST + ) + ); + } + }, e -> { + listener.onFailure( + new ElasticsearchStatusException( + "Could not complete inference endpoint creation as validation call to service threw an exception.", + RestStatus.BAD_REQUEST, + e + ) + ); + }) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index a723e5a9dffdf..c0fc818e421d0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -57,11 +57,15 @@ public abstract class BaseTransportInferenceActionTestCase action; protected static final String serviceId = "serviceId"; - protected static final TaskType taskType = TaskType.COMPLETION; + protected final TaskType taskType; protected static final String inferenceId = "inferenceEntityId"; protected InferenceServiceRegistry serviceRegistry; protected InferenceStats inferenceStats; + public BaseTransportInferenceActionTestCase(TaskType taskType) { + this.taskType = taskType; + } + @Before public void setUp() throws Exception { super.setUp(); @@ -377,7 +381,7 @@ protected void mockModelAndServiceRegistry(InferenceService service) { when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); } - protected void mockValidLicenseState(){ + protected void mockValidLicenseState() { when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index a5efe04c22c04..c303e029cb415 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; @@ -20,6 +21,10 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionTestCase { + public TransportInferenceActionTests() { + super(TaskType.COMPLETION); + } + @Override protected BaseTransportInferenceAction createAction( TransportService transportService, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index 3856a3d111b6e..e8e7d9ac21bed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -33,6 +33,10 @@ public class TransportUnifiedCompletionActionTests extends BaseTransportInferenceActionTestCase { + public TransportUnifiedCompletionActionTests() { + super(TaskType.CHAT_COMPLETION); + } + @Override protected BaseTransportInferenceAction createAction( TransportService transportService, @@ -71,7 +75,7 @@ public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInfe assertThat(e, isA(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [chat_completion]") ); assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); })); @@ -96,7 +100,7 @@ public void testThrows_IncompatibleTaskTypeException_WhenUsingRequestIsAny_Model assertThat(e, isA(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [chat_completion]") ); assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); })); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index fe076eb721ea2..8ea7e6c2bdb8d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -44,7 +44,7 @@ import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsRequestTaskSettingsTests.createRequestTaskSettingsMap; @@ -325,7 +325,7 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); + var model = createCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool)); var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap("overridden_user"); var action = actionCreator.create(model, overriddenTaskSettings); @@ -389,7 +389,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createChatCompletionModel(getUrl(webServer), "org", "secret", "model", null); + var model = createCompletionModel(getUrl(webServer), "org", "secret", "model", null); var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool)); var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap(null); var action = actionCreator.create(model, overriddenTaskSettings); @@ -452,7 +452,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createChatCompletionModel(getUrl(webServer), null, "secret", "model", null); + var model = createCompletionModel(getUrl(webServer), null, "secret", "model", null); var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool)); var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap("overridden_user"); var action = actionCreator.create(model, overriddenTaskSettings); @@ -521,7 +521,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createChatCompletionModel(getUrl(webServer), null, "secret", "model", null); + var model = createCompletionModel(getUrl(webServer), null, "secret", "model", null); var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool)); var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap("overridden_user"); var action = actionCreator.create(model, overriddenTaskSettings); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index ba74d2ab42c21..e248f77fe7728 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -51,7 +51,7 @@ import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -284,7 +284,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc } private ExecutableAction createAction(String url, String org, String apiKey, String modelName, @Nullable String user, Sender sender) { - var model = createChatCompletionModel(url, org, apiKey, modelName, user); + var model = createCompletionModel(url, org, apiKey, modelName, user); var requestCreator = OpenAiCompletionRequestManager.of(model, threadPool); var errorMessage = constructFailedToSendRequestMessage(model.getServiceSettings().uri(), "OpenAI chat completions"); return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, "OpenAI chat completions"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java index 15b4898650784..068e84fae35df 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java @@ -21,7 +21,7 @@ import java.util.ArrayList; import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; -import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; public class ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests extends ESTestCase { @@ -40,7 +40,7 @@ public void testModelUserFieldsSerialization() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messageList); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + OpenAiChatCompletionModel model = createCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java index b0c58f3e94af8..ee0cdad5d552e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -20,7 +20,7 @@ import java.util.ArrayList; import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; -import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; public class OpenAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { @@ -40,7 +40,7 @@ public void testModelUserFieldsSerialization() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messageList); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", USER); + OpenAiChatCompletionModel model = createCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", USER); OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java index 2be12c9b12e0b..ec4231bd73154 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java @@ -150,7 +150,7 @@ public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String user, boolean stream ) { - var chatCompletionModel = OpenAiChatCompletionModelTests.createChatCompletionModel(url, org, apiKey, model, user); + var chatCompletionModel = OpenAiChatCompletionModelTests.createCompletionModel(url, org, apiKey, model, user); return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java index d9388cab0e1ec..024b7aa532d90 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java @@ -28,7 +28,7 @@ import java.util.Random; import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; -import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; public class UnifiedChatCompletionRequestEntityTests extends ESTestCase { @@ -46,7 +46,7 @@ public void testBasicSerialization() throws IOException { UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + OpenAiChatCompletionModel model = createCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); @@ -111,7 +111,7 @@ public void testSerializationWithAllFields() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + OpenAiChatCompletionModel model = createCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); @@ -204,7 +204,7 @@ public void testSerializationWithNullOptionalFields() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + OpenAiChatCompletionModel model = createCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); @@ -253,7 +253,7 @@ public void testSerializationWithEmptyLists() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + OpenAiChatCompletionModel model = createCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); @@ -334,7 +334,7 @@ public void testSerializationWithNestedObjects() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", randomModel, null); + OpenAiChatCompletionModel model = createCompletionModel("test-endpoint", "organizationId", "api-key", randomModel, null); OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); @@ -452,7 +452,7 @@ public void testSerializationWithDifferentContentTypes() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + OpenAiChatCompletionModel model = createCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); @@ -516,7 +516,7 @@ public void testSerializationWithSpecialCharacters() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + OpenAiChatCompletionModel model = createCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); @@ -574,7 +574,7 @@ public void testSerializationWithBooleanFields() throws IOException { null // topP ); - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + OpenAiChatCompletionModel model = createCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); OpenAiUnifiedChatCompletionRequestEntity entityTrue = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputTrue, model); @@ -642,7 +642,7 @@ public void testSerializationWithoutContentField() throws IOException { UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + OpenAiChatCompletionModel model = createCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index e2cb93f731162..cfe47098e0141 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -331,7 +331,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { - var model = OpenAiChatCompletionModelTests.createChatCompletionModel( + var model = OpenAiChatCompletionModelTests.createCompletionModel( randomAlphaOfLength(10), randomAlphaOfLength(10), randomAlphaOfLength(10), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 6e744fffbff41..468b0f676ee85 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -925,7 +925,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new IbmWatsonxServiceWithoutAuth(senderFactory, createWithEmptySettings(threadPool))) { - var model = OpenAiChatCompletionModelTests.createChatCompletionModel( + var model = OpenAiChatCompletionModelTests.createCompletionModel( randomAlphaOfLength(10), randomAlphaOfLength(10), randomAlphaOfLength(10), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index da912cd6e5d14..4941d3ae49a23 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -74,7 +74,7 @@ import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettingsTests.getTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -1084,16 +1084,16 @@ public void testInfer_StreamRequest() throws Exception { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var result = streamChatCompletion(); + var result = streamCompletion(); InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" {"completion":[{"delta":"hello, world"}]}"""); } - private InferenceServiceResults streamChatCompletion() throws IOException { + private InferenceServiceResults streamCompletion() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); + var model = OpenAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -1122,7 +1122,7 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { }"""; webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); - var result = streamChatCompletion(); + var result = streamCompletion(); InferenceEventsAssertion.assertThat(result) .hasFinishedStream() @@ -1527,7 +1527,7 @@ public void testCheckModelConfig_ReturnsNewModelReference_DoesNotOverrideSimilar public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { try (var service = createOpenAiService()) { - var model = createChatCompletionModel( + var model = createCompletionModel( randomAlphaOfLength(10), randomAlphaOfLength(10), randomAlphaOfLength(10), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java index 2a5415f45c6d9..e6d23012fae35 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java @@ -24,16 +24,16 @@ public class OpenAiChatCompletionModelTests extends ESTestCase { public void testOverrideWith_OverridesUser() { - var model = createChatCompletionModel("url", "org", "api_key", "model_name", null); + var model = createCompletionModel("url", "org", "api_key", "model_name", null); var requestTaskSettingsMap = getChatCompletionRequestTaskSettingsMap("user_override"); var overriddenModel = OpenAiChatCompletionModel.of(model, requestTaskSettingsMap); - assertThat(overriddenModel, is(createChatCompletionModel("url", "org", "api_key", "model_name", "user_override"))); + assertThat(overriddenModel, is(createCompletionModel("url", "org", "api_key", "model_name", "user_override"))); } public void testOverrideWith_EmptyMap() { - var model = createChatCompletionModel("url", "org", "api_key", "model_name", null); + var model = createCompletionModel("url", "org", "api_key", "model_name", null); var requestTaskSettingsMap = Map.of(); @@ -42,14 +42,14 @@ public void testOverrideWith_EmptyMap() { } public void testOverrideWith_NullMap() { - var model = createChatCompletionModel("url", "org", "api_key", "model_name", null); + var model = createCompletionModel("url", "org", "api_key", "model_name", null); var overriddenModel = OpenAiChatCompletionModel.of(model, (Map) null); assertThat(overriddenModel, sameInstance(model)); } public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { - var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); + var model = createCompletionModel("url", "org", "api_key", "model_name", "user"); var request = new UnifiedCompletionRequest( List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), "different_model", @@ -63,12 +63,12 @@ public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { assertThat( OpenAiChatCompletionModel.of(model, request), - is(createChatCompletionModel("url", "org", "api_key", "different_model", "user")) + is(createCompletionModel("url", "org", "api_key", "different_model", "user")) ); } public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { - var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); + var model = createCompletionModel("url", "org", "api_key", "model_name", "user"); var request = new UnifiedCompletionRequest( List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), null, // not overriding model @@ -80,10 +80,17 @@ public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenReques null ); - assertThat( - OpenAiChatCompletionModel.of(model, request), - is(createChatCompletionModel("url", "org", "api_key", "model_name", "user")) - ); + assertThat(OpenAiChatCompletionModel.of(model, request), is(createCompletionModel("url", "org", "api_key", "model_name", "user"))); + } + + public static OpenAiChatCompletionModel createCompletionModel( + String url, + @Nullable String org, + String apiKey, + String modelName, + @Nullable String user + ) { + return createModelWithTaskType(url, org, apiKey, modelName, user, TaskType.COMPLETION); } public static OpenAiChatCompletionModel createChatCompletionModel( @@ -92,10 +99,21 @@ public static OpenAiChatCompletionModel createChatCompletionModel( String apiKey, String modelName, @Nullable String user + ) { + return createModelWithTaskType(url, org, apiKey, modelName, user, TaskType.CHAT_COMPLETION); + } + + public static OpenAiChatCompletionModel createModelWithTaskType( + String url, + @Nullable String org, + String apiKey, + String modelName, + @Nullable String user, + TaskType taskType ) { return new OpenAiChatCompletionModel( "id", - TaskType.COMPLETION, + taskType, "service", new OpenAiChatCompletionServiceSettings(modelName, url, org, null, null), new OpenAiChatCompletionTaskSettings(user), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java index 0153113be75d9..a854bbdec507a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java @@ -35,6 +35,8 @@ private Map> taskTypeToModelValidatorC SimpleModelValidator.class, TaskType.COMPLETION, ChatCompletionModelValidator.class, + TaskType.CHAT_COMPLETION, + ChatCompletionModelValidator.class, TaskType.ANY, SimpleModelValidator.class ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java new file mode 100644 index 0000000000000..f02c4662d49e4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java @@ -0,0 +1,150 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.validation; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; + +import java.util.List; + +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +public class SimpleChatCompletionServiceIntegrationValidatorTests extends ESTestCase { + + private static final UnifiedCompletionRequest EXPECTED_REQUEST = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("how big"), "user", null, null)), + null, + null, + null, + null, + null, + null, + null + ); + + @Mock + private InferenceService mockInferenceService; + @Mock + private Model mockModel; + @Mock + private ActionListener mockActionListener; + @Mock + private InferenceServiceResults mockInferenceServiceResults; + + private SimpleChatCompletionServiceIntegrationValidator underTest; + + @Before + public void setup() { + openMocks(this); + + underTest = new SimpleChatCompletionServiceIntegrationValidator(); + + when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod(); + } + + public void testValidate_ServiceThrowsException() { + doThrow(ElasticsearchStatusException.class).when(mockInferenceService) + .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), any()); + + assertThrows(ElasticsearchStatusException.class, () -> underTest.validate(mockInferenceService, mockModel, mockActionListener)); + + verifyCallToService(); + } + + public void testValidate_SuccessfulCallToService() { + mockSuccessfulCallToService(mockInferenceServiceResults); + verify(mockActionListener).onResponse(mockInferenceServiceResults); + verifyCallToService(); + } + + public void testValidate_CallsListenerOnFailure_WhenServiceResponseIsNull() { + mockNullResponseFromService(); + + var captor = ArgumentCaptor.forClass(ElasticsearchStatusException.class); + verify(mockActionListener).onFailure(captor.capture()); + + assertThat( + captor.getValue().getMessage(), + is("Could not complete inference endpoint creation as validation call to service returned null response.") + ); + assertThat(captor.getValue().status(), is(RestStatus.BAD_REQUEST)); + + verifyCallToService(); + } + + public void testValidate_CallsListenerOnFailure_WhenServiceThrowsException() { + var returnedException = new IllegalStateException("bad state"); + mockFailureResponseFromService(returnedException); + + var captor = ArgumentCaptor.forClass(ElasticsearchStatusException.class); + verify(mockActionListener).onFailure(captor.capture()); + + assertThat( + captor.getValue().getMessage(), + is("Could not complete inference endpoint creation as validation call to service threw an exception.") + ); + assertThat(captor.getValue().status(), is(RestStatus.BAD_REQUEST)); + assertThat(captor.getValue().getCause(), is(returnedException)); + + verifyCallToService(); + } + + private void mockSuccessfulCallToService(InferenceServiceResults result) { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(3); + responseListener.onResponse(result); + return null; + }).when(mockInferenceService) + .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), any()); + + underTest.validate(mockInferenceService, mockModel, mockActionListener); + } + + private void mockNullResponseFromService() { + mockSuccessfulCallToService(null); + } + + private void mockFailureResponseFromService(Exception exception) { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(3); + responseListener.onFailure(exception); + return null; + }).when(mockInferenceService) + .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), any()); + + underTest.validate(mockInferenceService, mockModel, mockActionListener); + } + + private void verifyCallToService() { + verify(mockInferenceService).unifiedCompletionInfer( + eq(mockModel), + eq(EXPECTED_REQUEST), + eq(InferenceAction.Request.DEFAULT_TIMEOUT), + any() + ); + verifyNoMoreInteractions(mockInferenceService, mockModel, mockActionListener, mockInferenceServiceResults); + } +}