diff --git a/.buildkite/pipelines/pull-request/part-1.yml b/.buildkite/pipelines/pull-request/part-1.yml index 3d467c6c41e43..7a09e2a162ff8 100644 --- a/.buildkite/pipelines/pull-request/part-1.yml +++ b/.buildkite/pipelines/pull-request/part-1.yml @@ -1,6 +1,8 @@ steps: - label: part-1 - command: .ci/scripts/run-gradle.sh -Dignore.tests.seed checkPart1 + command: | + .buildkite/scripts/spotless.sh # This doesn't have to be part of part-1, it was just a convenient place to put it + .ci/scripts/run-gradle.sh -Dignore.tests.seed checkPart1 timeout_in_minutes: 300 agents: provider: gcp diff --git a/.buildkite/pipelines/pull-request/precommit.yml b/.buildkite/pipelines/pull-request/precommit.yml index f6548dfeed9b2..1763758932581 100644 --- a/.buildkite/pipelines/pull-request/precommit.yml +++ b/.buildkite/pipelines/pull-request/precommit.yml @@ -3,7 +3,9 @@ config: skip-labels: [] steps: - label: precommit - command: .ci/scripts/run-gradle.sh -Dignore.tests.seed precommit + command: | + .buildkite/scripts/spotless.sh + .ci/scripts/run-gradle.sh -Dignore.tests.seed precommit timeout_in_minutes: 300 agents: provider: gcp diff --git a/.buildkite/scripts/spotless.sh b/.buildkite/scripts/spotless.sh new file mode 100755 index 0000000000000..b9e6094edb2c7 --- /dev/null +++ b/.buildkite/scripts/spotless.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +if [[ -z "${BUILDKITE_PULL_REQUEST:-}" ]]; then + echo "Not a pull request, skipping spotless" + exit 0 +fi + +if ! git diff --exit-code; then + echo "Changes are present before running spotless, not running" + git status + exit 0 +fi + +NEW_COMMIT_MESSAGE="[CI] Auto commit changes from spotless" +PREVIOUS_COMMIT_MESSAGE="$(git log -1 --pretty=%B)" + +echo "--- Running spotless" +.ci/scripts/run-gradle.sh -Dscan.tag.NESTED spotlessApply + +if git diff --exit-code; then + echo "No changes found after running spotless. Don't need to auto commit." + exit 0 +fi + +if [[ "$NEW_COMMIT_MESSAGE" == "$PREVIOUS_COMMIT_MESSAGE" ]]; then + echo "Changes found after running spotless" + echo "CI already attempted to commit these changes, but the file(s) seem to have changed again." + echo "Please review and fix manually." + exit 1 +fi + +git config --global user.name elasticsearchmachine +git config --global user.email 'infra-root+elasticsearchmachine@elastic.co' + +gh pr checkout "${BUILDKITE_PULL_REQUEST}" +git add -u . +git commit -m "$NEW_COMMIT_MESSAGE" +git push + +# After the git push, the new commit will trigger a new build within a few seconds and this build should get cancelled +# So, let's just sleep to give the build time to cancel itself without an error +# If it doesn't get cancelled for some reason, then exit with an error, because we don't want this build to be green (we just don't want it to generate an error either) +sleep 300 +exit 1 diff --git a/distribution/docker/src/docker/iron_bank/hardening_manifest.yaml b/distribution/docker/src/docker/iron_bank/hardening_manifest.yaml index f4364c5008c09..e3bdac51cc5c5 100644 --- a/distribution/docker/src/docker/iron_bank/hardening_manifest.yaml +++ b/distribution/docker/src/docker/iron_bank/hardening_manifest.yaml @@ -50,9 +50,12 @@ resources: # List of project maintainers maintainers: - - name: "Rory Hunter" - email: "rory.hunter@elastic.co" - username: "rory" + - name: "Mark Vieira" + email: "mark.vieira@elastic.co" + username: "mark-vieira" + - name: "Rene Gröschke" + email: "rene.groschke@elastic.co" + username: "breskeby" - email: "klepal_alexander@bah.com" name: "Alexander Klepal" username: "alexander.klepal" diff --git a/docs/changelog/117643.yaml b/docs/changelog/117643.yaml new file mode 100644 index 0000000000000..9105749377d2c --- /dev/null +++ b/docs/changelog/117643.yaml @@ -0,0 +1,6 @@ +pr: 117643 +summary: Drop null columns in text formats +area: ES|QL +type: bug +issues: + - 116848 diff --git a/docs/changelog/117851.yaml b/docs/changelog/117851.yaml new file mode 100644 index 0000000000000..21888cd6fb80f --- /dev/null +++ b/docs/changelog/117851.yaml @@ -0,0 +1,5 @@ +pr: 117851 +summary: Addition of `tier_preference`, `creation_date` and `version` fields in Elasticsearch monitoring template +area: Monitoring +type: enhancement +issues: [] diff --git a/docs/changelog/118454.yaml b/docs/changelog/118454.yaml new file mode 100644 index 0000000000000..9a19ede64d705 --- /dev/null +++ b/docs/changelog/118454.yaml @@ -0,0 +1,5 @@ +pr: 118454 +summary: Fix RLIKE folding with (unsupported) case insensitive pattern +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/118474.yaml b/docs/changelog/118474.yaml new file mode 100644 index 0000000000000..1b0c6942eb323 --- /dev/null +++ b/docs/changelog/118474.yaml @@ -0,0 +1,6 @@ +pr: 118474 +summary: Esql bucket function for date nanos +area: ES|QL +type: enhancement +issues: + - 118031 diff --git a/docs/reference/esql/functions/kibana/definition/bucket.json b/docs/reference/esql/functions/kibana/definition/bucket.json index 660e1be49fda9..18802f5ff8fef 100644 --- a/docs/reference/esql/functions/kibana/definition/bucket.json +++ b/docs/reference/esql/functions/kibana/definition/bucket.json @@ -310,6 +310,312 @@ "variadic" : false, "returnType" : "date" }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "date_period", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "integer", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + }, + { + "name" : "from", + "type" : "date", + "optional" : true, + "description" : "Start of the range. Can be a number, a date or a date expressed as a string." + }, + { + "name" : "to", + "type" : "date", + "optional" : true, + "description" : "End of the range. Can be a number, a date or a date expressed as a string." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "integer", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + }, + { + "name" : "from", + "type" : "date", + "optional" : true, + "description" : "Start of the range. Can be a number, a date or a date expressed as a string." + }, + { + "name" : "to", + "type" : "keyword", + "optional" : true, + "description" : "End of the range. Can be a number, a date or a date expressed as a string." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "integer", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + }, + { + "name" : "from", + "type" : "date", + "optional" : true, + "description" : "Start of the range. Can be a number, a date or a date expressed as a string." + }, + { + "name" : "to", + "type" : "text", + "optional" : true, + "description" : "End of the range. Can be a number, a date or a date expressed as a string." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "integer", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + }, + { + "name" : "from", + "type" : "keyword", + "optional" : true, + "description" : "Start of the range. Can be a number, a date or a date expressed as a string." + }, + { + "name" : "to", + "type" : "date", + "optional" : true, + "description" : "End of the range. Can be a number, a date or a date expressed as a string." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "integer", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + }, + { + "name" : "from", + "type" : "keyword", + "optional" : true, + "description" : "Start of the range. Can be a number, a date or a date expressed as a string." + }, + { + "name" : "to", + "type" : "keyword", + "optional" : true, + "description" : "End of the range. Can be a number, a date or a date expressed as a string." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "integer", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + }, + { + "name" : "from", + "type" : "keyword", + "optional" : true, + "description" : "Start of the range. Can be a number, a date or a date expressed as a string." + }, + { + "name" : "to", + "type" : "text", + "optional" : true, + "description" : "End of the range. Can be a number, a date or a date expressed as a string." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "integer", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + }, + { + "name" : "from", + "type" : "text", + "optional" : true, + "description" : "Start of the range. Can be a number, a date or a date expressed as a string." + }, + { + "name" : "to", + "type" : "date", + "optional" : true, + "description" : "End of the range. Can be a number, a date or a date expressed as a string." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "integer", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + }, + { + "name" : "from", + "type" : "text", + "optional" : true, + "description" : "Start of the range. Can be a number, a date or a date expressed as a string." + }, + { + "name" : "to", + "type" : "keyword", + "optional" : true, + "description" : "End of the range. Can be a number, a date or a date expressed as a string." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "integer", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + }, + { + "name" : "from", + "type" : "text", + "optional" : true, + "description" : "Start of the range. Can be a number, a date or a date expressed as a string." + }, + { + "name" : "to", + "type" : "text", + "optional" : true, + "description" : "End of the range. Can be a number, a date or a date expressed as a string." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "Numeric or date expression from which to derive buckets." + }, + { + "name" : "buckets", + "type" : "time_duration", + "optional" : false, + "description" : "Target number of buckets, or desired bucket size if `from` and `to` parameters are omitted." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, { "params" : [ { diff --git a/docs/reference/esql/functions/types/bucket.asciidoc b/docs/reference/esql/functions/types/bucket.asciidoc index 172e84b6f7860..2e6985e6bc4ed 100644 --- a/docs/reference/esql/functions/types/bucket.asciidoc +++ b/docs/reference/esql/functions/types/bucket.asciidoc @@ -16,6 +16,17 @@ date | integer | text | date | date date | integer | text | keyword | date date | integer | text | text | date date | time_duration | | | date +date_nanos | date_period | | | date_nanos +date_nanos | integer | date | date | date_nanos +date_nanos | integer | date | keyword | date_nanos +date_nanos | integer | date | text | date_nanos +date_nanos | integer | keyword | date | date_nanos +date_nanos | integer | keyword | keyword | date_nanos +date_nanos | integer | keyword | text | date_nanos +date_nanos | integer | text | date | date_nanos +date_nanos | integer | text | keyword | date_nanos +date_nanos | integer | text | text | date_nanos +date_nanos | time_duration | | | date_nanos double | double | | | double double | integer | double | double | double double | integer | double | integer | double diff --git a/muted-tests.yml b/muted-tests.yml index 240d9d245eee5..7c2724f46fc81 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -428,9 +428,6 @@ tests: - class: org.elasticsearch.xpack.searchablesnapshots.RetrySearchIntegTests method: testRetryPointInTime issue: https://github.com/elastic/elasticsearch/issues/118514 -- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT - method: test {stats.ByDateAndKeywordAndIntWithAlias SYNC} - issue: https://github.com/elastic/elasticsearch/issues/118668 - class: org.elasticsearch.xpack.application.OpenAiServiceUpgradeIT method: testOpenAiEmbeddings {upgradedNodes=1} issue: https://github.com/elastic/elasticsearch/issues/118156 diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/ChunkedInference.java similarity index 73% rename from server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java rename to server/src/main/java/org/elasticsearch/inference/ChunkedInference.java index 10e00e9860200..c54e5a98d56cc 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkedInference.java @@ -12,23 +12,27 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.xcontent.XContent; +import java.io.IOException; import java.util.Iterator; -public interface ChunkedInferenceServiceResults extends InferenceServiceResults { +public interface ChunkedInference { /** * Implementations of this function serialize their embeddings to {@link BytesReference} for storage in semantic text fields. - * The iterator iterates over all the chunks stored in the {@link ChunkedInferenceServiceResults}. * * @param xcontent provided by the SemanticTextField * @return an iterator of the serialized {@link Chunk} which includes the matched text (input) and bytes reference (output/embedding). */ - Iterator chunksAsMatchedTextAndByteReference(XContent xcontent); + Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException; /** - * A chunk of inference results containing matched text and the bytes reference. + * A chunk of inference results containing matched text, the substring location + * in the original text and the bytes reference. * @param matchedText + * @param textOffset * @param bytesReference */ - record Chunk(String matchedText, BytesReference bytesReference) {} + record Chunk(String matchedText, TextOffset textOffset, BytesReference bytesReference) {} + + record TextOffset(int start, int end) {} } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index c2d690d8160ac..c8ed9e6b230ce 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -144,7 +144,7 @@ void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ); /** diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java new file mode 100644 index 0000000000000..c2f70b0be2916 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +public record ChunkedInferenceEmbeddingByte(List chunks) implements ChunkedInference { + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException { + var asChunk = new ArrayList(); + for (var chunk : chunks) { + asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding()))); + } + return asChunk.iterator(); + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException { + XContentBuilder builder = XContentBuilder.builder(xContent); + builder.startArray(); + for (byte v : value) { + builder.value(v); + } + builder.endArray(); + return BytesReference.bytes(builder); + } + + public record ByteEmbeddingChunk(byte[] embedding, String matchedText, TextOffset offset) {} +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java new file mode 100644 index 0000000000000..651d135b761dd --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +public record ChunkedInferenceEmbeddingFloat(List chunks) implements ChunkedInference { + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException { + var asChunk = new ArrayList(); + for (var chunk : chunks) { + asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding()))); + } + return asChunk.iterator(); + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startArray(); + for (float v : value) { + b.value(v); + } + b.endArray(); + return BytesReference.bytes(b); + } + + public record FloatEmbeddingChunk(float[] embedding, String matchedText, TextOffset offset) {} +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java new file mode 100644 index 0000000000000..37bf92e0dbfce --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; + +public record ChunkedInferenceEmbeddingSparse(List chunks) implements ChunkedInference { + + public static List listOf(List inputs, SparseEmbeddingResults sparseEmbeddingResults) { + validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size()); + + var results = new ArrayList(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + results.add( + new ChunkedInferenceEmbeddingSparse( + List.of( + new SparseEmbeddingChunk( + sparseEmbeddingResults.embeddings().get(i).tokens(), + inputs.get(i), + new TextOffset(0, inputs.get(i).length()) + ) + ) + ) + ); + } + + return results; + } + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException { + var asChunk = new ArrayList(); + for (var chunk : chunks) { + asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.weightedTokens()))); + } + return asChunk.iterator(); + } + + private static BytesReference toBytesReference(XContent xContent, List tokens) throws IOException { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startObject(); + for (var weightedToken : tokens) { + weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); + } + b.endObject(); + return BytesReference.bytes(b); + } + + public record SparseEmbeddingChunk(List weightedTokens, String matchedText, TextOffset offset) {} +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceError.java new file mode 100644 index 0000000000000..65be9f12d7686 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceError.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.XContent; + +import java.util.Iterator; +import java.util.stream.Stream; + +public record ChunkedInferenceError(Exception exception) implements ChunkedInference { + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { + return Stream.of(exception).map(e -> new Chunk(e.getMessage(), new TextOffset(0, 0), BytesArray.EMPTY)).iterator(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java deleted file mode 100644 index 18f88a8ff022a..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContent; - -import java.io.IOException; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Stream; - -public class ErrorChunkedInferenceResults implements ChunkedInferenceServiceResults { - - public static final String NAME = "error_chunked"; - - private final Exception exception; - - public ErrorChunkedInferenceResults(Exception exception) { - this.exception = Objects.requireNonNull(exception); - } - - public ErrorChunkedInferenceResults(StreamInput in) throws IOException { - this.exception = in.readException(); - } - - public Exception getException() { - return exception; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeException(exception); - } - - @Override - public boolean equals(Object object) { - if (object == this) { - return true; - } - if (object == null || getClass() != object.getClass()) { - return false; - } - ErrorChunkedInferenceResults that = (ErrorChunkedInferenceResults) object; - // Just compare the message for serialization test purposes - return Objects.equals(exception.getMessage(), that.exception.getMessage()); - } - - @Override - public int hashCode() { - // Just compare the message for serialization test purposes - return Objects.hash(exception.getMessage()); - } - - @Override - public List transformToCoordinationFormat() { - return null; - } - - @Override - public List transformToLegacyFormat() { - return null; - } - - @Override - public Map asMap() { - Map asMap = new LinkedHashMap<>(); - asMap.put(NAME, exception.getMessage()); - return asMap; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContentHelper.field(NAME, exception.getMessage()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return Stream.of(exception).map(e -> new Chunk(e.getMessage(), BytesArray.EMPTY)).iterator(); - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java deleted file mode 100644 index c961050acefdb..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public class InferenceChunkedSparseEmbeddingResults implements ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_sparse_embedding_results"; - public static final String FIELD_NAME = "sparse_embedding_chunk"; - - public static InferenceChunkedSparseEmbeddingResults ofMlResult(MlChunkedTextExpansionResults mlInferenceResults) { - return new InferenceChunkedSparseEmbeddingResults(mlInferenceResults.getChunks()); - } - - /** - * Returns a list of {@link InferenceChunkedSparseEmbeddingResults}. The number of entries in the list will match the input list size. - * Each {@link InferenceChunkedSparseEmbeddingResults} will have a single chunk containing the entire results from the - * {@link SparseEmbeddingResults}. - */ - public static List listOf(List inputs, SparseEmbeddingResults sparseEmbeddingResults) { - validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size()); - - var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - results.add(ofSingle(inputs.get(i), sparseEmbeddingResults.embeddings().get(i))); - } - - return results; - } - - private static InferenceChunkedSparseEmbeddingResults ofSingle(String input, SparseEmbeddingResults.Embedding embedding) { - var weightedTokens = embedding.tokens() - .stream() - .map(weightedToken -> new WeightedToken(weightedToken.token(), weightedToken.weight())) - .toList(); - - return new InferenceChunkedSparseEmbeddingResults(List.of(new MlChunkedTextExpansionResults.ChunkedResult(input, weightedTokens))); - } - - private final List chunkedResults; - - public InferenceChunkedSparseEmbeddingResults(List chunks) { - this.chunkedResults = chunks; - } - - public InferenceChunkedSparseEmbeddingResults(StreamInput in) throws IOException { - this.chunkedResults = in.readCollectionAsList(MlChunkedTextExpansionResults.ChunkedResult::new); - } - - public List getChunkedResults() { - return chunkedResults; - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunkedResults.iterator()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunkedResults); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordindated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of( - FIELD_NAME, - chunkedResults.stream().map(MlChunkedTextExpansionResults.ChunkedResult::asMap).collect(Collectors.toList()) - ); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedSparseEmbeddingResults that = (InferenceChunkedSparseEmbeddingResults) o; - return Objects.equals(chunkedResults, that.chunkedResults); - } - - @Override - public int hashCode() { - return Objects.hash(chunkedResults); - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunkedResults.stream() - .map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.weightedTokens()))) - .iterator(); - } - - /** - * Serialises the {@link WeightedToken} list, according to the provided {@link XContent}, - * into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, List tokens) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startObject(); - for (var weightedToken : tokens) { - weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); - } - b.endObject(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java deleted file mode 100644 index 6bd66664068d5..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public record InferenceChunkedTextEmbeddingByteResults(List chunks, boolean isTruncated) - implements - ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_text_embedding_service_byte_results"; - public static final String FIELD_NAME = "text_embedding_byte_chunk"; - - /** - * Returns a list of {@link InferenceChunkedTextEmbeddingByteResults}. The number of entries in the list will match the input list size. - * Each {@link InferenceChunkedTextEmbeddingByteResults} will have a single chunk containing the entire results from the - * {@link InferenceTextEmbeddingByteResults}. - */ - public static List listOf(List inputs, InferenceTextEmbeddingByteResults textEmbeddings) { - validateInputSizeAgainstEmbeddings(inputs, textEmbeddings.embeddings().size()); - - var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - results.add(ofSingle(inputs.get(i), textEmbeddings.embeddings().get(i).values())); - } - - return results; - } - - private static InferenceChunkedTextEmbeddingByteResults ofSingle(String input, byte[] byteEmbeddings) { - return new InferenceChunkedTextEmbeddingByteResults(List.of(new InferenceByteEmbeddingChunk(input, byteEmbeddings)), false); - } - - public InferenceChunkedTextEmbeddingByteResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(InferenceByteEmbeddingChunk::new), in.readBoolean()); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunks.iterator()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunks); - out.writeBoolean(isTruncated); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of(FIELD_NAME, chunks); - } - - @Override - public String getWriteableName() { - return NAME; - } - - public List getChunks() { - return chunks; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedTextEmbeddingByteResults that = (InferenceChunkedTextEmbeddingByteResults) o; - return isTruncated == that.isTruncated && Objects.equals(chunks, that.chunks); - } - - @Override - public int hashCode() { - return Objects.hash(chunks, isTruncated); - } - - public record InferenceByteEmbeddingChunk(String matchedText, byte[] embedding) implements Writeable, ToXContentObject { - - public InferenceByteEmbeddingChunk(StreamInput in) throws IOException { - this(in.readString(), in.readByteArray()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeByteArray(embedding); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); - - builder.startArray(ChunkedNlpInferenceResults.INFERENCE); - for (byte value : embedding) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceByteEmbeddingChunk that = (InferenceByteEmbeddingChunk) o; - return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding); - } - - @Override - public int hashCode() { - int result = Objects.hash(matchedText); - result = 31 * result + Arrays.hashCode(embedding); - return result; - } - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunks.stream().map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.embedding()))).iterator(); - } - - private static BytesReference toBytesReference(XContent xContent, byte[] value) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startArray(); - for (byte v : value) { - b.value(v); - } - b.endArray(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java deleted file mode 100644 index 369f22a807913..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.utils.FloatConversionUtils; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public record InferenceChunkedTextEmbeddingFloatResults(List chunks) - implements - ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_text_embedding_service_float_results"; - public static final String FIELD_NAME = "text_embedding_float_chunk"; - - public InferenceChunkedTextEmbeddingFloatResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(InferenceFloatEmbeddingChunk::new)); - } - - /** - * Returns a list of {@link InferenceChunkedTextEmbeddingFloatResults}. - * Each {@link InferenceChunkedTextEmbeddingFloatResults} contain a single chunk with the text and the - * {@link InferenceTextEmbeddingFloatResults}. - */ - public static List listOf(List inputs, InferenceTextEmbeddingFloatResults textEmbeddings) { - validateInputSizeAgainstEmbeddings(inputs, textEmbeddings.embeddings().size()); - - var results = new ArrayList(inputs.size()); - - for (int i = 0; i < inputs.size(); i++) { - results.add( - new InferenceChunkedTextEmbeddingFloatResults( - List.of(new InferenceFloatEmbeddingChunk(inputs.get(i), textEmbeddings.embeddings().get(i).values())) - ) - ); - } - - return results; - } - - public static InferenceChunkedTextEmbeddingFloatResults ofMlResults(MlChunkedTextEmbeddingFloatResults mlInferenceResult) { - return new InferenceChunkedTextEmbeddingFloatResults( - mlInferenceResult.getChunks() - .stream() - .map(chunk -> new InferenceFloatEmbeddingChunk(chunk.matchedText(), FloatConversionUtils.floatArrayOf(chunk.embedding()))) - .toList() - ); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - // TODO add isTruncated flag - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunks.iterator()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunks); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of(FIELD_NAME, chunks); - } - - @Override - public String getWriteableName() { - return NAME; - } - - public List getChunks() { - return chunks; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedTextEmbeddingFloatResults that = (InferenceChunkedTextEmbeddingFloatResults) o; - return Objects.equals(chunks, that.chunks); - } - - @Override - public int hashCode() { - return Objects.hash(chunks); - } - - public record InferenceFloatEmbeddingChunk(String matchedText, float[] embedding) implements Writeable, ToXContentObject { - - public InferenceFloatEmbeddingChunk(StreamInput in) throws IOException { - this(in.readString(), in.readFloatArray()); - } - - public static InferenceFloatEmbeddingChunk of(String matchedText, double[] doubleEmbedding) { - return new InferenceFloatEmbeddingChunk(matchedText, FloatConversionUtils.floatArrayOf(doubleEmbedding)); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeFloatArray(embedding); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); - - builder.startArray(ChunkedNlpInferenceResults.INFERENCE); - for (float value : embedding) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceFloatEmbeddingChunk that = (InferenceFloatEmbeddingChunk) o; - return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding); - } - - @Override - public int hashCode() { - int result = Objects.hash(matchedText); - result = 31 * result + Arrays.hashCode(embedding); - return result; - } - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunks.stream().map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.embedding()))).iterator(); - } - - /** - * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, float[] value) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startArray(); - for (float v : value) { - b.value(v); - } - b.endArray(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java index 4c68d02264457..cb69f1e403e9c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java @@ -31,7 +31,7 @@ public static int getFirstEmbeddingSize(List embeddings) throws Il * Throws an exception if the number of elements in the input text list is different than the results in text embedding * response. */ - static void validateInputSizeAgainstEmbeddings(List inputs, int embeddingSize) { + public static void validateInputSizeAgainstEmbeddings(List inputs, int embeddingSize) { if (inputs.size() != embeddingSize) { throw new IllegalArgumentException( Strings.format("The number of inputs [%s] does not match the embeddings [%s]", inputs.size(), embeddingSize) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java deleted file mode 100644 index 83678cd030bc2..0000000000000 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults.INFERENCE; -import static org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults.TEXT; - -public class InferenceChunkedTextEmbeddingFloatResultsTests extends ESTestCase { - /** - * Similar to {@link org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults#asMap()} but it converts the - * embeddings float array into a list of floats to make testing equality easier. - */ - public static Map asMapWithListsInsteadOfArrays(InferenceChunkedTextEmbeddingFloatResults result) { - return Map.of( - InferenceChunkedTextEmbeddingFloatResults.FIELD_NAME, - result.getChunks() - .stream() - .map(InferenceChunkedTextEmbeddingFloatResultsTests::inferenceFloatEmbeddingChunkAsMapWithListsInsteadOfArrays) - .collect(Collectors.toList()) - ); - } - - /** - * Similar to {@link MlChunkedTextEmbeddingFloatResults.EmbeddingChunk#asMap()} but it converts the double array into a list of doubles - * to make testing equality easier. - */ - public static Map inferenceFloatEmbeddingChunkAsMapWithListsInsteadOfArrays( - InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk chunk - ) { - var chunkAsList = new ArrayList(chunk.embedding().length); - for (double embedding : chunk.embedding()) { - chunkAsList.add((float) embedding); - } - var map = new HashMap(); - map.put(TEXT, chunk.matchedText()); - map.put(INFERENCE, chunkAsList); - return map; - } -} diff --git a/x-pack/plugin/core/template-resources/src/main/resources/monitoring-es-mb.json b/x-pack/plugin/core/template-resources/src/main/resources/monitoring-es-mb.json index 27262507518d2..572f814688dc4 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/monitoring-es-mb.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/monitoring-es-mb.json @@ -1517,6 +1517,17 @@ "ignore_above": 1024, "type": "keyword" }, + "tier_preference": { + "ignore_above": 1024, + "type": "keyword" + }, + "creation_date": { + "type": "date" + }, + "version": { + "ignore_above": 1024, + "type": "keyword" + }, "recovery": { "properties": { "stop_time": { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLike.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLike.java index 5f095a654fc89..b4bccf162d9e4 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLike.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLike.java @@ -8,12 +8,11 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import java.io.IOException; -public class RLike extends RegexMatch { +public abstract class RLike extends RegexMatch { public RLike(Source source, Expression value, RLikePattern pattern) { super(source, value, pattern, false); @@ -33,13 +32,4 @@ public String getWriteableName() { throw new UnsupportedOperationException(); } - @Override - protected NodeInfo info() { - return NodeInfo.create(this, RLike::new, field(), pattern(), caseInsensitive()); - } - - @Override - protected RLike replaceChild(Expression newChild) { - return new RLike(source(), newChild, pattern(), caseInsensitive()); - } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java index 32e8b04573d2d..0f9116ade5a31 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.core.expression.predicate.regex; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction; @@ -64,11 +63,7 @@ public boolean foldable() { @Override public Boolean fold() { - Object val = field().fold(); - if (val instanceof BytesRef br) { - val = br.utf8ToString(); - } - return RegexOperation.match(val, pattern().asJavaRegex()); + throw new UnsupportedOperationException(); } @Override diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/WildcardLike.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/WildcardLike.java index bf54744667217..05027707326bd 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/WildcardLike.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/WildcardLike.java @@ -8,12 +8,11 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import java.io.IOException; -public class WildcardLike extends RegexMatch { +public abstract class WildcardLike extends RegexMatch { public WildcardLike(Source source, Expression left, WildcardPattern pattern) { this(source, left, pattern, false); @@ -33,14 +32,4 @@ public String getWriteableName() { throw new UnsupportedOperationException(); } - @Override - protected NodeInfo info() { - return NodeInfo.create(this, WildcardLike::new, field(), pattern(), caseInsensitive()); - } - - @Override - protected WildcardLike replaceChild(Expression newLeft) { - return new WildcardLike(source(), newLeft, pattern(), caseInsensitive()); - } - } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/planner/TranslatorHandler.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/planner/TranslatorHandler.java index 1ccbb04f7a69c..b85544905595a 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/planner/TranslatorHandler.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/planner/TranslatorHandler.java @@ -12,7 +12,6 @@ import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.esql.core.querydsl.query.Query; -import org.elasticsearch.xpack.esql.core.type.DataType; import java.util.function.Supplier; @@ -34,5 +33,4 @@ default Query wrapFunctionQuery(ScalarFunction sf, Expression field, Supplier assertThat(initialValue, emptyOrNullString()); case "csv" -> { - assertEquals(initialValue, "\r\n"); + assertEquals("\r\n", initialValue); initialValue = ""; } case "tsv" -> { - assertEquals(initialValue, "\n"); + assertEquals("\n", initialValue); initialValue = ""; } } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvAssert.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvAssert.java index 1a2aa122c85ca..4e31916d5328e 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvAssert.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvAssert.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.time.DateFormatter; import org.elasticsearch.compute.data.Page; import org.elasticsearch.logging.Logger; @@ -39,9 +40,9 @@ import static org.elasticsearch.xpack.esql.core.util.NumericUtils.unsignedLongAsNumber; import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.CARTESIAN; import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.GEO; +import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; public final class CsvAssert { @@ -197,11 +198,13 @@ public static void assertData( for (int row = 0; row < expectedValues.size(); row++) { try { if (row >= actualValues.size()) { - if (dataFailures.isEmpty()) { - fail("Expected more data but no more entries found after [" + row + "]"); - } else { - dataFailure(dataFailures, "Expected more data but no more entries found after [" + row + "]\n"); - } + dataFailure( + "Expected more data but no more entries found after [" + row + "]", + dataFailures, + expected, + actualValues, + valueTransformer + ); } if (logger != null) { @@ -212,51 +215,28 @@ public static void assertData( var actualRow = actualValues.get(row); for (int column = 0; column < expectedRow.size(); column++) { - var expectedValue = expectedRow.get(column); - var actualValue = actualRow.get(column); var expectedType = expected.columnTypes().get(column); + var expectedValue = convertExpectedValue(expectedType, expectedRow.get(column)); + var actualValue = actualRow.get(column); - if (expectedValue != null) { - // convert the long from CSV back to its STRING form - if (expectedType == Type.DATETIME) { - expectedValue = rebuildExpected(expectedValue, Long.class, x -> UTC_DATE_TIME_FORMATTER.formatMillis((long) x)); - } else if (expectedType == Type.DATE_NANOS) { - expectedValue = rebuildExpected( - expectedValue, - Long.class, - x -> DateFormatter.forPattern("strict_date_optional_time_nanos").formatNanos((long) x) - ); - } else if (expectedType == Type.GEO_POINT) { - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> GEO.wkbToWkt((BytesRef) x)); - } else if (expectedType == Type.CARTESIAN_POINT) { - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> CARTESIAN.wkbToWkt((BytesRef) x)); - } else if (expectedType == Type.GEO_SHAPE) { - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> GEO.wkbToWkt((BytesRef) x)); - } else if (expectedType == Type.CARTESIAN_SHAPE) { - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> CARTESIAN.wkbToWkt((BytesRef) x)); - } else if (expectedType == Type.IP) { - // convert BytesRef-packed IP to String, allowing subsequent comparison with what's expected - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> DocValueFormat.IP.format((BytesRef) x)); - } else if (expectedType == Type.VERSION) { - // convert BytesRef-packed Version to String - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> new Version((BytesRef) x).toString()); - } else if (expectedType == UNSIGNED_LONG) { - expectedValue = rebuildExpected(expectedValue, Long.class, x -> unsignedLongAsNumber((long) x)); - } - } var transformedExpected = valueTransformer.apply(expectedType, expectedValue); var transformedActual = valueTransformer.apply(expectedType, actualValue); if (Objects.equals(transformedExpected, transformedActual) == false) { dataFailures.add(new DataFailure(row, column, transformedExpected, transformedActual)); } if (dataFailures.size() > 10) { - dataFailure(dataFailures); + dataFailure("", dataFailures, expected, actualValues, valueTransformer); } } - var delta = actualRow.size() - expectedRow.size(); - if (delta > 0) { - fail("Plan has extra columns, returned [" + actualRow.size() + "], expected [" + expectedRow.size() + "]"); + if (actualRow.size() != expectedRow.size()) { + dataFailure( + "Plan has extra columns, returned [" + actualRow.size() + "], expected [" + expectedRow.size() + "]", + dataFailures, + expected, + actualValues, + valueTransformer + ); } } catch (AssertionError ae) { if (logger != null && row + 1 < actualValues.size()) { @@ -267,21 +247,95 @@ public static void assertData( } } if (dataFailures.isEmpty() == false) { - dataFailure(dataFailures); + dataFailure("", dataFailures, expected, actualValues, valueTransformer); } if (expectedValues.size() < actualValues.size()) { - fail( - "Elasticsearch still has data after [" + expectedValues.size() + "] entries:\n" + row(actualValues, expectedValues.size()) + dataFailure( + "Elasticsearch still has data after [" + expectedValues.size() + "] entries", + dataFailures, + expected, + actualValues, + valueTransformer ); } } - private static void dataFailure(List dataFailures) { - dataFailure(dataFailures, ""); + private static void dataFailure( + String description, + List dataFailures, + ExpectedResults expectedValues, + List> actualValues, + BiFunction valueTransformer + ) { + var expected = pipeTable( + "Expected:", + expectedValues.columnNames(), + expectedValues.columnTypes(), + expectedValues.values(), + (type, value) -> valueTransformer.apply(type, convertExpectedValue(type, value)) + ); + var actual = pipeTable("Actual:", expectedValues.columnNames(), expectedValues.columnTypes(), actualValues, valueTransformer); + fail(description + System.lineSeparator() + describeFailures(dataFailures) + actual + expected); + } + + private static final int MAX_ROWS = 25; + + private static String pipeTable( + String description, + List headers, + List types, + List> values, + BiFunction valueTransformer + ) { + int rows = Math.min(MAX_ROWS, values.size()); + int[] width = new int[headers.size()]; + String[][] printableValues = new String[rows][headers.size()]; + for (int c = 0; c < headers.size(); c++) { + width[c] = header(headers.get(c), types.get(c)).length(); + } + for (int r = 0; r < rows; r++) { + for (int c = 0; c < headers.size(); c++) { + printableValues[r][c] = String.valueOf(valueTransformer.apply(types.get(c), values.get(r).get(c))); + width[c] = Math.max(width[c], printableValues[r][c].length()); + } + } + + var result = new StringBuilder().append(System.lineSeparator()).append(description).append(System.lineSeparator()); + // headers + appendPaddedValue(result, header(headers.get(0), types.get(0)), width[0]); + for (int c = 1; c < width.length; c++) { + result.append(" | "); + appendPaddedValue(result, header(headers.get(c), types.get(c)), width[c]); + } + result.append(System.lineSeparator()); + // values + for (int r = 0; r < printableValues.length; r++) { + appendPaddedValue(result, printableValues[r][0], width[0]); + for (int c = 1; c < printableValues[r].length; c++) { + result.append(" | "); + appendPaddedValue(result, printableValues[r][c], width[c]); + } + result.append(System.lineSeparator()); + } + if (values.size() > rows) { + result.append("...").append(System.lineSeparator()); + } + return result.toString(); + } + + private static String header(String name, Type type) { + return name + ':' + Strings.toLowercaseAscii(type.name()); + } + + private static void appendPaddedValue(StringBuilder result, String value, int width) { + result.append(value); + for (int i = 0; i < width - (value != null ? value.length() : 4); i++) { + result.append(' '); + } } - private static void dataFailure(List dataFailures, String prefixError) { - fail(prefixError + "Data mismatch:\n" + dataFailures.stream().map(f -> { + private static String describeFailures(List dataFailures) { + return "Data mismatch:" + System.lineSeparator() + dataFailures.stream().map(f -> { Description description = new StringDescription(); ListMatcher expected; if (f.expected instanceof List e) { @@ -299,7 +353,7 @@ private static void dataFailure(List dataFailures, String prefixErr expected.describeMismatch(actualList, description); String prefix = "row " + f.row + " column " + f.column + ":"; return prefix + description.toString().replace("\n", "\n" + prefix); - }).collect(Collectors.joining("\n"))); + }).collect(Collectors.joining(System.lineSeparator())); } private static Comparator> resultRowComparator(List types) { @@ -331,6 +385,30 @@ private static Comparator> resultRowComparator(List types) { }; } + private static Object convertExpectedValue(Type expectedType, Object expectedValue) { + if (expectedValue == null) { + return null; + } + + // convert the long from CSV back to its STRING form + return switch (expectedType) { + case DATETIME -> rebuildExpected(expectedValue, Long.class, x -> UTC_DATE_TIME_FORMATTER.formatMillis((long) x)); + case DATE_NANOS -> rebuildExpected( + expectedValue, + Long.class, + x -> DateFormatter.forPattern("strict_date_optional_time_nanos").formatNanos((long) x) + ); + case GEO_POINT, GEO_SHAPE -> rebuildExpected(expectedValue, BytesRef.class, x -> GEO.wkbToWkt((BytesRef) x)); + case CARTESIAN_POINT, CARTESIAN_SHAPE -> rebuildExpected(expectedValue, BytesRef.class, x -> CARTESIAN.wkbToWkt((BytesRef) x)); + case IP -> // convert BytesRef-packed IP to String, allowing subsequent comparison with what's expected + rebuildExpected(expectedValue, BytesRef.class, x -> DocValueFormat.IP.format((BytesRef) x)); + case VERSION -> // convert BytesRef-packed Version to String + rebuildExpected(expectedValue, BytesRef.class, x -> new Version((BytesRef) x).toString()); + case UNSIGNED_LONG -> rebuildExpected(expectedValue, Long.class, x -> unsignedLongAsNumber((long) x)); + default -> expectedValue; + }; + } + private static Object rebuildExpected(Object expectedValue, Class clazz, Function mapper) { if (List.class.isAssignableFrom(expectedValue.getClass())) { assertThat(((List) expectedValue).get(0), instanceOf(clazz)); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java index 67ea456aeff34..c66d6839fa7d2 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java @@ -41,7 +41,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.esql.CsvTestUtils.COMMA_ESCAPING_REGEX; @@ -63,6 +62,8 @@ public class CsvTestsDataLoader { private static final TestsDataset LANGUAGES = new TestsDataset("languages"); private static final TestsDataset LANGUAGES_LOOKUP = LANGUAGES.withIndex("languages_lookup") .withSetting("languages_lookup-settings.json"); + private static final TestsDataset LANGUAGES_LOOKUP_NON_UNIQUE_KEY = LANGUAGES_LOOKUP.withIndex("languages_lookup_non_unique_key") + .withData("languages_non_unique_key.csv"); private static final TestsDataset ALERTS = new TestsDataset("alerts"); private static final TestsDataset UL_LOGS = new TestsDataset("ul_logs"); private static final TestsDataset SAMPLE_DATA = new TestsDataset("sample_data"); @@ -114,6 +115,7 @@ public class CsvTestsDataLoader { Map.entry(APPS_SHORT.indexName, APPS_SHORT), Map.entry(LANGUAGES.indexName, LANGUAGES), Map.entry(LANGUAGES_LOOKUP.indexName, LANGUAGES_LOOKUP), + Map.entry(LANGUAGES_LOOKUP_NON_UNIQUE_KEY.indexName, LANGUAGES_LOOKUP_NON_UNIQUE_KEY), Map.entry(UL_LOGS.indexName, UL_LOGS), Map.entry(SAMPLE_DATA.indexName, SAMPLE_DATA), Map.entry(MV_SAMPLE_DATA.indexName, MV_SAMPLE_DATA), @@ -258,11 +260,22 @@ public static void main(String[] args) throws IOException { public static Set availableDatasetsForEs(RestClient client, boolean supportsIndexModeLookup) throws IOException { boolean inferenceEnabled = clusterHasInferenceEndpoint(client); - return CSV_DATASET_MAP.values() - .stream() - .filter(d -> d.requiresInferenceEndpoint == false || inferenceEnabled) - .filter(d -> supportsIndexModeLookup || d.indexName.endsWith("_lookup") == false) // TODO: use actual index settings - .collect(Collectors.toCollection(HashSet::new)); + Set testDataSets = new HashSet<>(); + + for (TestsDataset dataset : CSV_DATASET_MAP.values()) { + if ((inferenceEnabled || dataset.requiresInferenceEndpoint == false) + && (supportsIndexModeLookup || isLookupDataset(dataset) == false)) { + testDataSets.add(dataset); + } + } + + return testDataSets; + } + + public static boolean isLookupDataset(TestsDataset dataset) throws IOException { + Settings settings = dataset.readSettingsFile(); + String mode = settings.get("index.mode"); + return (mode != null && mode.equalsIgnoreCase("lookup")); } public static void loadDataSetIntoEs(RestClient client, boolean supportsIndexModeLookup) throws IOException { @@ -354,13 +367,8 @@ private static void load(RestClient client, TestsDataset dataset, Logger logger, if (data == null) { throw new IllegalArgumentException("Cannot find resource " + dataName); } - Settings indexSettings = Settings.EMPTY; - final String settingName = dataset.settingFileName != null ? "/" + dataset.settingFileName : null; - if (settingName != null) { - indexSettings = Settings.builder() - .loadFromStream(settingName, CsvTestsDataLoader.class.getResourceAsStream(settingName), false) - .build(); - } + + Settings indexSettings = dataset.readSettingsFile(); indexCreator.createIndex(client, dataset.indexName, readMappingFile(mapping, dataset.typeMapping), indexSettings); loadCsvData(client, dataset.indexName, data, dataset.allowSubFields, logger); } @@ -669,6 +677,18 @@ public TestsDataset withTypeMapping(Map typeMapping) { public TestsDataset withInferenceEndpoint(boolean needsInference) { return new TestsDataset(indexName, mappingFileName, dataFileName, settingFileName, allowSubFields, typeMapping, needsInference); } + + private Settings readSettingsFile() throws IOException { + Settings indexSettings = Settings.EMPTY; + final String settingName = settingFileName != null ? "/" + settingFileName : null; + if (settingName != null) { + indexSettings = Settings.builder() + .loadFromStream(settingName, CsvTestsDataLoader.class.getResourceAsStream(settingName), false) + .build(); + } + + return indexSettings; + } } public record EnrichConfig(String policyName, String policyFileName) {} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date_nanos.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date_nanos.csv-spec index bf0fd72f4f3f0..22b0bc2878cbb 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date_nanos.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date_nanos.csv-spec @@ -548,6 +548,80 @@ yr:date_nanos | mo:date_nanos | mn:date_nanos 2023-01-01T00:00:00.000000000Z | 2023-10-01T00:00:00.000000000Z | 2023-10-23T12:10:00.000000000Z | 2023-10-23T12:15:03.360000000Z ; +Bucket Date nanos by Year +required_capability: date_trunc_date_nanos +required_capability: date_nanos_bucket + +FROM date_nanos +| WHERE millis > "2020-01-01" +| STATS ct = count(*) BY yr = BUCKET(nanos, 1 year); + +ct:long | yr:date_nanos +8 | 2023-01-01T00:00:00.000000000Z +; + +Bucket Date nanos by Year, range version +required_capability: date_trunc_date_nanos +required_capability: date_nanos_bucket + +FROM date_nanos +| WHERE millis > "2020-01-01" +| STATS ct = count(*) BY yr = BUCKET(nanos, 5, "1999-01-01", NOW()); + +ct:long | yr:date_nanos +8 | 2023-01-01T00:00:00.000000000Z +; + +Bucket Date nanos by Month +required_capability: date_trunc_date_nanos +required_capability: date_nanos_bucket + +FROM date_nanos +| WHERE millis > "2020-01-01" +| STATS ct = count(*) BY mo = BUCKET(nanos, 1 month); + +ct:long | mo:date_nanos +8 | 2023-10-01T00:00:00.000000000Z +; + +Bucket Date nanos by Month, range version +required_capability: date_trunc_date_nanos +required_capability: date_nanos_bucket + +FROM date_nanos +| WHERE millis > "2020-01-01" +| STATS ct = count(*) BY mo = BUCKET(nanos, 20, "2023-01-01", "2023-12-31"); + +ct:long | mo:date_nanos +8 | 2023-10-01T00:00:00.000000000Z +; + +Bucket Date nanos by Week, range version +required_capability: date_trunc_date_nanos +required_capability: date_nanos_bucket + +FROM date_nanos +| WHERE millis > "2020-01-01" +| STATS ct = count(*) BY mo = BUCKET(nanos, 55, "2023-01-01", "2023-12-31"); + +ct:long | mo:date_nanos +8 | 2023-10-23T00:00:00.000000000Z +; +Bucket Date nanos by 10 minutes +required_capability: date_trunc_date_nanos +required_capability: date_nanos_bucket + +FROM date_nanos +| WHERE millis > "2020-01-01" +| STATS ct = count(*) BY mn = BUCKET(nanos, 10 minutes); + +ct:long | mn:date_nanos +4 | 2023-10-23T13:50:00.000000000Z +1 | 2023-10-23T13:30:00.000000000Z +1 | 2023-10-23T12:20:00.000000000Z +2 | 2023-10-23T12:10:00.000000000Z +; + Add date nanos required_capability: date_nanos_add_subtract diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec index cde5427bf37d6..2b3b0bee93471 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec @@ -223,7 +223,7 @@ null | null | null ; -// the query is incorrectly physically plan (fails the verification) in pre-8.13.0 versions +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions overwriteName#[skip:-8.12.99] from employees | sort emp_no asc | eval full_name = concat(first_name, " ", last_name) | dissect full_name "%{emp_no} %{b}" | keep full_name, emp_no, b | limit 3; @@ -245,7 +245,7 @@ emp_no:integer | first_name:keyword | rest:keyword ; -// the query is incorrectly physically plan (fails the verification) in pre-8.13.0 versions +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions overwriteNameWhere#[skip:-8.12.99] from employees | sort emp_no asc | eval full_name = concat(first_name, " ", last_name) | dissect full_name "%{emp_no} %{b}" | where emp_no == "Bezalel" | keep full_name, emp_no, b | limit 3; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec index 592b06107c8b5..72660c11d8b73 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec @@ -601,3 +601,39 @@ Mokhtar |Bernatsky |38992 |BM Parto |Bamford |61805 |BP Premal |Baek |52833 |BP ; + + +caseInsensitiveRegex +from employees | where first_name RLIKE "(?i)geor.*" | keep first_name +; + +first_name:keyword +; + + +caseInsensitiveRegex2 +from employees | where first_name RLIKE "(?i)Geor.*" | keep first_name +; + +first_name:keyword +; + + +caseInsensitiveRegexFold +required_capability: fixed_regex_fold +row foo = "Bar" | where foo rlike "(?i)ba.*" +; + +foo:keyword +; + + +caseInsensitiveRegexFold2 +required_capability: fixed_regex_fold +row foo = "Bar" | where foo rlike "(?i)Ba.*" +; + +foo:keyword +; + + diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec index eece1bdfbffa4..6dc9148ffc0e8 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec @@ -199,7 +199,7 @@ null | null | null ; -// the query is incorrectly physically plan (fails the verification) in pre-8.13.0 versions +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions overwriteName#[skip:-8.12.99] from employees | sort emp_no asc | eval full_name = concat(first_name, " ", last_name) | grok full_name "%{WORD:emp_no} %{WORD:b}" | keep full_name, emp_no, b | limit 3; @@ -210,7 +210,7 @@ Parto Bamford | Parto | Bamford ; -// the query is incorrectly physically plan (fails the verification) in pre-8.13.0 versions +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions overwriteNameWhere#[skip:-8.12.99] from employees | sort emp_no asc | eval full_name = concat(first_name, " ", last_name) | grok full_name "%{WORD:emp_no} %{WORD:b}" | where emp_no == "Bezalel" | keep full_name, emp_no, b | limit 3; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/languages_non_unique_key.csv b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/languages_non_unique_key.csv new file mode 100644 index 0000000000000..d6381b174d739 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/languages_non_unique_key.csv @@ -0,0 +1,14 @@ +language_code:integer,language_name:keyword,country:keyword +1,English,Canada +1,English, +1,,United Kingdom +1,English,United States of America +2,German,[Germany,Austria] +2,German,Switzerland +2,German, +4,Quenya, +5,,Atlantis +[6,7],Mv-Lang,Mv-Land +[7,8],Mv-Lang2,Mv-Land2 +,Null-Lang,Null-Land +,Null-Lang2,Null-Land2 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec index 74b7a19d06bd6..7fed4f377096f 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec @@ -3,7 +3,6 @@ // Reuses the sample dataset and commands from enrich.csv-spec // -//TODO: this sometimes returns null instead of the looked up value (likely related to the execution order) basicOnTheDataNode required_capability: join_lookup_v5 @@ -102,6 +101,150 @@ emp_no:integer | language_code:integer | language_name:keyword 10003 | 4 | German ; +nonUniqueLeftKeyOnTheDataNode +required_capability: join_lookup_v5 + +FROM employees +| WHERE emp_no <= 10030 +| EVAL language_code = emp_no % 10 +| WHERE language_code < 3 +| LOOKUP JOIN languages_lookup ON language_code +| SORT emp_no +| KEEP emp_no, language_code, language_name +; + +emp_no:integer | language_code:integer | language_name:keyword +10001 |1 | English +10002 |2 | French +10010 |0 | null +10011 |1 | English +10012 |2 | French +10020 |0 | null +10021 |1 | English +10022 |2 | French +10030 |0 | null +; + +nonUniqueRightKeyOnTheDataNode +required_capability: join_lookup_v5 + +FROM employees +| EVAL language_code = emp_no % 10 +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| WHERE emp_no > 10090 AND emp_no < 10096 +| SORT emp_no +| EVAL country = MV_SORT(country) +| KEEP emp_no, language_code, language_name, country +; + +emp_no:integer | language_code:integer | language_name:keyword | country:keyword +10091 | 1 | [English, English, English] | [Canada, United Kingdom, United States of America] +10092 | 2 | [German, German, German] | [Austria, Germany, Switzerland] +10093 | 3 | null | null +10094 | 4 | Quenya | null +10095 | 5 | null | Atlantis +; + +nonUniqueRightKeyOnTheCoordinator +required_capability: join_lookup_v5 + +FROM employees +| SORT emp_no +| LIMIT 5 +| EVAL language_code = emp_no % 10 +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| EVAL country = MV_SORT(country) +| KEEP emp_no, language_code, language_name, country +; + +emp_no:integer | language_code:integer | language_name:keyword | country:keyword +10001 | 1 | [English, English, English] | [Canada, United Kingdom, United States of America] +10002 | 2 | [German, German, German] | [Austria, Germany, Switzerland] +10003 | 3 | null | null +10004 | 4 | Quenya | null +10005 | 5 | null | Atlantis +; + +nonUniqueRightKeyFromRow +required_capability: join_lookup_v5 + +ROW language_code = 2 +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| DROP country.keyword +| EVAL country = MV_SORT(country) +; + +language_code:integer | language_name:keyword | country:keyword +2 | [German, German, German] | [Austria, Germany, Switzerland] +; + +nullJoinKeyOnTheDataNode +required_capability: join_lookup_v5 + +FROM employees +| WHERE emp_no < 10004 +| EVAL language_code = emp_no % 10, language_code = CASE(language_code == 3, null, language_code) +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| SORT emp_no +| KEEP emp_no, language_code, language_name +; + +emp_no:integer | language_code:integer | language_name:keyword +10001 | 1 | [English, English, English] +10002 | 2 | [German, German, German] +10003 | null | null +; + + +mvJoinKeyOnTheDataNode +required_capability: join_lookup_v5 + +FROM employees +| WHERE 10003 < emp_no AND emp_no < 10008 +| EVAL language_code = emp_no % 10 +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| SORT emp_no +| KEEP emp_no, language_code, language_name +; + +emp_no:integer | language_code:integer | language_name:keyword +10004 | 4 | Quenya +10005 | 5 | null +10006 | 6 | Mv-Lang +10007 | 7 | [Mv-Lang, Mv-Lang2] +; + +mvJoinKeyFromRow +required_capability: join_lookup_v5 + +ROW language_code = [4, 5, 6, 7] +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| EVAL language_name = MV_SORT(language_name), country = MV_SORT(country) +| KEEP language_code, language_name, country +; + +language_code:integer | language_name:keyword | country:keyword +[4, 5, 6, 7] | [Mv-Lang, Mv-Lang2, Quenya] | [Atlantis, Mv-Land, Mv-Land2] +; + +mvJoinKeyFromRowExpanded +required_capability: join_lookup_v5 + +ROW language_code = [4, 5, 6, 7, 8] +| MV_EXPAND language_code +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| EVAL language_name = MV_SORT(language_name), country = MV_SORT(country) +| KEEP language_code, language_name, country +; + +language_code:integer | language_name:keyword | country:keyword +4 | Quenya | null +5 | null | Atlantis +6 | Mv-Lang | Mv-Land +7 | [Mv-Lang, Mv-Lang2] | [Mv-Land, Mv-Land2] +8 | Mv-Lang2 | Mv-Land2 +; + lookupIPFromRow required_capability: join_lookup_v5 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index add6f18887464..51080d8289353 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -564,7 +564,7 @@ c:long | gender:keyword | trunk_worked_seconds:long 0 | null | 200000000 ; -// the query is incorrectly physically plan (fails the verification) in pre-8.13.0 versions +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions byStringAndLongWithAlias#[skip:-8.12.99] FROM employees | EVAL trunk_worked_seconds = avg_worked_seconds / 100000000 * 100000000 @@ -720,7 +720,8 @@ c:long | d:date | gender:keyword | languages:integer 2 | 1987-01-01T00:00:00.000Z | M | 1 ; -byDateAndKeywordAndIntWithAlias +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions +byDateAndKeywordAndIntWithAlias#[skip:-8.12.99] from employees | eval d = date_trunc(1 year, hire_date) | rename gender as g, languages as l, emp_no as e | keep d, g, l, e | stats c = count(e) by d, g, l | sort c desc, d, l desc, g desc | limit 10; c:long | d:date | g:keyword | l:integer diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 6853747171048..e2e7b67ccf988 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -357,6 +357,11 @@ public enum Cap { */ DATE_TRUNC_DATE_NANOS(), + /** + * Support date nanos values as the field argument to bucket + */ + DATE_NANOS_BUCKET(), + /** * support aggregations on date nanos */ @@ -557,7 +562,12 @@ public enum Cap { /** * Additional types for match function and operator */ - MATCH_ADDITIONAL_TYPES; + MATCH_ADDITIONAL_TYPES, + + /** + * Fix for regex folding with case-insensitive pattern https://github.com/elastic/elasticsearch/issues/118371 + */ + FIXED_REGEX_FOLD; private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java index dc0e9fd1fb06d..4163a222b1a28 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java @@ -218,7 +218,7 @@ public Iterator toXContentChunked(ToXContent.Params params }); } - private boolean[] nullColumns() { + public boolean[] nullColumns() { boolean[] nullColumns = new boolean[columns.size()]; for (int c = 0; c < nullColumns.length; c++) { nullColumns[c] = allColumnsAreNull(c); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java index 9e40b85fd6590..347d542f5212d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java @@ -90,7 +90,7 @@ public class Bucket extends GroupingFunction implements Validatable, TwoOptional private final Expression to; @FunctionInfo( - returnType = { "double", "date" }, + returnType = { "double", "date", "date_nanos" }, description = """ Creates groups of values - buckets - out of a datetime or numeric input. The size of the buckets can either be provided directly, or chosen based on a recommended count and values range.""", @@ -169,7 +169,7 @@ public Bucket( Source source, @Param( name = "field", - type = { "integer", "long", "double", "date" }, + type = { "integer", "long", "double", "date", "date_nanos" }, description = "Numeric or date expression from which to derive buckets." ) Expression field, @Param( @@ -241,7 +241,7 @@ public boolean foldable() { @Override public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { - if (field.dataType() == DataType.DATETIME) { + if (field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS) { Rounding.Prepared preparedRounding; if (buckets.dataType().isWholeNumber()) { int b = ((Number) buckets.fold()).intValue(); @@ -314,8 +314,8 @@ private double pickRounding(int buckets, double from, double to) { } // supported parameter type combinations (1st, 2nd, 3rd, 4th): - // datetime, integer, string/datetime, string/datetime - // datetime, rounding/duration, -, - + // datetime/date_nanos, integer, string/datetime, string/datetime + // datetime/date_nanos, rounding/duration, -, - // numeric, integer, numeric, numeric // numeric, numeric, -, - @Override @@ -329,7 +329,7 @@ protected TypeResolution resolveType() { return TypeResolution.TYPE_RESOLVED; } - if (fieldType == DataType.DATETIME) { + if (fieldType == DataType.DATETIME || fieldType == DataType.DATE_NANOS) { TypeResolution resolution = isType( buckets, dt -> dt.isWholeNumber() || DataType.isTemporalAmount(dt), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java index cd42711177510..996c90a8e40bc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java @@ -79,7 +79,7 @@ public String getWriteableName() { } @Override - protected NodeInfo info() { + protected NodeInfo info() { return NodeInfo.create(this, RLike::new, field(), pattern(), caseInsensitive()); } @@ -93,6 +93,11 @@ protected TypeResolution resolveType() { return isString(field(), sourceText(), DEFAULT); } + @Override + public Boolean fold() { + return (Boolean) EvaluatorMapper.super.fold(); + } + @Override public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { return AutomataMatch.toEvaluator(source(), toEvaluator.apply(field()), pattern().createAutomaton()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java index c1b4f20f41795..d2edb0f92e8f2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java @@ -99,6 +99,11 @@ protected TypeResolution resolveType() { return isString(field(), sourceText(), DEFAULT); } + @Override + public Boolean fold() { + return (Boolean) EvaluatorMapper.super.fold(); + } + @Override public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { return AutomataMatch.toEvaluator( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/formatter/TextFormat.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/formatter/TextFormat.java index 5c0d6b138b326..7a7e4677b0dca 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/formatter/TextFormat.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/formatter/TextFormat.java @@ -39,7 +39,8 @@ public enum TextFormat implements MediaType { PLAIN_TEXT() { @Override public Iterator> format(RestRequest request, EsqlQueryResponse esqlResponse) { - return new TextFormatter(esqlResponse).format(hasHeader(request)); + boolean dropNullColumns = request.paramAsBoolean(DROP_NULL_COLUMNS_OPTION, false); + return new TextFormatter(esqlResponse, hasHeader(request), dropNullColumns).format(); } @Override @@ -282,15 +283,21 @@ public Set headerValues() { */ public static final String URL_PARAM_FORMAT = "format"; public static final String URL_PARAM_DELIMITER = "delimiter"; + public static final String DROP_NULL_COLUMNS_OPTION = "drop_null_columns"; public Iterator> format(RestRequest request, EsqlQueryResponse esqlResponse) { final var delimiter = delimiter(request); + boolean dropNullColumns = request.paramAsBoolean(DROP_NULL_COLUMNS_OPTION, false); + boolean[] dropColumns = dropNullColumns ? esqlResponse.nullColumns() : new boolean[esqlResponse.columns().size()]; return Iterators.concat( // if the header is requested return the info hasHeader(request) && esqlResponse.columns() != null - ? Iterators.single(writer -> row(writer, esqlResponse.columns().iterator(), ColumnInfo::name, delimiter)) + ? Iterators.single(writer -> row(writer, esqlResponse.columns().iterator(), ColumnInfo::name, delimiter, dropColumns)) : Collections.emptyIterator(), - Iterators.map(esqlResponse.values(), row -> writer -> row(writer, row, f -> Objects.toString(f, StringUtils.EMPTY), delimiter)) + Iterators.map( + esqlResponse.values(), + row -> writer -> row(writer, row, f -> Objects.toString(f, StringUtils.EMPTY), delimiter, dropColumns) + ) ); } @@ -313,9 +320,14 @@ public String contentType(RestRequest request) { } // utility method for consuming a row. - void row(Writer writer, Iterator row, Function toString, Character delimiter) throws IOException { + void row(Writer writer, Iterator row, Function toString, Character delimiter, boolean[] dropColumns) + throws IOException { boolean firstColumn = true; - while (row.hasNext()) { + for (int i = 0; row.hasNext(); i++) { + if (dropColumns[i]) { + row.next(); + continue; + } if (firstColumn) { firstColumn = false; } else { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/formatter/TextFormatter.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/formatter/TextFormatter.java index 0535e4adfe346..95b46958be351 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/formatter/TextFormatter.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/formatter/TextFormatter.java @@ -30,13 +30,17 @@ public class TextFormatter { private final EsqlQueryResponse response; private final int[] width; private final Function FORMATTER = Objects::toString; + private final boolean includeHeader; + private final boolean[] dropColumns; /** - * Create a new {@linkplain TextFormatter} for formatting responses. + * Create a new {@linkplain TextFormatter} for formatting responses */ - public TextFormatter(EsqlQueryResponse response) { + public TextFormatter(EsqlQueryResponse response, boolean includeHeader, boolean dropNullColumns) { this.response = response; var columns = response.columns(); + this.includeHeader = includeHeader; + this.dropColumns = dropNullColumns ? response.nullColumns() : new boolean[columns.size()]; // Figure out the column widths: // 1. Start with the widths of the column names width = new int[columns.size()]; @@ -58,12 +62,12 @@ public TextFormatter(EsqlQueryResponse response) { } /** - * Format the provided {@linkplain EsqlQueryResponse} optionally including the header lines. + * Format the provided {@linkplain EsqlQueryResponse} */ - public Iterator> format(boolean includeHeader) { + public Iterator> format() { return Iterators.concat( // The header lines - includeHeader && response.columns().size() > 0 ? Iterators.single(this::formatHeader) : Collections.emptyIterator(), + includeHeader && response.columns().isEmpty() == false ? Iterators.single(this::formatHeader) : Collections.emptyIterator(), // Now format the results. formatResults() ); @@ -71,6 +75,9 @@ public Iterator> format(boolean includeHead private void formatHeader(Writer writer) throws IOException { for (int i = 0; i < width.length; i++) { + if (dropColumns[i]) { + continue; + } if (i > 0) { writer.append('|'); } @@ -86,6 +93,9 @@ private void formatHeader(Writer writer) throws IOException { writer.append('\n'); for (int i = 0; i < width.length; i++) { + if (dropColumns[i]) { + continue; + } if (i > 0) { writer.append('+'); } @@ -98,6 +108,10 @@ private Iterator> formatResults() { return Iterators.map(response.values(), row -> writer -> { for (int i = 0; i < width.length; i++) { assert row.hasNext(); + if (dropColumns[i]) { + row.next(); + continue; + } if (i > 0) { writer.append('|'); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlTranslatorHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlTranslatorHandler.java index c07be82ed2a16..6fce6c43f12d4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlTranslatorHandler.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlTranslatorHandler.java @@ -17,9 +17,7 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; import org.elasticsearch.xpack.esql.core.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.core.querydsl.query.Query; -import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery; -import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter; import java.util.function.Supplier; @@ -30,11 +28,6 @@ public Query asQuery(Expression e) { return EsqlExpressionTranslators.toQuery(e, this); } - @Override - public Object convert(Object value, DataType dataType) { - return EsqlDataTypeConverter.convert(value, dataType); - } - @Override public Query wrapFunctionQuery(ScalarFunction sf, Expression field, Supplier querySupplier) { if (field instanceof FieldAttribute fa) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java index 7e7d91cdf76f4..f01b06c23e8a8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java @@ -12,15 +12,19 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.Rounding; +import org.elasticsearch.common.time.DateUtils; import org.elasticsearch.index.mapper.DateFieldMapper; +import org.elasticsearch.logging.LogManager; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.hamcrest.Matcher; +import org.hamcrest.Matchers; import java.time.Duration; +import java.time.Instant; import java.time.Period; import java.util.ArrayList; import java.util.List; @@ -38,6 +42,7 @@ public BucketTests(@Name("TestCase") Supplier testCas public static Iterable parameters() { List suppliers = new ArrayList<>(); dateCases(suppliers, "fixed date", () -> DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER.parseMillis("2023-02-17T09:00:00.00Z")); + dateNanosCases(suppliers, "fixed date nanos", () -> DateUtils.toLong(Instant.parse("2023-02-17T09:00:00.00Z"))); dateCasesWithSpan( suppliers, "fixed date with period", @@ -54,6 +59,22 @@ public static Iterable parameters() { Duration.ofDays(1L), "[86400000 in Z][fixed]" ); + dateNanosCasesWithSpan( + suppliers, + "fixed date nanos with period", + () -> DateUtils.toLong(Instant.parse("2023-01-01T00:00:00.00Z")), + DataType.DATE_PERIOD, + Period.ofYears(1), + "[YEAR_OF_CENTURY in Z][fixed to midnight]" + ); + dateNanosCasesWithSpan( + suppliers, + "fixed date nanos with duration", + () -> DateUtils.toLong(Instant.parse("2023-02-17T09:00:00.00Z")), + DataType.TIME_DURATION, + Duration.ofDays(1L), + "[86400000 in Z][fixed]" + ); numberCases(suppliers, "fixed long", DataType.LONG, () -> 100L); numberCasesWithSpan(suppliers, "fixed long with span", DataType.LONG, () -> 100L); numberCases(suppliers, "fixed int", DataType.INTEGER, () -> 100); @@ -142,6 +163,62 @@ private static void dateCasesWithSpan( })); } + private static void dateNanosCasesWithSpan( + List suppliers, + String name, + LongSupplier date, + DataType spanType, + Object span, + String spanStr + ) { + suppliers.add(new TestCaseSupplier(name, List.of(DataType.DATE_NANOS, spanType), () -> { + List args = new ArrayList<>(); + args.add(new TestCaseSupplier.TypedData(date.getAsLong(), DataType.DATE_NANOS, "field")); + args.add(new TestCaseSupplier.TypedData(span, spanType, "buckets").forceLiteral()); + return new TestCaseSupplier.TestCase( + args, + Matchers.startsWith("DateTruncDateNanosEvaluator[fieldVal=Attribute[channel=0], rounding=Rounding["), + DataType.DATE_NANOS, + resultsMatcher(args) + ); + })); + } + + private static void dateNanosCases(List suppliers, String name, LongSupplier date) { + for (DataType fromType : DATE_BOUNDS_TYPE) { + for (DataType toType : DATE_BOUNDS_TYPE) { + suppliers.add(new TestCaseSupplier(name, List.of(DataType.DATE_NANOS, DataType.INTEGER, fromType, toType), () -> { + List args = new ArrayList<>(); + args.add(new TestCaseSupplier.TypedData(date.getAsLong(), DataType.DATE_NANOS, "field")); + // TODO more "from" and "to" and "buckets" + args.add(new TestCaseSupplier.TypedData(50, DataType.INTEGER, "buckets").forceLiteral()); + args.add(dateBound("from", fromType, "2023-02-01T00:00:00.00Z")); + args.add(dateBound("to", toType, "2023-03-01T09:00:00.00Z")); + return new TestCaseSupplier.TestCase( + args, + Matchers.startsWith("DateTruncDateNanosEvaluator[fieldVal=Attribute[channel=0], rounding=Rounding["), + DataType.DATE_NANOS, + resultsMatcher(args) + ); + })); + // same as above, but a low bucket count and datetime bounds that match it (at hour span) + suppliers.add(new TestCaseSupplier(name, List.of(DataType.DATE_NANOS, DataType.INTEGER, fromType, toType), () -> { + List args = new ArrayList<>(); + args.add(new TestCaseSupplier.TypedData(date.getAsLong(), DataType.DATE_NANOS, "field")); + args.add(new TestCaseSupplier.TypedData(4, DataType.INTEGER, "buckets").forceLiteral()); + args.add(dateBound("from", fromType, "2023-02-17T09:00:00Z")); + args.add(dateBound("to", toType, "2023-02-17T12:00:00Z")); + return new TestCaseSupplier.TestCase( + args, + Matchers.startsWith("DateTruncDateNanosEvaluator[fieldVal=Attribute[channel=0], rounding=Rounding["), + DataType.DATE_NANOS, + equalTo(Rounding.builder(Rounding.DateTimeUnit.HOUR_OF_DAY).build().prepareForUnknown().round(date.getAsLong())) + ); + })); + } + } + } + private static final DataType[] NUMBER_BOUNDS_TYPES = new DataType[] { DataType.INTEGER, DataType.LONG, DataType.DOUBLE }; private static void numberCases(List suppliers, String name, DataType numberType, Supplier number) { @@ -221,7 +298,19 @@ private static TestCaseSupplier.TypedData keywordDateLiteral(String name, DataTy private static Matcher resultsMatcher(List typedData) { if (typedData.get(0).type() == DataType.DATETIME) { long millis = ((Number) typedData.get(0).data()).longValue(); - return equalTo(Rounding.builder(Rounding.DateTimeUnit.DAY_OF_MONTH).build().prepareForUnknown().round(millis)); + long expected = Rounding.builder(Rounding.DateTimeUnit.DAY_OF_MONTH).build().prepareForUnknown().round(millis); + LogManager.getLogger(getTestClass()).info("Expected: " + Instant.ofEpochMilli(expected)); + LogManager.getLogger(getTestClass()).info("Input: " + Instant.ofEpochMilli(millis)); + return equalTo(expected); + } + if (typedData.get(0).type() == DataType.DATE_NANOS) { + long nanos = ((Number) typedData.get(0).data()).longValue(); + long expected = DateUtils.toNanoSeconds( + Rounding.builder(Rounding.DateTimeUnit.DAY_OF_MONTH).build().prepareForUnknown().round(DateUtils.toMilliSeconds(nanos)) + ); + LogManager.getLogger(getTestClass()).info("Expected: " + DateUtils.toInstant(expected)); + LogManager.getLogger(getTestClass()).info("Input: " + DateUtils.toInstant(nanos)); + return equalTo(expected); } return equalTo(((Number) typedData.get(0).data()).doubleValue()); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatTests.java index fe1ac52427627..ca47e0cb329b3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatTests.java @@ -123,17 +123,17 @@ public void testTsvFormatWithEmptyData() { public void testCsvFormatWithRegularData() { String text = format(CSV, req(), regularData()); assertEquals(""" - string,number,location,location2\r - Along The River Bank,708,POINT (12.0 56.0),POINT (1234.0 5678.0)\r - Mind Train,280,POINT (-97.0 26.0),POINT (-9753.0 2611.0)\r + string,number,location,location2,null_field\r + Along The River Bank,708,POINT (12.0 56.0),POINT (1234.0 5678.0),\r + Mind Train,280,POINT (-97.0 26.0),POINT (-9753.0 2611.0),\r """, text); } public void testCsvFormatNoHeaderWithRegularData() { String text = format(CSV, reqWithParam("header", "absent"), regularData()); assertEquals(""" - Along The River Bank,708,POINT (12.0 56.0),POINT (1234.0 5678.0)\r - Mind Train,280,POINT (-97.0 26.0),POINT (-9753.0 2611.0)\r + Along The River Bank,708,POINT (12.0 56.0),POINT (1234.0 5678.0),\r + Mind Train,280,POINT (-97.0 26.0),POINT (-9753.0 2611.0),\r """, text); } @@ -146,14 +146,17 @@ public void testCsvFormatWithCustomDelimiterRegularData() { "number", "location", "location2", + "null_field", "Along The River Bank", "708", "POINT (12.0 56.0)", "POINT (1234.0 5678.0)", + "", "Mind Train", "280", "POINT (-97.0 26.0)", - "POINT (-9753.0 2611.0)" + "POINT (-9753.0 2611.0)", + "" ); List expectedTerms = terms.stream() .map(x -> x.contains(String.valueOf(delim)) ? '"' + x + '"' : x) @@ -167,6 +170,8 @@ public void testCsvFormatWithCustomDelimiterRegularData() { sb.append(expectedTerms.remove(0)); sb.append(delim); sb.append(expectedTerms.remove(0)); + sb.append(delim); + sb.append(expectedTerms.remove(0)); sb.append("\r\n"); } while (expectedTerms.size() > 0); assertEquals(sb.toString(), text); @@ -175,9 +180,9 @@ public void testCsvFormatWithCustomDelimiterRegularData() { public void testTsvFormatWithRegularData() { String text = format(TSV, req(), regularData()); assertEquals(""" - string\tnumber\tlocation\tlocation2 - Along The River Bank\t708\tPOINT (12.0 56.0)\tPOINT (1234.0 5678.0) - Mind Train\t280\tPOINT (-97.0 26.0)\tPOINT (-9753.0 2611.0) + string\tnumber\tlocation\tlocation2\tnull_field + Along The River Bank\t708\tPOINT (12.0 56.0)\tPOINT (1234.0 5678.0)\t + Mind Train\t280\tPOINT (-97.0 26.0)\tPOINT (-9753.0 2611.0)\t """, text); } @@ -245,6 +250,24 @@ public void testPlainTextEmptyCursorWithoutColumns() { ); } + public void testCsvFormatWithDropNullColumns() { + String text = format(CSV, reqWithParam("drop_null_columns", "true"), regularData()); + assertEquals(""" + string,number,location,location2\r + Along The River Bank,708,POINT (12.0 56.0),POINT (1234.0 5678.0)\r + Mind Train,280,POINT (-97.0 26.0),POINT (-9753.0 2611.0)\r + """, text); + } + + public void testTsvFormatWithDropNullColumns() { + String text = format(TSV, reqWithParam("drop_null_columns", "true"), regularData()); + assertEquals(""" + string\tnumber\tlocation\tlocation2 + Along The River Bank\t708\tPOINT (12.0 56.0)\tPOINT (1234.0 5678.0) + Mind Train\t280\tPOINT (-97.0 26.0)\tPOINT (-9753.0 2611.0) + """, text); + } + private static EsqlQueryResponse emptyData() { return new EsqlQueryResponse(singletonList(new ColumnInfoImpl("name", "keyword")), emptyList(), null, false, false, null); } @@ -256,7 +279,8 @@ private static EsqlQueryResponse regularData() { new ColumnInfoImpl("string", "keyword"), new ColumnInfoImpl("number", "integer"), new ColumnInfoImpl("location", "geo_point"), - new ColumnInfoImpl("location2", "cartesian_point") + new ColumnInfoImpl("location2", "cartesian_point"), + new ColumnInfoImpl("null_field", "keyword") ); BytesRefArray geoPoints = new BytesRefArray(2, BigArrays.NON_RECYCLING_INSTANCE); @@ -274,7 +298,8 @@ private static EsqlQueryResponse regularData() { blockFactory.newBytesRefBlockBuilder(2) .appendBytesRef(CARTESIAN.asWkb(new Point(1234, 5678))) .appendBytesRef(CARTESIAN.asWkb(new Point(-9753, 2611))) - .build() + .build(), + blockFactory.newConstantNullBlock(2) ) ); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatterTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatterTests.java index e735ba83168bb..4e90fe53d96d7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatterTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatterTests.java @@ -85,8 +85,6 @@ public class TextFormatterTests extends ESTestCase { new EsqlExecutionInfo(randomBoolean()) ); - TextFormatter formatter = new TextFormatter(esqlResponse); - /** * Tests for {@link TextFormatter#format} with header, values * of exactly the minimum column size, column names of exactly @@ -95,7 +93,7 @@ public class TextFormatterTests extends ESTestCase { * column size. */ public void testFormatWithHeader() { - String[] result = getTextBodyContent(formatter.format(true)).split("\n"); + String[] result = getTextBodyContent(new TextFormatter(esqlResponse, true, false).format()).split("\n"); assertThat(result, arrayWithSize(4)); assertEquals( " foo | bar |15charwidename!| null_field1 |superduperwidename!!!| baz |" @@ -119,6 +117,35 @@ public void testFormatWithHeader() { ); } + /** + * Tests for {@link TextFormatter#format} with drop_null_columns and + * truncation of long columns. + */ + public void testFormatWithDropNullColumns() { + String[] result = getTextBodyContent(new TextFormatter(esqlResponse, true, true).format()).split("\n"); + assertThat(result, arrayWithSize(4)); + assertEquals( + " foo | bar |15charwidename!|superduperwidename!!!| baz |" + + " date | location | location2 ", + result[0] + ); + assertEquals( + "---------------+---------------+---------------+---------------------+---------------+-------" + + "-----------------+------------------+----------------------", + result[1] + ); + assertEquals( + "15charwidedata!|1 |6.888 |12.0 |rabbit |" + + "1953-09-02T00:00:00.000Z|POINT (12.0 56.0) |POINT (1234.0 5678.0) ", + result[2] + ); + assertEquals( + "dog |2 |123124.888 |9912.0 |goat |" + + "2000-03-15T21:34:37.443Z|POINT (-97.0 26.0)|POINT (-9753.0 2611.0)", + result[3] + ); + } + /** * Tests for {@link TextFormatter#format} without header and * truncation of long columns. @@ -160,7 +187,7 @@ public void testFormatWithoutHeader() { new EsqlExecutionInfo(randomBoolean()) ); - String[] result = getTextBodyContent(new TextFormatter(response).format(false)).split("\n"); + String[] result = getTextBodyContent(new TextFormatter(response, false, false).format()).split("\n"); assertThat(result, arrayWithSize(2)); assertEquals( "doggie |4 |1.0 |null |77.0 |wombat |" @@ -199,8 +226,10 @@ public void testVeryLongPadding() { randomBoolean(), randomBoolean(), new EsqlExecutionInfo(randomBoolean()) - ) - ).format(false) + ), + false, + false + ).format() ) ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java index c2e85cc43284a..c4f4dac67acd3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java @@ -17,11 +17,11 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardLike; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java index e159e5ed0bd7d..bc22fbb6bd828 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java @@ -199,7 +199,7 @@ public void testPushDownFilterOnAliasInEval() { public void testPushDownLikeRlikeFilter() { EsRelation relation = relation(); - org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike conditionA = rlike(getFieldAttribute("a"), "foo"); + RLike conditionA = rlike(getFieldAttribute("a"), "foo"); WildcardLike conditionB = wildcardLike(getFieldAttribute("b"), "bar"); Filter fa = new Filter(EMPTY, relation, conditionA); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java index 20d638a113bf2..c7206c6971bde 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java @@ -11,11 +11,11 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardLike; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.esql.core.util.StringUtils; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import static java.util.Arrays.asList; diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index f5f682b143a72..a6888f28159f4 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -17,7 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -37,7 +37,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import java.io.IOException; @@ -151,7 +151,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { switch (model.getConfigurations().getTaskType()) { case ANY, TEXT_EMBEDDING -> { @@ -176,9 +176,24 @@ private InferenceTextEmbeddingFloatResults makeResults(List input, int d return new InferenceTextEmbeddingFloatResults(embeddings); } - private List makeChunkedResults(List input, int dimensions) { + private List makeChunkedResults(List input, int dimensions) { InferenceTextEmbeddingFloatResults nonChunkedResults = makeResults(input, dimensions); - return InferenceChunkedTextEmbeddingFloatResults.listOf(input, nonChunkedResults); + + var results = new ArrayList(); + for (int i = 0; i < input.size(); i++) { + results.add( + new ChunkedInferenceEmbeddingFloat( + List.of( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + nonChunkedResults.embeddings().get(i).values(), + input.get(i), + new ChunkedInference.TextOffset(0, input.get(i).length()) + ) + ) + ) + ); + } + return results; } protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index fa1e27005c287..bbb773aa5129a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -139,7 +139,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 64569fd8c5c6a..eea64304f503a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -35,9 +35,8 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; @@ -142,7 +141,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { switch (model.getConfigurations().getTaskType()) { case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeChunkedResults(input)); @@ -167,16 +166,22 @@ private SparseEmbeddingResults makeResults(List input) { return new SparseEmbeddingResults(embeddings); } - private List makeChunkedResults(List input) { - List results = new ArrayList<>(); + private List makeChunkedResults(List input) { + List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); } results.add( - new InferenceChunkedSparseEmbeddingResults( - List.of(new MlChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)) + new ChunkedInferenceEmbeddingSparse( + List.of( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + tokens, + input.get(i), + new ChunkedInference.TextOffset(0, input.get(i).length()) + ) + ) ) ); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index f7a05a27354ef..8325017f8e390 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -233,7 +233,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index a4187f4c4fa90..71fbcf6d8ef49 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -18,10 +18,6 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; @@ -108,7 +104,6 @@ public static List getNamedWriteables() { ); addInferenceResultsNamedWriteables(namedWriteables); - addChunkedInferenceResultsNamedWriteables(namedWriteables); // Empty default task settings namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); @@ -433,37 +428,6 @@ private static void addInternalNamedWriteables(List namedWriteables) { - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - ErrorChunkedInferenceResults.NAME, - ErrorChunkedInferenceResults::new - ) - ); - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - InferenceChunkedSparseEmbeddingResults.NAME, - InferenceChunkedSparseEmbeddingResults::new - ) - ); - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - InferenceChunkedTextEmbeddingFloatResults.NAME, - InferenceChunkedTextEmbeddingFloatResults::new - ) - ); - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - InferenceChunkedTextEmbeddingByteResults.NAME, - InferenceChunkedTextEmbeddingByteResults::new - ) - ); - } - private static void addChunkingSettingsNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry(ChunkingSettings.class, WordBoundaryChunkingSettings.NAME, WordBoundaryChunkingSettings::new) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index d178e927aa65d..a9195ea24af3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -29,7 +29,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; @@ -37,11 +37,12 @@ import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -141,7 +142,7 @@ private record FieldInferenceResponse( int inputOrder, boolean isOriginalFieldInput, Model model, - ChunkedInferenceServiceResults chunkedResults + ChunkedInference chunkedResults ) {} private record FieldInferenceResponseAccumulator( @@ -273,19 +274,19 @@ public void onFailure(Exception exc) { final List currentBatch = requests.subList(0, currentBatchSize); final List nextBatch = requests.subList(currentBatchSize, requests.size()); final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); - ActionListener> completionListener = new ActionListener<>() { + ActionListener> completionListener = new ActionListener<>() { @Override - public void onResponse(List results) { + public void onResponse(List results) { try { var requestsIterator = requests.iterator(); - for (ChunkedInferenceServiceResults result : results) { + for (ChunkedInference result : results) { var request = requestsIterator.next(); var acc = inferenceResults.get(request.index); - if (result instanceof ErrorChunkedInferenceResults error) { + if (result instanceof ChunkedInferenceError error) { acc.addFailure( new ElasticsearchException( "Exception when running inference id [{}] on field [{}]", - error.getException(), + error.exception(), inferenceProvider.model.getInferenceEntityId(), request.field ) @@ -359,7 +360,7 @@ private void addInferenceResponseFailure(int id, Exception failure) { * Otherwise, the source of the request is augmented with the field inference results under the * {@link SemanticTextField#INFERENCE_FIELD} field. */ - private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { + private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) throws IOException { if (response.failures().isEmpty() == false) { for (var failure : response.failures()) { item.abort(item.index(), failure); @@ -376,7 +377,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons // ensure that the order in the original field is consistent in case of multiple inputs Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); List inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList()); - List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); + List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); var result = new SemanticTextField( fieldName, inputs, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 2aef54e56f4b9..9b0b1104df660 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -11,18 +11,17 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingByte; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import java.util.ArrayList; import java.util.List; @@ -72,8 +71,8 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El private List>> floatResults; private List>> byteResults; private List>> sparseResults; - private AtomicArray errors; - private ActionListener> finalListener; + private AtomicArray errors; + private ActionListener> finalListener; public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, EmbeddingType embeddingType) { this(inputs, maxNumberOfInputsPerBatch, DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP, embeddingType); @@ -189,7 +188,7 @@ private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) { * @param finalListener The listener to call once all the batches are processed * @return Batches and listeners */ - public List batchRequestsWithListeners(ActionListener> finalListener) { + public List batchRequestsWithListeners(ActionListener> finalListener) { this.finalListener = finalListener; int numberOfRequests = batchedRequests.size(); @@ -331,9 +330,8 @@ private ElasticsearchStatusException unexpectedResultTypeException(String got, S @Override public void onFailure(Exception e) { - var errorResult = new ErrorChunkedInferenceResults(e); for (var pos : positions) { - errors.set(pos.inputIndex(), errorResult); + errors.set(pos.inputIndex(), e); } if (resultCount.incrementAndGet() == totalNumberOfRequests) { @@ -342,10 +340,10 @@ public void onFailure(Exception e) { } private void sendResponse() { - var response = new ArrayList(chunkedOffsets.size()); + var response = new ArrayList(chunkedOffsets.size()); for (int i = 0; i < chunkedOffsets.size(); i++) { if (errors.get(i) != null) { - response.add(errors.get(i)); + response.add(new ChunkedInferenceError(errors.get(i))); } else { response.add(mergeResultsWithInputs(i)); } @@ -355,16 +353,16 @@ private void sendResponse() { } } - private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) { + private ChunkedInference mergeResultsWithInputs(int resultIndex) { return switch (embeddingType) { - case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), floatResults.get(resultIndex)); - case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), byteResults.get(resultIndex)); - case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), sparseResults.get(resultIndex)); + case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex), floatResults.get(resultIndex)); + case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex), byteResults.get(resultIndex)); + case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex), sparseResults.get(resultIndex)); }; } - private InferenceChunkedTextEmbeddingFloatResults mergeFloatResultsWithInputs( - List chunks, + private ChunkedInferenceEmbeddingFloat mergeFloatResultsWithInputs( + ChunkOffsetsAndInput chunks, AtomicArray> debatchedResults ) { var all = new ArrayList(); @@ -375,18 +373,22 @@ private InferenceChunkedTextEmbeddingFloatResults mergeFloatResultsWithInputs( assert chunks.size() == all.size(); - var embeddingChunks = new ArrayList(); + var embeddingChunks = new ArrayList(); for (int i = 0; i < chunks.size(); i++) { embeddingChunks.add( - new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk(chunks.get(i), all.get(i).values()) + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + all.get(i).values(), + chunks.chunkText(i), + new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end()) + ) ); } - return new InferenceChunkedTextEmbeddingFloatResults(embeddingChunks); + return new ChunkedInferenceEmbeddingFloat(embeddingChunks); } - private InferenceChunkedTextEmbeddingByteResults mergeByteResultsWithInputs( - List chunks, + private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs( + ChunkOffsetsAndInput chunks, AtomicArray> debatchedResults ) { var all = new ArrayList(); @@ -397,18 +399,22 @@ private InferenceChunkedTextEmbeddingByteResults mergeByteResultsWithInputs( assert chunks.size() == all.size(); - var embeddingChunks = new ArrayList(); + var embeddingChunks = new ArrayList(); for (int i = 0; i < chunks.size(); i++) { embeddingChunks.add( - new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk(chunks.get(i), all.get(i).values()) + new ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk( + all.get(i).values(), + chunks.chunkText(i), + new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end()) + ) ); } - return new InferenceChunkedTextEmbeddingByteResults(embeddingChunks, false); + return new ChunkedInferenceEmbeddingByte(embeddingChunks); } - private InferenceChunkedSparseEmbeddingResults mergeSparseResultsWithInputs( - List chunks, + private ChunkedInferenceEmbeddingSparse mergeSparseResultsWithInputs( + ChunkOffsetsAndInput chunks, AtomicArray> debatchedResults ) { var all = new ArrayList(); @@ -419,12 +425,18 @@ private InferenceChunkedSparseEmbeddingResults mergeSparseResultsWithInputs( assert chunks.size() == all.size(); - var embeddingChunks = new ArrayList(); + var embeddingChunks = new ArrayList(); for (int i = 0; i < chunks.size(); i++) { - embeddingChunks.add(new MlChunkedTextExpansionResults.ChunkedResult(chunks.get(i), all.get(i).tokens())); + embeddingChunks.add( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + all.get(i).tokens(), + chunks.chunkText(i), + new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end()) + ) + ); } - return new InferenceChunkedSparseEmbeddingResults(embeddingChunks); + return new ChunkedInferenceEmbeddingSparse(embeddingChunks); } public record BatchRequest(List subBatches) { @@ -460,5 +472,13 @@ record ChunkOffsetsAndInput(List offsets, String input) { List toChunkText() { return offsets.stream().map(o -> input.substring(o.start(), o.end())).collect(Collectors.toList()); } + + int size() { + return offsets.size(); + } + + String chunkText(int index) { + return input.substring(offsets.get(index).start(), offsets.get(index).end()); + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 0f26f6577860f..d651729dee259 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -13,7 +13,7 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -31,7 +31,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -70,7 +69,7 @@ public record SemanticTextField(String fieldName, List originalValues, I public record InferenceResult(String inferenceId, ModelSettings modelSettings, List chunks) {} - public record Chunk(String text, BytesReference rawEmbeddings) {} + record Chunk(String text, BytesReference rawEmbeddings) {} public record ModelSettings( TaskType taskType, @@ -307,13 +306,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } /** - * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}. + * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}. */ - public static List toSemanticTextFieldChunks(List results, XContentType contentType) { + public static List toSemanticTextFieldChunks(List results, XContentType contentType) throws IOException { List chunks = new ArrayList<>(); for (var result : results) { - for (Iterator it = result.chunksAsMatchedTextAndByteReference(contentType.xContent()); it - .hasNext();) { + for (var it = result.chunksAsMatchedTextAndByteReference(contentType.xContent()); it.hasNext();) { var chunkAsByteReference = it.next(); chunks.add(new Chunk(chunkAsByteReference.matchedText(), chunkAsByteReference.bytesReference())); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index ce6ac6747eba8..208744b40ce9d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -13,7 +13,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -100,7 +100,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { init(); chunkedInfer(model, null, input, taskSettings, inputType, timeout, listener); @@ -114,7 +114,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { init(); // a non-null query is not supported and is dropped by all providers @@ -143,7 +143,7 @@ protected abstract void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ); public void start(Model model, ActionListener listener) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index ffd26b9ac534d..42a276c6ee838 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -299,7 +299,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof AlibabaCloudSearchModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index d224e50bb650d..a88881220f933 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -16,7 +16,7 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -126,7 +126,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout); if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index f1840af18779f..41852e4758a8c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -232,7 +232,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { throw new UnsupportedOperationException("Anthropic service does not support chunked inference"); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index f8ea11e4b15a5..bcd4a1abfbf00 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -119,7 +119,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) { var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index a38c265d2613c..0ed69604c258a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -273,7 +273,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof AzureOpenAiModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index ccb8d79dacd6c..a7d17192bfa92 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -272,7 +272,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof CohereModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index f107d64f93e4e..0fcd9db27a3d3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -29,8 +29,8 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionCreator; @@ -113,7 +113,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { // Pass-through without actually performing chunking (result will have a single chunk per input) ActionListener inferListener = listener.delegateFailureAndWrap( @@ -265,14 +265,12 @@ public void checkModelConfig(Model model, ActionListener listener) { } } - private static List translateToChunkedResults( - InferenceInputs inputs, - InferenceServiceResults inferenceResults - ) { + private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - return InferenceChunkedSparseEmbeddingResults.listOf(DocumentsOnlyInput.of(inputs).getInputs(), sparseEmbeddingResults); + var inputsAsList = DocumentsOnlyInput.of(inputs).getInputs(); + return ChunkedInferenceEmbeddingSparse.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { - return List.of(new ErrorChunkedInferenceResults(error.getException())); + return List.of(new ChunkedInferenceError(error.getException())); } else { String expectedClass = Strings.format("%s", SparseEmbeddingResults.class.getSimpleName()); throw createInvalidChunkedResultException(expectedClass, inferenceResults.getWriteableName()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 5856e08c8dc9b..25e8fc14da491 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceResults; @@ -699,7 +699,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { chunkedInfer(model, null, input, taskSettings, inputType, timeout, listener); } @@ -712,7 +712,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if ((TaskType.TEXT_EMBEDDING.equals(model.getTaskType()) || TaskType.SPARSE_EMBEDDING.equals(model.getTaskType())) == false) { listener.onFailure( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index b681722a82136..837a001d1f8f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -326,7 +326,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 87a2d98dca92c..b412f20289880 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -225,7 +225,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model; var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index b74ec01cd76e7..acb082cd2de8d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -119,7 +119,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof HuggingFaceModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 5b038781b96af..ef357ac308c2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -28,9 +28,9 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; @@ -44,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.util.ArrayList; import java.util.EnumSet; import java.util.HashMap; import java.util.List; @@ -51,6 +52,7 @@ import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; public class HuggingFaceElserService extends HuggingFaceBaseService { @@ -100,7 +102,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { ActionListener inferListener = listener.delegateFailureAndWrap( (delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response)) @@ -110,16 +112,31 @@ protected void doChunkedInfer( doInfer(model, inputs, taskSettings, inputType, timeout, inferListener); } - private static List translateToChunkedResults( - DocumentsOnlyInput inputs, - InferenceServiceResults inferenceResults - ) { + private static List translateToChunkedResults(DocumentsOnlyInput inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) { - return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs.getInputs(), textEmbeddingResults); + validateInputSizeAgainstEmbeddings(inputs.getInputs(), textEmbeddingResults.embeddings().size()); + + var results = new ArrayList(inputs.getInputs().size()); + + for (int i = 0; i < inputs.getInputs().size(); i++) { + results.add( + new ChunkedInferenceEmbeddingFloat( + List.of( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + textEmbeddingResults.embeddings().get(i).values(), + inputs.getInputs().get(i), + new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length()) + ) + ) + ) + ); + } + return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - return InferenceChunkedSparseEmbeddingResults.listOf(inputs.getInputs(), sparseEmbeddingResults); + var inputsAsList = DocumentsOnlyInput.of(inputs).getInputs(); + return ChunkedInferenceEmbeddingSparse.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { - return List.of(new ErrorChunkedInferenceResults(error.getException())); + return List.of(new ChunkedInferenceError(error.getException())); } else { String expectedClasses = Strings.format( "One of [%s,%s]", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index cc66d5fd7ee74..482554e060a47 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -295,7 +295,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 881e7d36f2a21..dc0576651aeba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -107,7 +107,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 7b51b068708ca..ed7a01e829dd2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -291,7 +291,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof OpenAiModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index c68a629b999c5..0b7d136ffb04c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -21,7 +21,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; @@ -34,8 +34,8 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -57,9 +57,9 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSparseEmbeddings; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult; import static org.hamcrest.Matchers.containsString; @@ -160,8 +160,8 @@ public void testItemFailures() throws Exception { Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10) ); - model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom"))); - model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success"))); + model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); + model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -290,10 +290,9 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool Answer chunkedInferAnswer = invocationOnMock -> { StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; List inputs = (List) invocationOnMock.getArguments()[2]; - ActionListener> listener = (ActionListener< - List>) invocationOnMock.getArguments()[6]; + ActionListener> listener = (ActionListener>) invocationOnMock.getArguments()[6]; Runnable runnable = () -> { - List results = new ArrayList<>(); + List results = new ArrayList<>(); for (String input : inputs) { results.add(model.getResults(input)); } @@ -348,7 +347,7 @@ private static BulkItemRequest[] randomBulkItemRequest( // This prevents a situation where embeddings in the expected docMap do not match those in the model, which could happen if // embeddings were overwritten. if (model.hasResult(inputText)) { - ChunkedInferenceServiceResults results = model.getResults(inputText); + var results = model.getResults(inputText); semanticTextField = semanticTextFieldFromChunkedInferenceResults( field, model, @@ -371,7 +370,7 @@ private static BulkItemRequest[] randomBulkItemRequest( } private static class StaticModel extends TestModel { - private final Map resultMap; + private final Map resultMap; StaticModel( String inferenceEntityId, @@ -397,11 +396,11 @@ public static StaticModel createRandomInstance() { ); } - ChunkedInferenceServiceResults getResults(String text) { - return resultMap.getOrDefault(text, new InferenceChunkedSparseEmbeddingResults(List.of())); + ChunkedInference getResults(String text) { + return resultMap.getOrDefault(text, new ChunkedInferenceEmbeddingSparse(List.of())); } - void putResult(String text, ChunkedInferenceServiceResults result) { + void putResult(String text, ChunkedInference result) { resultMap.put(text, result); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index dec7d15760aa6..03249163c7f82 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -8,12 +8,12 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingByte; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; @@ -313,16 +313,16 @@ public void testMergingListener_Float() { assertThat(finalListener.results, hasSize(4)); { var chunkedResult = finalListener.results.get(0); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertEquals("1st small", chunkedFloatResult.chunks().get(0).matchedText()); } { // this is the large input split in multiple chunks var chunkedResult = finalListener.results.get(1); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(6)); assertThat(chunkedFloatResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); assertThat(chunkedFloatResult.chunks().get(1).matchedText(), startsWith(" passage_input20 ")); @@ -333,15 +333,15 @@ public void testMergingListener_Float() { } { var chunkedResult = finalListener.results.get(2); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertEquals("2nd small", chunkedFloatResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(3); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertEquals("3rd small", chunkedFloatResult.chunks().get(0).matchedText()); } @@ -386,16 +386,16 @@ public void testMergingListener_Byte() { assertThat(finalListener.results, hasSize(4)); { var chunkedResult = finalListener.results.get(0); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText()); } { // this is the large input split in multiple chunks var chunkedResult = finalListener.results.get(1); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 ")); @@ -406,15 +406,15 @@ public void testMergingListener_Byte() { } { var chunkedResult = finalListener.results.get(2); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(3); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText()); } @@ -466,34 +466,34 @@ public void testMergingListener_Sparse() { assertThat(finalListener.results, hasSize(4)); { var chunkedResult = finalListener.results.get(0); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); - assertEquals("1st small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(1)); + assertEquals("1st small", chunkedSparseResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(1); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); - assertEquals("2nd small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(1)); + assertEquals("2nd small", chunkedSparseResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(2); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); - assertEquals("3rd small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(1)); + assertEquals("3rd small", chunkedSparseResult.chunks().get(0).matchedText()); } { // this is the large input split in multiple chunks var chunkedResult = finalListener.results.get(3); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(9)); // passage is split into 9 chunks, 10 words each - assertThat(chunkedSparseResult.getChunkedResults().get(0).matchedText(), startsWith("passage_input0 ")); - assertThat(chunkedSparseResult.getChunkedResults().get(1).matchedText(), startsWith(" passage_input10 ")); - assertThat(chunkedSparseResult.getChunkedResults().get(8).matchedText(), startsWith(" passage_input80 ")); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each + assertThat(chunkedSparseResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); + assertThat(chunkedSparseResult.chunks().get(1).matchedText(), startsWith(" passage_input10 ")); + assertThat(chunkedSparseResult.chunks().get(8).matchedText(), startsWith(" passage_input80 ")); } } @@ -501,13 +501,13 @@ public void testListenerErrorsWithWrongNumberOfResponses() { List inputs = List.of("1st small", "2nd small", "3rd small"); var failureMessage = new AtomicReference(); - var listener = new ActionListener>() { + var listener = new ActionListener>() { @Override - public void onResponse(List chunkedInferenceServiceResults) { - assertThat(chunkedInferenceServiceResults.get(0), instanceOf(ErrorChunkedInferenceResults.class)); - var error = (ErrorChunkedInferenceResults) chunkedInferenceServiceResults.get(0); - failureMessage.set(error.getException().getMessage()); + public void onResponse(List chunkedResults) { + assertThat(chunkedResults.get(0), instanceOf(ChunkedInferenceError.class)); + var error = (ChunkedInferenceError) chunkedResults.get(0); + failureMessage.set(error.exception().getMessage()); } @Override @@ -531,12 +531,12 @@ private ChunkedResultsListener testListener() { return new ChunkedResultsListener(); } - private static class ChunkedResultsListener implements ActionListener> { - List results; + private static class ChunkedResultsListener implements ActionListener> { + List results; @Override - public void onResponse(List chunkedInferenceServiceResults) { - this.results = chunkedInferenceServiceResults; + public void onResponse(List chunks) { + this.results = chunks; } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 563093930c358..dcdd9b3d42341 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -19,9 +19,8 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.core.utils.FloatConversionUtils; import org.elasticsearch.xpack.inference.model.TestModel; @@ -158,38 +157,39 @@ public void testModelSettingsValidation() { assertThat(ex.getMessage(), containsString("required [element_type] field is missing")); } - public static InferenceChunkedTextEmbeddingFloatResults randomInferenceChunkedTextEmbeddingFloatResults( - Model model, - List inputs - ) throws IOException { - List chunks = new ArrayList<>(); + public static ChunkedInferenceEmbeddingFloat randomChunkedInferenceEmbeddingFloat(Model model, List inputs) { + List chunks = new ArrayList<>(); for (String input : inputs) { float[] values = new float[model.getServiceSettings().dimensions()]; for (int j = 0; j < values.length; j++) { values[j] = (float) randomDouble(); } - chunks.add(new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk(input, values)); + chunks.add( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk(values, input, new ChunkedInference.TextOffset(0, input.length())) + ); } - return new InferenceChunkedTextEmbeddingFloatResults(chunks); + return new ChunkedInferenceEmbeddingFloat(chunks); } - public static InferenceChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { - List chunks = new ArrayList<>(); + public static ChunkedInferenceEmbeddingSparse randomChunkedInferenceEmbeddingSparse(List inputs) { + List chunks = new ArrayList<>(); for (String input : inputs) { var tokens = new ArrayList(); for (var token : input.split("\\s+")) { tokens.add(new WeightedToken(token, randomFloat())); } - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(input, tokens)); + chunks.add( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk(tokens, input, new ChunkedInference.TextOffset(0, input.length())) + ); } - return new InferenceChunkedSparseEmbeddingResults(chunks); + return new ChunkedInferenceEmbeddingSparse(chunks); } public static SemanticTextField randomSemanticText(String fieldName, Model model, List inputs, XContentType contentType) throws IOException { - ChunkedInferenceServiceResults results = switch (model.getTaskType()) { - case TEXT_EMBEDDING -> randomInferenceChunkedTextEmbeddingFloatResults(model, inputs); - case SPARSE_EMBEDDING -> randomSparseEmbeddings(inputs); + ChunkedInference results = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; return semanticTextFieldFromChunkedInferenceResults(fieldName, model, inputs, results, contentType); @@ -199,9 +199,9 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( String fieldName, Model model, List inputs, - ChunkedInferenceServiceResults results, + ChunkedInference results, XContentType contentType - ) { + ) throws IOException { return new SemanticTextField( fieldName, inputs, @@ -232,18 +232,24 @@ public static Object randomSemanticTextInput() { } } - public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField field) throws IOException { + public static ChunkedInference toChunkedResult(SemanticTextField field) throws IOException { switch (field.inference().modelSettings().taskType()) { case SPARSE_EMBEDDING -> { - List chunks = new ArrayList<>(); + List chunks = new ArrayList<>(); for (var chunk : field.inference().chunks()) { var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType()); - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(chunk.text(), tokens)); + chunks.add( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + tokens, + chunk.text(), + new ChunkedInference.TextOffset(0, chunk.text().length()) + ) + ); } - return new InferenceChunkedSparseEmbeddingResults(chunks); + return new ChunkedInferenceEmbeddingSparse(chunks); } case TEXT_EMBEDDING -> { - List chunks = new ArrayList<>(); + List chunks = new ArrayList<>(); for (var chunk : field.inference().chunks()) { double[] values = parseDenseVector( chunk.rawEmbeddings(), @@ -251,13 +257,14 @@ public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField f field.contentType() ); chunks.add( - new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + FloatConversionUtils.floatArrayOf(values), chunk.text(), - FloatConversionUtils.floatArrayOf(values) + new ChunkedInference.TextOffset(0, chunk.text().length()) ) ); } - return new InferenceChunkedTextEmbeddingFloatResults(chunks); + return new ChunkedInferenceEmbeddingFloat(chunks); } default -> throw new AssertionError("Invalid task_type: " + field.inference().modelSettings().taskType().name()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java deleted file mode 100644 index 4be00ea9e5822..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ElasticsearchTimeoutException; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; - -import java.io.IOException; - -public class ErrorChunkedInferenceResultsTests extends AbstractWireSerializingTestCase { - - public static ErrorChunkedInferenceResults createRandomResults() { - return new ErrorChunkedInferenceResults( - randomBoolean() - ? new ElasticsearchTimeoutException(randomAlphaOfLengthBetween(10, 50)) - : new ElasticsearchStatusException(randomAlphaOfLengthBetween(10, 50), randomFrom(RestStatus.values())) - ); - } - - @Override - protected Writeable.Reader instanceReader() { - return ErrorChunkedInferenceResults::new; - } - - @Override - protected ErrorChunkedInferenceResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected ErrorChunkedInferenceResults mutateInstance(ErrorChunkedInferenceResults instance) throws IOException { - return new ErrorChunkedInferenceResults(new RuntimeException(randomAlphaOfLengthBetween(10, 50))); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java deleted file mode 100644 index 8685ad9f0e124..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import static org.hamcrest.Matchers.is; - -public class InferenceChunkedSparseEmbeddingResultsTests extends AbstractWireSerializingTestCase { - - public static InferenceChunkedSparseEmbeddingResults createRandomResults() { - var chunks = new ArrayList(); - int numChunks = randomIntBetween(1, 5); - - for (int i = 0; i < numChunks; i++) { - var tokenWeights = new ArrayList(); - int numTokens = randomIntBetween(1, 8); - for (int j = 0; j < numTokens; j++) { - tokenWeights.add(new WeightedToken(Integer.toString(j), (float) randomDoubleBetween(0.0, 5.0, false))); - } - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(randomAlphaOfLength(6), tokenWeights)); - } - - return new InferenceChunkedSparseEmbeddingResults(chunks); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk() { - var entity = new InferenceChunkedSparseEmbeddingResults( - List.of(new MlChunkedTextExpansionResults.ChunkedResult("text", List.of(new WeightedToken("token", 0.1f)))) - ); - - assertThat( - entity.asMap(), - is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of(Map.of(ChunkedNlpInferenceResults.TEXT, "text", ChunkedNlpInferenceResults.INFERENCE, Map.of("token", 0.1f))) - ) - ) - ); - - String xContentResult = Strings.toString(entity, true, true); - assertThat(xContentResult, is(""" - { - "sparse_embedding_chunk" : [ - { - "text" : "text", - "inference" : { - "token" : 0.1 - } - } - ] - }""")); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk_FromSparseEmbeddingResults() { - var entity = InferenceChunkedSparseEmbeddingResults.listOf( - List.of("text"), - new SparseEmbeddingResults(List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1f)), false))) - ); - - assertThat(entity.size(), is(1)); - - var firstEntry = entity.get(0); - - assertThat( - firstEntry.asMap(), - is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of(Map.of(ChunkedNlpInferenceResults.TEXT, "text", ChunkedNlpInferenceResults.INFERENCE, Map.of("token", 0.1f))) - ) - ) - ); - - String xContentResult = Strings.toString(firstEntry, true, true); - assertThat(xContentResult, is(""" - { - "sparse_embedding_chunk" : [ - { - "text" : "text", - "inference" : { - "token" : 0.1 - } - } - ] - }""")); - } - - public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { - var exception = expectThrows( - IllegalArgumentException.class, - () -> InferenceChunkedSparseEmbeddingResults.listOf( - List.of("text", "text2"), - new SparseEmbeddingResults(List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1f)), false))) - ) - ); - - assertThat(exception.getMessage(), is("The number of inputs [2] does not match the embeddings [1]")); - } - - @Override - protected Writeable.Reader instanceReader() { - return InferenceChunkedSparseEmbeddingResults::new; - } - - @Override - protected InferenceChunkedSparseEmbeddingResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected InferenceChunkedSparseEmbeddingResults mutateInstance(InferenceChunkedSparseEmbeddingResults instance) throws IOException { - return randomValueOtherThan(instance, InferenceChunkedSparseEmbeddingResultsTests::createRandomResults); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java deleted file mode 100644 index c1215e8a3d71b..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import static org.hamcrest.Matchers.is; - -public class InferenceChunkedTextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase< - InferenceChunkedTextEmbeddingByteResults> { - - public static InferenceChunkedTextEmbeddingByteResults createRandomResults() { - int numChunks = randomIntBetween(1, 5); - var chunks = new ArrayList(numChunks); - - for (int i = 0; i < numChunks; i++) { - chunks.add(createRandomChunk()); - } - - return new InferenceChunkedTextEmbeddingByteResults(chunks, randomBoolean()); - } - - private static InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk createRandomChunk() { - int columns = randomIntBetween(1, 10); - byte[] bytes = new byte[columns]; - for (int i = 0; i < columns; i++) { - bytes[i] = randomByte(); - } - - return new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk(randomAlphaOfLength(6), bytes); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk() { - var entity = new InferenceChunkedTextEmbeddingByteResults( - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })), - false - ); - - assertThat( - entity.asMap(), - is( - Map.of( - InferenceChunkedTextEmbeddingByteResults.FIELD_NAME, - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })) - ) - ) - ); - String xContentResult = Strings.toString(entity, true, true); - assertThat(xContentResult, is(""" - { - "text_embedding_byte_chunk" : [ - { - "text" : "text", - "inference" : [ - 1 - ] - } - ] - }""")); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk_ForTextEmbeddingByteResults() { - var entity = InferenceChunkedTextEmbeddingByteResults.listOf( - List.of("text"), - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 1 })) - ) - ); - - assertThat(entity.size(), is(1)); - - var firstEntry = entity.get(0); - - assertThat( - firstEntry.asMap(), - is( - Map.of( - InferenceChunkedTextEmbeddingByteResults.FIELD_NAME, - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })) - ) - ) - ); - String xContentResult = Strings.toString(firstEntry, true, true); - assertThat(xContentResult, is(""" - { - "text_embedding_byte_chunk" : [ - { - "text" : "text", - "inference" : [ - 1 - ] - } - ] - }""")); - } - - public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { - var exception = expectThrows( - IllegalArgumentException.class, - () -> InferenceChunkedTextEmbeddingByteResults.listOf( - List.of("text", "text2"), - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 1 })) - ) - ) - ); - - assertThat(exception.getMessage(), is("The number of inputs [2] does not match the embeddings [1]")); - } - - @Override - protected Writeable.Reader instanceReader() { - return InferenceChunkedTextEmbeddingByteResults::new; - } - - @Override - protected InferenceChunkedTextEmbeddingByteResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected InferenceChunkedTextEmbeddingByteResults mutateInstance(InferenceChunkedTextEmbeddingByteResults instance) - throws IOException { - return randomValueOtherThan(instance, InferenceChunkedTextEmbeddingByteResultsTests::createRandomResults); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 6768583598b2d..f8a5bd8a54d31 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -135,7 +135,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index a154ded395822..46c3a062f7db0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -29,8 +29,8 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -392,7 +392,7 @@ public void testChunkedInfer_InvalidTaskType() throws IOException { null ); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); try { service.chunkedInfer( model, @@ -417,7 +417,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { var model = createModelForTaskType(taskType, chunkingSettings); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); var results = listener.actionGet(TIMEOUT); @@ -425,9 +425,9 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin assertThat(results, hasSize(2)); var firstResult = results.get(0); if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - assertThat(firstResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + assertThat(firstResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { - assertThat(firstResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + assertThat(firstResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 197606df02a1f..80c2b672a8feb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; @@ -1551,7 +1551,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep requestSender.enqueue(mockResults2); } - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("abc", "xyz"), @@ -1564,15 +1564,15 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("abc", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123F, 0.678F }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("xyz", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.223F, 0.278F }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 08fc097a56f40..a9ef4bd551175 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -36,7 +36,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1186,7 +1186,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("foo", "bar"), @@ -1199,15 +1199,15 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 1.0123f, -1.0123f }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index cc68d54b11e91..ac8e769ef13a3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1335,7 +1335,7 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); model.setUri(new URI(getUrl(webServer))); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("foo", "bar"), @@ -1348,15 +1348,15 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 1.123f, -1.123f }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index a8d1a1ec28d09..e207bcfdeada5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -36,8 +36,8 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingByte; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1442,7 +1442,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); // 2 inputs service.chunkedInfer( model, @@ -1456,15 +1456,15 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding(), 0.0f); @@ -1532,7 +1532,7 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { "model", CohereEmbeddingType.BYTE ); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); // 2 inputs service.chunkedInfer( model, @@ -1546,15 +1546,15 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingByteResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingByte.class)); + var floatResult = (ChunkedInferenceEmbeddingByte) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new byte[] { 23, -23 }, floatResult.chunks().get(0).embedding()); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var byteResult = (InferenceChunkedTextEmbeddingByteResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingByte.class)); + var byteResult = (ChunkedInferenceEmbeddingByte) results.get(1); assertThat(byteResult.chunks(), hasSize(1)); assertEquals("bar", byteResult.chunks().get(0).matchedText()); assertArrayEquals(new byte[] { 24, -24 }, byteResult.chunks().get(0).embedding()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index dae99cea77ec4..11dc7206d959a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -31,8 +31,8 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -453,7 +453,7 @@ public void testChunkedInfer_PassesThrough() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(eisGatewayUrl); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("input text"), @@ -464,22 +464,21 @@ public void testChunkedInfer_PassesThrough() throws IOException { ); var results = listener.actionGet(TIMEOUT); - MatcherAssert.assertThat( - results.get(0).asMap(), - Matchers.is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of( - Map.of( - ChunkedNlpInferenceResults.TEXT, - "input text", - ChunkedNlpInferenceResults.INFERENCE, - Map.of("hello", 2.1259406f, "greet", 1.7073475f) - ) + assertThat(results.get(0), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var sparseResult = (ChunkedInferenceEmbeddingSparse) results.get(0); + assertThat( + sparseResult.chunks(), + is( + List.of( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), + "input text", + new ChunkedInference.TextOffset(0, "input text".length()) ) ) ) ); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 17e6583f11c8f..21d7efbc7b03c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -24,7 +24,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; @@ -42,9 +42,9 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; @@ -865,26 +865,26 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var result1 = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var result1 = (ChunkedInferenceEmbeddingFloat) chunkedResponse.get(0); assertThat(result1.chunks(), hasSize(1)); assertArrayEquals( ((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(), - result1.getChunks().get(0).embedding(), + result1.chunks().get(0).embedding(), 0.0001f ); - assertEquals("foo", result1.getChunks().get(0).matchedText()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var result2 = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(1); + assertEquals("foo", result1.chunks().get(0).matchedText()); + assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var result2 = (ChunkedInferenceEmbeddingFloat) chunkedResponse.get(1); assertThat(result2.chunks(), hasSize(1)); assertArrayEquals( ((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(), - result2.getChunks().get(0).embedding(), + result2.chunks().get(0).embedding(), 0.0001f ); - assertEquals("bar", result2.getChunks().get(0).matchedText()); + assertEquals("bar", result2.chunks().get(0).matchedText()); gotResults.set(true); }, ESTestCase::fail); @@ -940,22 +940,22 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result1 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(0); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), - result1.getChunkedResults().get(0).weightedTokens() + result1.chunks().get(0).weightedTokens() ); - assertEquals("foo", result1.getChunkedResults().get(0).matchedText()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); + assertEquals("foo", result1.chunks().get(0).matchedText()); + assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result2 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(1); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), - result2.getChunkedResults().get(0).weightedTokens() + result2.chunks().get(0).weightedTokens() ); - assertEquals("bar", result2.getChunkedResults().get(0).matchedText()); + assertEquals("bar", result2.chunks().get(0).matchedText()); gotResults.set(true); }, ESTestCase::fail); @@ -1010,22 +1010,22 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws Inte var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result1 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(0); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), - result1.getChunkedResults().get(0).weightedTokens() + result1.chunks().get(0).weightedTokens() ); - assertEquals("foo", result1.getChunkedResults().get(0).matchedText()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); + assertEquals("foo", result1.chunks().get(0).matchedText()); + assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result2 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(1); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), - result2.getChunkedResults().get(0).weightedTokens() + result2.chunks().get(0).weightedTokens() ); - assertEquals("bar", result2.getChunkedResults().get(0).matchedText()); + assertEquals("bar", result2.chunks().get(0).matchedText()); gotResults.set(true); }, ESTestCase::fail); @@ -1126,12 +1126,12 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(3)); // a single failure fails the batch for (var er : chunkedResponse) { - assertThat(er, instanceOf(ErrorChunkedInferenceResults.class)); - assertEquals("boom", ((ErrorChunkedInferenceResults) er).getException().getMessage()); + assertThat(er, instanceOf(ChunkedInferenceError.class)); + assertEquals("boom", ((ChunkedInferenceError) er).exception().getMessage()); } gotResults.set(true); @@ -1190,10 +1190,10 @@ public void testChunkingLargeDocument() throws InterruptedException { var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(1)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var sparseResults = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var sparseResults = (ChunkedInferenceEmbeddingFloat) chunkedResponse.get(0); assertThat(sparseResults.chunks(), hasSize(numChunks)); gotResults.set(true); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 0e2f4847c88ee..ea82c09eef1e8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -36,7 +36,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -868,7 +868,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer(model, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); var results = listener.actionGet(TIMEOUT); @@ -876,8 +876,8 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed // first result { - assertThat(results.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); @@ -885,8 +885,8 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed // second result { - assertThat(results.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index c70692eb29a27..64f86a0d0f280 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -24,14 +24,13 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -39,7 +38,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; @@ -49,6 +47,8 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.mockito.Mockito.mock; @@ -90,7 +90,7 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = HuggingFaceElserModelTests.createModel(getUrl(webServer), "secret"); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("abc"), @@ -101,14 +101,16 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE ); var result = listener.actionGet(TIMEOUT).get(0); - - MatcherAssert.assertThat( - result.asMap(), - Matchers.is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of( - Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, Map.of(".", 0.13315596f)) + assertThat(result, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var sparseResult = (ChunkedInferenceEmbeddingSparse) result; + assertThat( + sparseResult.chunks(), + is( + List.of( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + List.of(new WeightedToken(".", 0.13315596f)), + "abc", + new ChunkedInference.TextOffset(0, "abc".length()) ) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 3e5e2d7c12074..f3d7cbfea38dc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -34,8 +34,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -46,7 +45,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -59,7 +57,6 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; -import static org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResultsTests.asMapWithListsInsteadOfArrays; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; @@ -774,7 +771,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("abc"), @@ -785,19 +782,12 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th ); var result = listener.actionGet(TIMEOUT).get(0); - assertThat(result, CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - - MatcherAssert.assertThat( - asMapWithListsInsteadOfArrays((InferenceChunkedTextEmbeddingFloatResults) result), - Matchers.is( - Map.of( - InferenceChunkedTextEmbeddingFloatResults.FIELD_NAME, - List.of( - Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, List.of(-0.0123f, 0.0123f)) - ) - ) - ) - ); + assertThat(result, CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var embeddingResult = (ChunkedInferenceEmbeddingFloat) result; + assertThat(embeddingResult.chunks(), hasSize(1)); + assertThat(embeddingResult.chunks().get(0).matchedText(), is("abc")); + assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length()))); + assertArrayEquals(new float[] { -0.0123f, 0.0123f }, embeddingResult.chunks().get(0).embedding(), 0.001f); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( @@ -828,7 +818,7 @@ public void testChunkedInfer() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("abc"), @@ -841,8 +831,8 @@ public void testChunkedInfer() throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(1)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("abc", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 5aa826f1d80fe..3d298823ea19f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.action.ibmwatsonx.IbmWatsonxActionCreator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -684,7 +684,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws apiKey, getUrl(webServer) ); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); var results = listener.actionGet(TIMEOUT); @@ -692,8 +692,8 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws // first result { - assertThat(results.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); @@ -701,8 +701,8 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws // second result { - assertThat(results.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index e0cfa4a5ca4be..c547531ec1289 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -34,7 +34,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -665,7 +665,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("abc", "def"), @@ -679,14 +679,14 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertTrue(Arrays.equals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding())); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertTrue(Arrays.equals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding())); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index c812ca67861fb..67cff8ef1e48d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -1613,7 +1613,7 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("foo", "bar"), @@ -1626,15 +1626,15 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding())); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding())); diff --git a/x-pack/plugin/monitoring/src/main/java/org/elasticsearch/xpack/monitoring/MonitoringTemplateRegistry.java b/x-pack/plugin/monitoring/src/main/java/org/elasticsearch/xpack/monitoring/MonitoringTemplateRegistry.java index e0433ea6fdd71..cfd322d04e92f 100644 --- a/x-pack/plugin/monitoring/src/main/java/org/elasticsearch/xpack/monitoring/MonitoringTemplateRegistry.java +++ b/x-pack/plugin/monitoring/src/main/java/org/elasticsearch/xpack/monitoring/MonitoringTemplateRegistry.java @@ -77,7 +77,7 @@ public class MonitoringTemplateRegistry extends IndexTemplateRegistry { * writes monitoring data in ECS format as of 8.0. These templates define the ECS schema as well as alias fields for the old monitoring * mappings that point to the corresponding ECS fields. */ - public static final int STACK_MONITORING_REGISTRY_VERSION = 8_00_00_99 + 18; + public static final int STACK_MONITORING_REGISTRY_VERSION = 8_00_00_99 + 19; private static final String STACK_MONITORING_REGISTRY_VERSION_VARIABLE = "xpack.stack.monitoring.template.release.version"; private static final String STACK_TEMPLATE_VERSION = "8"; private static final String STACK_TEMPLATE_VERSION_VARIABLE = "xpack.stack.monitoring.template.version";