diff --git a/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/doc/DocSnippetTaskSpec.groovy b/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/doc/DocSnippetTaskSpec.groovy index 85ce3c1804474..96888357d8433 100644 --- a/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/doc/DocSnippetTaskSpec.groovy +++ b/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/doc/DocSnippetTaskSpec.groovy @@ -535,7 +535,7 @@ GET /_analyze ] } ], - "text": "My license plate is ٢٥٠١٥" + "text": "My license plate is empty" } ---- """ @@ -557,7 +557,7 @@ GET /_analyze ] } ], - "text": "My license plate is ٢٥٠١٥" + "text": "My license plate is empty" }""" } diff --git a/docs/changelog/106068.yaml b/docs/changelog/106068.yaml index fbc30aa86a33e..51bcc2bcf98b0 100644 --- a/docs/changelog/106068.yaml +++ b/docs/changelog/106068.yaml @@ -3,3 +3,19 @@ summary: Add `modelId` and `modelText` to `KnnVectorQueryBuilder` area: Search type: enhancement issues: [] +highlight: + title: Query phase KNN now supports query_vector_builder + body: |- + It is now possible to pass `model_text` and `model_id` within a `knn` query + in the [query DSL](https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-knn-query.html) to convert a text query into a dense vector and run the + nearest neighbor query on it, instead of requiring the dense vector to be + directly passed (within the `query_vector` parameter). Similar to the + [top-level knn query](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) (executed in the DFS phase), it is possible to supply + a `query_vector_builder` object containing a `text_embedding` object with + `model_text` (the text query to be converted into a dense vector) and + `model_id` (the identifier of a deployed model responsible for transforming + the text query into a dense vector). Note that an embedding model with the + referenced `model_id` needs to be [deployed on a ML node](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html). + in the cluster. + notable: true + diff --git a/docs/changelog/106796.yaml b/docs/changelog/106796.yaml new file mode 100644 index 0000000000000..83eb99dba1603 --- /dev/null +++ b/docs/changelog/106796.yaml @@ -0,0 +1,5 @@ +pr: 106796 +summary: Bulk loading enrich fields in ESQL +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/changelog/107196.yaml b/docs/changelog/107196.yaml new file mode 100644 index 0000000000000..9892ccf71856f --- /dev/null +++ b/docs/changelog/107196.yaml @@ -0,0 +1,5 @@ +pr: 107196 +summary: Add metric for calculating index flush time excluding waiting on locks +area: Engine +type: enhancement +issues: [] diff --git a/docs/changelog/107370.yaml b/docs/changelog/107370.yaml new file mode 100644 index 0000000000000..e7bdeef68cffe --- /dev/null +++ b/docs/changelog/107370.yaml @@ -0,0 +1,5 @@ +pr: 107370 +summary: Fork when handling remote field-caps responses +area: Search +type: bug +issues: [] diff --git a/docs/changelog/107432.yaml b/docs/changelog/107432.yaml new file mode 100644 index 0000000000000..c492644c5baf2 --- /dev/null +++ b/docs/changelog/107432.yaml @@ -0,0 +1,6 @@ +pr: 107432 +summary: "Percolator named queries: rewrite for matched info" +area: Percolator +type: bug +issues: + - 107176 diff --git a/docs/reference/cluster/nodes-stats.asciidoc b/docs/reference/cluster/nodes-stats.asciidoc index c008b074acccd..07328ba98bcec 100644 --- a/docs/reference/cluster/nodes-stats.asciidoc +++ b/docs/reference/cluster/nodes-stats.asciidoc @@ -626,6 +626,7 @@ Total time spent performing flush operations. (integer) Total time in milliseconds spent performing flush operations. + ======= `warmer`:: diff --git a/docs/reference/inference/post-inference.asciidoc b/docs/reference/inference/post-inference.asciidoc index 10ed9f20ce21f..5a9ae283e895c 100644 --- a/docs/reference/inference/post-inference.asciidoc +++ b/docs/reference/inference/post-inference.asciidoc @@ -54,6 +54,7 @@ The unique identifier of the {infer} endpoint. (Optional, string) The type of {infer} task that the model performs. + [discrete] [[post-inference-api-query-params]] ==== {api-query-parms-title} diff --git a/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/DataStreamIT.java b/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/DataStreamIT.java index a0a391a0f019b..cf4eaab763011 100644 --- a/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/DataStreamIT.java +++ b/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/DataStreamIT.java @@ -1778,22 +1778,9 @@ public void testRemoveGhostReference() throws Exception { @Override public ClusterState execute(ClusterState currentState) throws Exception { DataStream original = currentState.getMetadata().dataStreams().get(dataStreamName); - DataStream broken = new DataStream( - original.getName(), - List.of(new Index(original.getIndices().get(0).getName(), "broken"), original.getIndices().get(1)), - original.getGeneration(), - original.getMetadata(), - original.isHidden(), - original.isReplicated(), - original.isSystem(), - original.isAllowCustomRouting(), - original.getIndexMode(), - original.getLifecycle(), - original.isFailureStore(), - original.getFailureIndices(), - original.rolloverOnWrite(), - original.getAutoShardingEvent() - ); + DataStream broken = original.copy() + .setIndices(List.of(new Index(original.getIndices().get(0).getName(), "broken"), original.getIndices().get(1))) + .build(); brokenDataStreamHolder.set(broken); return ClusterState.builder(currentState) .metadata(Metadata.builder(currentState.getMetadata()).put(broken).build()) diff --git a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/GetDataStreamsTransportAction.java b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/GetDataStreamsTransportAction.java index 41e62508cafbb..0fc00ad9ebe59 100644 --- a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/GetDataStreamsTransportAction.java +++ b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/GetDataStreamsTransportAction.java @@ -139,7 +139,7 @@ static GetDataStreamAction.Response innerOperation( Map backingIndicesSettingsValues = new HashMap<>(); Metadata metadata = state.getMetadata(); collectIndexSettingsValues(dataStream, backingIndicesSettingsValues, metadata, dataStream.getIndices()); - if (DataStream.isFailureStoreEnabled() && dataStream.getFailureIndices().isEmpty() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && dataStream.getFailureIndices().isEmpty() == false) { collectIndexSettingsValues(dataStream, backingIndicesSettingsValues, metadata, dataStream.getFailureIndices()); } diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/DataStreamIndexSettingsProviderTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/DataStreamIndexSettingsProviderTests.java index 11446a2a2a761..5933b5caba001 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/DataStreamIndexSettingsProviderTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/DataStreamIndexSettingsProviderTests.java @@ -301,24 +301,7 @@ public void testGetAdditionalIndexSettingsDataStreamAlreadyCreatedTimeSettingsMi ).getMetadata() ); DataStream ds = mb.dataStream(dataStreamName); - mb.put( - new DataStream( - ds.getName(), - ds.getIndices(), - ds.getGeneration(), - ds.getMetadata(), - ds.isHidden(), - ds.isReplicated(), - ds.isSystem(), - ds.isAllowCustomRouting(), - IndexMode.TIME_SERIES, - ds.getLifecycle(), - ds.isFailureStore(), - ds.getFailureIndices(), - ds.rolloverOnWrite(), - ds.getAutoShardingEvent() - ) - ); + mb.put(ds.copy().setIndexMode(IndexMode.TIME_SERIES).build()); Metadata metadata = mb.build(); Instant now = twoHoursAgo.plus(2, ChronoUnit.HOURS); diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/MetadataDataStreamRolloverServiceTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/MetadataDataStreamRolloverServiceTests.java index 61f0efe89504d..2185f8f50a93f 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/MetadataDataStreamRolloverServiceTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/MetadataDataStreamRolloverServiceTests.java @@ -60,17 +60,10 @@ public class MetadataDataStreamRolloverServiceTests extends ESTestCase { public void testRolloverClusterStateForDataStream() throws Exception { Instant now = Instant.now(); String dataStreamName = "logs-my-app"; - final DataStream dataStream = new DataStream( + final DataStream dataStream = DataStream.builder( dataStreamName, - List.of(new Index(DataStream.getDefaultBackingIndexName(dataStreamName, 1, now.toEpochMilli()), "uuid")), - 1, - null, - false, - false, - false, - false, - IndexMode.TIME_SERIES - ); + List.of(new Index(DataStream.getDefaultBackingIndexName(dataStreamName, 1, now.toEpochMilli()), "uuid")) + ).setIndexMode(IndexMode.TIME_SERIES).build(); ComposableIndexTemplate template = ComposableIndexTemplate.builder() .indexPatterns(List.of(dataStream.getName() + "*")) .template( @@ -168,17 +161,10 @@ public void testRolloverAndMigrateDataStream() throws Exception { Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS); String dataStreamName = "logs-my-app"; IndexMode dsIndexMode = randomBoolean() ? null : IndexMode.STANDARD; - final DataStream dataStream = new DataStream( + final DataStream dataStream = DataStream.builder( dataStreamName, - List.of(new Index(DataStream.getDefaultBackingIndexName(dataStreamName, 1, now.toEpochMilli()), "uuid")), - 1, - null, - false, - false, - false, - false, - dsIndexMode - ); + List.of(new Index(DataStream.getDefaultBackingIndexName(dataStreamName, 1, now.toEpochMilli()), "uuid")) + ).setIndexMode(dsIndexMode).build(); ComposableIndexTemplate template = ComposableIndexTemplate.builder() .indexPatterns(List.of(dataStream.getName() + "*")) .template( @@ -257,17 +243,10 @@ public void testRolloverAndMigrateDataStream() throws Exception { public void testChangingIndexModeFromTimeSeriesToSomethingElseNoEffectOnExistingDataStreams() throws Exception { Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS); String dataStreamName = "logs-my-app"; - final DataStream dataStream = new DataStream( + final DataStream dataStream = DataStream.builder( dataStreamName, - List.of(new Index(DataStream.getDefaultBackingIndexName(dataStreamName, 1, now.toEpochMilli()), "uuid")), - 1, - null, - false, - false, - false, - false, - IndexMode.TIME_SERIES - ); + List.of(new Index(DataStream.getDefaultBackingIndexName(dataStreamName, 1, now.toEpochMilli()), "uuid")) + ).setIndexMode(IndexMode.TIME_SERIES).build(); ComposableIndexTemplate template = ComposableIndexTemplate.builder() .indexPatterns(List.of(dataStream.getName() + "*")) .template( @@ -479,17 +458,7 @@ private static ClusterState createClusterState(String dataStreamName, int number for (int i = 1; i <= numberOfBackingIndices; i++) { backingIndices.add(new Index(DataStream.getDefaultBackingIndexName(dataStreamName, i, now.toEpochMilli()), "uuid" + i)); } - final DataStream dataStream = new DataStream( - dataStreamName, - backingIndices, - numberOfBackingIndices, - null, - false, - false, - false, - false, - null - ); + final DataStream dataStream = DataStream.builder(dataStreamName, backingIndices).setGeneration(numberOfBackingIndices).build(); ComposableIndexTemplate template = ComposableIndexTemplate.builder() .indexPatterns(List.of(dataStream.getName() + "*")) .template( diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/UpdateTimeSeriesRangeServiceTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/UpdateTimeSeriesRangeServiceTests.java index 1c63deadf92a4..66133e9fbe0f2 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/UpdateTimeSeriesRangeServiceTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/UpdateTimeSeriesRangeServiceTests.java @@ -139,26 +139,7 @@ public void testUpdateTimeSeriesTemporalRange_NoUpdateBecauseReplicated() { List.of(new Tuple<>(start.minus(4, ChronoUnit.HOURS), start), new Tuple<>(start, end)) ).getMetadata(); DataStream d = metadata.dataStreams().get(dataStreamName); - metadata = Metadata.builder(metadata) - .put( - new DataStream( - d.getName(), - d.getIndices(), - d.getGeneration(), - d.getMetadata(), - d.isHidden(), - true, - d.isSystem(), - d.isAllowCustomRouting(), - d.getIndexMode(), - d.getLifecycle(), - d.isFailureStore(), - d.getFailureIndices(), - false, - d.getAutoShardingEvent() - ) - ) - .build(); + metadata = Metadata.builder(metadata).put(d.copy().setReplicated(true).setRolloverOnWrite(false).build()).build(); now = now.plus(1, ChronoUnit.HOURS); ClusterState in = ClusterState.builder(ClusterState.EMPTY_STATE).metadata(metadata).build(); diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/DeleteDataStreamTransportActionTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/DeleteDataStreamTransportActionTests.java index a5c3b348b1f1b..d394db9523cce 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/DeleteDataStreamTransportActionTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/DeleteDataStreamTransportActionTests.java @@ -57,7 +57,7 @@ public void testDeleteDataStream() { } public void testDeleteDataStreamWithFailureStore() { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); final String dataStreamName = "my-data-stream"; final List otherIndices = randomSubsetOf(List.of("foo", "bar", "baz")); diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsResponseTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsResponseTests.java index 9fc646995bc0e..ec6e624794a03 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsResponseTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsResponseTests.java @@ -76,22 +76,14 @@ public void testResponseIlmAndDataStreamLifecycleRepresentation() throws Excepti List failureStores = List.of(failureStoreIndex); { // data stream has an enabled lifecycle - DataStream logs = new DataStream( - "logs", - indices, - 3, - null, - false, - false, - false, - true, - IndexMode.STANDARD, - new DataStreamLifecycle(), - true, - failureStores, - false, - null - ); + DataStream logs = DataStream.builder("logs", indices) + .setGeneration(3) + .setAllowCustomRouting(true) + .setIndexMode(IndexMode.STANDARD) + .setLifecycle(new DataStreamLifecycle()) + .setFailureStoreEnabled(true) + .setFailureIndices(failureStores) + .build(); String ilmPolicyName = "rollover-30days"; Map indexSettingsValues = Map.of( @@ -166,7 +158,7 @@ public void testResponseIlmAndDataStreamLifecycleRepresentation() throws Excepti is(ManagedBy.LIFECYCLE.displayValue) ); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { List failureStoresRepresentation = (List) dataStreamMap.get( DataStream.FAILURE_INDICES_FIELD.getPreferredName() ); @@ -187,22 +179,14 @@ public void testResponseIlmAndDataStreamLifecycleRepresentation() throws Excepti { // data stream has a lifecycle that's not enabled - DataStream logs = new DataStream( - "logs", - indices, - 3, - null, - false, - false, - false, - true, - IndexMode.STANDARD, - new DataStreamLifecycle(null, null, false), - true, - failureStores, - false, - null - ); + DataStream logs = DataStream.builder("logs", indices) + .setGeneration(3) + .setAllowCustomRouting(true) + .setIndexMode(IndexMode.STANDARD) + .setLifecycle(new DataStreamLifecycle(null, null, false)) + .setFailureStoreEnabled(true) + .setFailureIndices(failureStores) + .build(); String ilmPolicyName = "rollover-30days"; Map indexSettingsValues = Map.of( @@ -266,7 +250,7 @@ public void testResponseIlmAndDataStreamLifecycleRepresentation() throws Excepti is(ManagedBy.UNMANAGED.displayValue) ); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { List failureStoresRepresentation = (List) dataStreamMap.get( DataStream.FAILURE_INDICES_FIELD.getPreferredName() ); diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java index a67fa72cb3079..8524ef30856a8 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java @@ -283,22 +283,11 @@ public void testRetentionNotExecutedForTSIndicesWithinTimeBounds() { Metadata.Builder builder = Metadata.builder(clusterState.metadata()); DataStream dataStream = builder.dataStream(dataStreamName); builder.put( - new DataStream( - dataStreamName, - dataStream.getIndices(), - dataStream.getGeneration() + 1, - dataStream.getMetadata(), - dataStream.isHidden(), - dataStream.isReplicated(), - dataStream.isSystem(), - dataStream.isAllowCustomRouting(), - dataStream.getIndexMode(), - DataStreamLifecycle.newBuilder().dataRetention(0L).build(), - dataStream.isFailureStore(), - dataStream.getFailureIndices(), - dataStream.rolloverOnWrite(), - dataStream.getAutoShardingEvent() - ) + dataStream.copy() + .setName(dataStreamName) + .setGeneration(dataStream.getGeneration() + 1) + .setLifecycle(DataStreamLifecycle.newBuilder().dataRetention(0L).build()) + .build() ); clusterState = ClusterState.builder(clusterState).metadata(builder).build(); diff --git a/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorMatchedSlotSubFetchPhase.java b/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorMatchedSlotSubFetchPhase.java index 83703dcf10971..fe4bfc7741c87 100644 --- a/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorMatchedSlotSubFetchPhase.java +++ b/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorMatchedSlotSubFetchPhase.java @@ -85,8 +85,9 @@ public void process(HitContext hitContext) throws IOException { // This is not a document with a percolator field. continue; } - query = pc.filterNestedDocs(query, fetchContext.getSearchExecutionContext().indexVersionCreated()); IndexSearcher percolatorIndexSearcher = pc.percolateQuery.getPercolatorIndexSearcher(); + query = pc.filterNestedDocs(query, fetchContext.getSearchExecutionContext().indexVersionCreated()); + query = percolatorIndexSearcher.rewrite(query); int memoryIndexMaxDoc = percolatorIndexSearcher.getIndexReader().maxDoc(); TopDocs topDocs = percolatorIndexSearcher.search(query, memoryIndexMaxDoc, new Sort(SortField.FIELD_DOC)); if (topDocs.totalHits.value == 0) { diff --git a/modules/percolator/src/yamlRestTest/resources/rest-api-spec/test/30_matched_complex_queries.yml b/modules/percolator/src/yamlRestTest/resources/rest-api-spec/test/30_matched_complex_queries.yml new file mode 100644 index 0000000000000..eb0c020a8199e --- /dev/null +++ b/modules/percolator/src/yamlRestTest/resources/rest-api-spec/test/30_matched_complex_queries.yml @@ -0,0 +1,86 @@ +setup: + - requires: + cluster_features: ["gte_v8.14.0"] + reason: "Displaying matched complex named queries within percolator queries was fixed in 8.14" + - do: + indices.create: + index: houses + body: + mappings: + dynamic: strict + properties: + my_query: + type: percolator + description: + type: text + num_of_bedrooms: + type: integer + type: + type: keyword + price: + type: integer + + - do: + index: + refresh: true + index: houses + id: query_cheap_houses_with_swimming_pool + body: + my_query: + { + "bool": { + "should": [ + { "range": { "price": { "lte": 399999, "_name": "cheap_query" } } }, + { "wildcard": { "description": { "value": "swim*", "_name": "swimming_pool_query" } } } + ] + } + } + + - do: + index: + refresh: true + index: houses + id: query_big_houses_with_fireplace + body: + my_query: + { + "bool": { + "should": [ + { "range": { "num_of_bedrooms": { "gte": 3, "_name": "big_house_query" } } }, + { "query_string": { "query": "fire*", "fields" : ["description"], "_name": "fireplace_query" } } + ] + } + } + +--- +"Matched named queries within percolator queries: percolate existing document": + - do: + index: + refresh: true + index: houses + id: house1 + body: + description: "house with a beautiful fireplace and swimming pool" + num_of_bedrooms: 3 + type: detached + price: 1000000 + + - do: + search: + index: houses + body: + query: + percolate: + field: my_query + index: houses + id: house1 + + - match: { hits.total.value: 2 } + + - match: { hits.hits.0._id: query_big_houses_with_fireplace } + - match: { hits.hits.0.fields._percolator_document_slot: [ 0 ] } + - match: { hits.hits.0.fields._percolator_document_slot_0_matched_queries: [ "big_house_query", "fireplace_query" ] } + + - match: { hits.hits.1._id: query_cheap_houses_with_swimming_pool } + - match: { hits.hits.1.fields._percolator_document_slot: [ 0 ] } + - match: { hits.hits.1.fields._percolator_document_slot_0_matched_queries: [ "swimming_pool_query" ] } diff --git a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/VectorSearchIT.java b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/VectorSearchIT.java index 34b2f5d723949..e78e0978b1d80 100644 --- a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/VectorSearchIT.java +++ b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/VectorSearchIT.java @@ -285,7 +285,6 @@ public void testByteVectorSearch() throws Exception { assertThat((double) hits.get(0).get("_score"), closeTo(0.028571429, 0.0001)); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107332") public void testQuantizedVectorSearch() throws Exception { assumeTrue( "Quantized vector search is not supported on this version", @@ -357,7 +356,7 @@ public void testQuantizedVectorSearch() throws Exception { assertThat(extractValue(response, "hits.total.value"), equalTo(2)); hits = extractValue(response, "hits.hits"); assertThat(hits.get(0).get("_id"), equalTo("0")); - assertThat((double) hits.get(0).get("_score"), closeTo(0.9934857, 0.0001)); + assertThat((double) hits.get(0).get("_score"), closeTo(0.9934857, 0.005)); } private void indexVectors(String indexName) throws Exception { diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/PrevalidateNodeRemovalRestIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/PrevalidateNodeRemovalRestIT.java index 17d0f04b9e2cf..ae1764310a34d 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/PrevalidateNodeRemovalRestIT.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/PrevalidateNodeRemovalRestIT.java @@ -29,7 +29,7 @@ public class PrevalidateNodeRemovalRestIT extends HttpSmokeTestCase { public void testRestStatusCode() throws IOException { String node1Name = internalCluster().getRandomNodeName(); - String node1Id = internalCluster().clusterService(node1Name).localNode().getId(); + String node1Id = getNodeId(node1Name); ensureGreen(); RestClient client = getRestClient(); diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml index 66f88315032c3..68755f80c428d 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml @@ -17,6 +17,8 @@ setup: type: dense_vector dims: 5 index: true + index_options: + type: hnsw similarity: l2_norm - do: diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateNodeRemovalIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateNodeRemovalIT.java index f53e559bfda5d..38921840a2c64 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateNodeRemovalIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateNodeRemovalIT.java @@ -58,7 +58,7 @@ public void testNodeRemovalFromNonRedCluster() throws Exception { PrevalidateNodeRemovalRequest.Builder req = PrevalidateNodeRemovalRequest.builder(); switch (randomIntBetween(0, 2)) { case 0 -> req.setNames(nodeName); - case 1 -> req.setIds(internalCluster().clusterService(nodeName).localNode().getId()); + case 1 -> req.setIds(getNodeId(nodeName)); case 2 -> req.setExternalIds(internalCluster().clusterService(nodeName).localNode().getExternalId()); default -> throw new IllegalStateException("Unexpected value"); } @@ -156,7 +156,7 @@ public void testNodeRemovalFromRedClusterWithLocalShardCopy() throws Exception { // Prevalidate removal of node1 PrevalidateNodeRemovalRequest req = PrevalidateNodeRemovalRequest.builder().setNames(node1).build(); PrevalidateNodeRemovalResponse resp = client().execute(PrevalidateNodeRemovalAction.INSTANCE, req).get(); - String node1Id = internalCluster().clusterService(node1).localNode().getId(); + String node1Id = getNodeId(node1); assertFalse(resp.getPrevalidation().isSafe()); assertThat(resp.getPrevalidation().message(), equalTo("removal of the following nodes might not be safe: [" + node1Id + "]")); assertThat(resp.getPrevalidation().nodes().size(), equalTo(1)); @@ -187,7 +187,7 @@ public void testNodeRemovalFromRedClusterWithTimeout() throws Exception { .timeout(TimeValue.timeValueSeconds(1)); PrevalidateNodeRemovalResponse resp = client().execute(PrevalidateNodeRemovalAction.INSTANCE, req).get(); assertFalse("prevalidation result should return false", resp.getPrevalidation().isSafe()); - String node2Id = internalCluster().clusterService(node2).localNode().getId(); + String node2Id = getNodeId(node2); assertThat( resp.getPrevalidation().message(), equalTo("cannot prevalidate removal of nodes with the following IDs: [" + node2Id + "]") diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java index 560a525ec526c..77bcaf1e1970c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java @@ -53,8 +53,8 @@ public void testCheckShards() throws Exception { .stream() .map(ShardRouting::shardId) .collect(Collectors.toSet()); - String node1Id = internalCluster().clusterService(node1).localNode().getId(); - String node2Id = internalCluster().clusterService(node2).localNode().getId(); + String node1Id = getNodeId(node1); + String node2Id = getNodeId(node2); Set shardIdsToCheck = new HashSet<>(shardIds); boolean includeUnknownShardId = randomBoolean(); if (includeUnknownShardId) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java b/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java index a0efb81c18668..c661894840261 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java @@ -326,7 +326,7 @@ public void testSendingShardFailure() throws Exception { String nonMasterNode = randomFrom(nonMasterNodes); assertAcked(prepareCreate("test").setSettings(indexSettings(3, 2))); ensureGreen(); - String nonMasterNodeId = internalCluster().clusterService(nonMasterNode).localNode().getId(); + String nonMasterNodeId = getNodeId(nonMasterNode); // fail a random shard ShardRouting failedShard = randomFrom( diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/store/IndicesStoreIntegrationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/store/IndicesStoreIntegrationIT.java index ca749eeaef545..5805eab831230 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/store/IndicesStoreIntegrationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/store/IndicesStoreIntegrationIT.java @@ -386,8 +386,8 @@ public void testShardActiveElseWhere() throws Exception { final String masterNode = internalCluster().getMasterName(); final String nonMasterNode = nodes.get(0).equals(masterNode) ? nodes.get(1) : nodes.get(0); - final String masterId = internalCluster().clusterService(masterNode).localNode().getId(); - final String nonMasterId = internalCluster().clusterService(nonMasterNode).localNode().getId(); + final String masterId = getNodeId(masterNode); + final String nonMasterId = getNodeId(nonMasterNode); final int numShards = scaledRandomIntBetween(2, 10); assertAcked(prepareCreate("test").setSettings(indexSettings(numShards, 0))); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/nodesinfo/SimpleNodesInfoIT.java b/server/src/internalClusterTest/java/org/elasticsearch/nodesinfo/SimpleNodesInfoIT.java index cafc0e9426eea..a5700c319aa59 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/nodesinfo/SimpleNodesInfoIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/nodesinfo/SimpleNodesInfoIT.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.admin.cluster.health.ClusterHealthResponse; import org.elasticsearch.action.admin.cluster.node.info.NodesInfoRequest; import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.monitor.os.OsInfo; @@ -29,16 +28,16 @@ @ClusterScope(scope = Scope.TEST, numDataNodes = 0) public class SimpleNodesInfoIT extends ESIntegTestCase { - public void testNodesInfos() throws Exception { - List nodesIds = internalCluster().startNodes(2); - final String node_1 = nodesIds.get(0); - final String node_2 = nodesIds.get(1); + public void testNodesInfos() { + List nodesNames = internalCluster().startNodes(2); + final String node_1 = nodesNames.get(0); + final String node_2 = nodesNames.get(1); ClusterHealthResponse clusterHealth = clusterAdmin().prepareHealth().setWaitForGreenStatus().setWaitForNodes("2").get(); logger.info("--> done cluster_health, status {}", clusterHealth.getStatus()); - String server1NodeId = internalCluster().getInstance(ClusterService.class, node_1).state().nodes().getLocalNodeId(); - String server2NodeId = internalCluster().getInstance(ClusterService.class, node_2).state().nodes().getLocalNodeId(); + String server1NodeId = getNodeId(node_1); + String server2NodeId = getNodeId(node_2); logger.info("--> started nodes: {} and {}", server1NodeId, server2NodeId); NodesInfoResponse response = clusterAdmin().prepareNodesInfo().get(); @@ -68,16 +67,16 @@ public void testNodesInfos() throws Exception { assertThat(response.getNodesMap().get(server2NodeId), notNullValue()); } - public void testNodesInfosTotalIndexingBuffer() throws Exception { - List nodesIds = internalCluster().startNodes(2); - final String node_1 = nodesIds.get(0); - final String node_2 = nodesIds.get(1); + public void testNodesInfosTotalIndexingBuffer() { + List nodesNames = internalCluster().startNodes(2); + final String node_1 = nodesNames.get(0); + final String node_2 = nodesNames.get(1); ClusterHealthResponse clusterHealth = clusterAdmin().prepareHealth().setWaitForGreenStatus().setWaitForNodes("2").get(); logger.info("--> done cluster_health, status {}", clusterHealth.getStatus()); - String server1NodeId = internalCluster().getInstance(ClusterService.class, node_1).state().nodes().getLocalNodeId(); - String server2NodeId = internalCluster().getInstance(ClusterService.class, node_2).state().nodes().getLocalNodeId(); + String server1NodeId = getNodeId(node_1); + String server2NodeId = getNodeId(node_2); logger.info("--> started nodes: {} and {}", server1NodeId, server2NodeId); NodesInfoResponse response = clusterAdmin().prepareNodesInfo().get(); @@ -103,19 +102,19 @@ public void testNodesInfosTotalIndexingBuffer() throws Exception { } public void testAllocatedProcessors() throws Exception { - List nodesIds = internalCluster().startNodes( + List nodeNames = internalCluster().startNodes( Settings.builder().put(EsExecutors.NODE_PROCESSORS_SETTING.getKey(), 2.9).build(), Settings.builder().put(EsExecutors.NODE_PROCESSORS_SETTING.getKey(), 5.9).build() ); - final String node_1 = nodesIds.get(0); - final String node_2 = nodesIds.get(1); + final String node_1 = nodeNames.get(0); + final String node_2 = nodeNames.get(1); ClusterHealthResponse clusterHealth = clusterAdmin().prepareHealth().setWaitForGreenStatus().setWaitForNodes("2").get(); logger.info("--> done cluster_health, status {}", clusterHealth.getStatus()); - String server1NodeId = internalCluster().getInstance(ClusterService.class, node_1).state().nodes().getLocalNodeId(); - String server2NodeId = internalCluster().getInstance(ClusterService.class, node_2).state().nodes().getLocalNodeId(); + String server1NodeId = getNodeId(node_1); + String server2NodeId = getNodeId(node_2); logger.info("--> started nodes: {} and {}", server1NodeId, server2NodeId); NodesInfoResponse response = clusterAdmin().prepareNodesInfo().get(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java index 813c06d9f02f3..d71718f3f3a6b 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java @@ -145,7 +145,7 @@ public void testPersistentActionWithNoAvailableNode() throws Exception { Settings nodeSettings = Settings.builder().put(nodeSettings(0, Settings.EMPTY)).put("node.attr.test_attr", "test").build(); String newNode = internalCluster().startNode(nodeSettings); - String newNodeId = internalCluster().clusterService(newNode).localNode().getId(); + String newNodeId = getNodeId(newNode); waitForTaskToStart(); TaskInfo taskInfo = clusterAdmin().prepareListTasks().setActions(TestPersistentTasksExecutor.NAME + "[c]").get().getTasks().get(0); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java index afc62323ca544..08c4d2aab4bc9 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java @@ -25,8 +25,10 @@ import java.util.List; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; public class CCSFieldCapabilitiesIT extends AbstractMultiClustersTestCase { @@ -35,6 +37,11 @@ protected Collection remoteClusterAlias() { return List.of("remote_cluster"); } + @Override + protected boolean reuseClusters() { + return false; + } + @Override protected Collection> nodePlugins(String clusterAlias) { final List> plugins = new ArrayList<>(super.nodePlugins(clusterAlias)); @@ -105,4 +112,17 @@ public void testFailuresFromRemote() { assertEquals(IllegalArgumentException.class, ex.getClass()); assertEquals("I throw because I choose to.", ex.getMessage()); } + + public void testFailedToConnectToRemoteCluster() throws Exception { + String localIndex = "local_index"; + assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate(localIndex)); + client(LOCAL_CLUSTER).prepareIndex(localIndex).setId("1").setSource("foo", "bar").get(); + client(LOCAL_CLUSTER).admin().indices().prepareRefresh(localIndex).get(); + cluster("remote_cluster").close(); + FieldCapabilitiesResponse response = client().prepareFieldCaps("*", "remote_cluster:*").setFields("*").get(); + assertThat(response.getIndices(), arrayContaining(localIndex)); + List failures = response.getFailures(); + assertThat(failures, hasSize(1)); + assertThat(failures.get(0).getIndices(), arrayContaining("remote_cluster:*")); + } } diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 50008832712b4..978ad1ce31e28 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -171,6 +171,8 @@ static TransportVersion def(int id) { public static final TransportVersion MODIFY_DATA_STREAM_FAILURE_STORES = def(8_630_00_0); public static final TransportVersion ML_INFERENCE_RERANK_NEW_RESPONSE_FORMAT = def(8_631_00_0); public static final TransportVersion HIGHLIGHTERS_TAGS_ON_FIELD_LEVEL = def(8_632_00_0); + public static final TransportVersion TRACK_FLUSH_TIME_EXCLUDING_WAITING_ON_LOCKS = def(8_633_00_0); + public static final TransportVersion ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS = def(8_634_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/get/GetIndexRequest.java b/server/src/main/java/org/elasticsearch/action/admin/indices/get/GetIndexRequest.java index a550350c20f6b..0b94e89fcc64d 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/get/GetIndexRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/get/GetIndexRequest.java @@ -95,7 +95,7 @@ public static Feature[] fromRequest(RestRequest request) { public GetIndexRequest() { super( - DataStream.isFailureStoreEnabled() + DataStream.isFailureStoreFeatureFlagEnabled() ? IndicesOptions.builder(IndicesOptions.strictExpandOpen()) .failureStoreOptions( IndicesOptions.FailureStoreOptions.builder().includeRegularIndices(true).includeFailureIndices(true) diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverService.java b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverService.java index 75852098170c6..cef0b3797b1d4 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverService.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverService.java @@ -640,7 +640,7 @@ static void validate( ); } var dataStream = (DataStream) indexAbstraction; - if (isFailureStoreRollover && dataStream.isFailureStore() == false) { + if (isFailureStoreRollover && dataStream.isFailureStoreEnabled() == false) { throw new IllegalArgumentException( "unable to roll over failure store because [" + indexAbstraction.getName() + "] does not have the failure store enabled" ); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 412e4f3c875e8..ea4d278227849 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -306,7 +306,7 @@ private void executeBulkRequestsByShard( } private void redirectFailuresOrCompleteBulkOperation() { - if (DataStream.isFailureStoreEnabled() && failureStoreRedirects.isEmpty() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && failureStoreRedirects.isEmpty() == false) { doRedirectFailures(); } else { completeBulkOperation(); @@ -412,7 +412,7 @@ private void completeShardOperation() { */ private static String getRedirectTarget(DocWriteRequest docWriteRequest, Metadata metadata) { // Feature flag guard - if (DataStream.isFailureStoreEnabled() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() == false) { return null; } // Do not resolve a failure store for documents that were already headed to one @@ -431,7 +431,7 @@ private static String getRedirectTarget(DocWriteRequest docWriteRequest, Meta Index concreteIndex = ia.getWriteIndex(); IndexAbstraction writeIndexAbstraction = metadata.getIndicesLookup().get(concreteIndex.getName()); DataStream parentDataStream = writeIndexAbstraction.getParentDataStream(); - if (parentDataStream != null && parentDataStream.isFailureStore()) { + if (parentDataStream != null && parentDataStream.isFailureStoreEnabled()) { // Keep the data stream name around to resolve the redirect to failure store if the shard level request fails. return parentDataStream.getName(); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java index 2112ad48bec62..d0a75bdf109c5 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java @@ -215,7 +215,7 @@ synchronized void markItemAsDropped(int slot) { * @param e the failure encountered. */ public void markItemForFailureStore(int slot, String targetIndexName, Exception e) { - if (DataStream.isFailureStoreEnabled() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() == false) { // Assert false for development, but if we somehow find ourselves here, default to failure logic. assert false : "Attempting to route a failed write request type to a failure store but the failure store is not enabled! " diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 3494701cf5b7a..13c4009cbc3e2 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -738,7 +738,7 @@ public boolean isForceExecution() { * or if it matches a template that has a data stream failure store enabled. */ static boolean shouldStoreFailure(String indexName, Metadata metadata, long epochMillis) { - return DataStream.isFailureStoreEnabled() + return DataStream.isFailureStoreFeatureFlagEnabled() && resolveFailureStoreFromMetadata(indexName, metadata, epochMillis).or( () -> resolveFailureStoreFromTemplate(indexName, metadata) ).orElse(false); @@ -774,7 +774,7 @@ private static Optional resolveFailureStoreFromMetadata(String indexNam DataStream targetDataStream = writeAbstraction.getParentDataStream(); // We will store the failure if the write target belongs to a data stream with a failure store. - return Optional.of(targetDataStream != null && targetDataStream.isFailureStore()); + return Optional.of(targetDataStream != null && targetDataStream.isFailureStoreEnabled()); } /** diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java b/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java index 36f2ff4fffa96..1a2103d665b38 100644 --- a/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java +++ b/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java @@ -319,7 +319,7 @@ public XContentBuilder toXContent( builder.endArray(); } builder.field(DataStream.GENERATION_FIELD.getPreferredName(), dataStream.getGeneration()); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { builder.field(DataStream.FAILURE_INDICES_FIELD.getPreferredName()); builder.startArray(); for (Index failureStore : dataStream.getFailureIndices()) { @@ -358,8 +358,8 @@ public XContentBuilder toXContent( builder.field(ALLOW_CUSTOM_ROUTING.getPreferredName(), dataStream.isAllowCustomRouting()); builder.field(REPLICATED.getPreferredName(), dataStream.isReplicated()); builder.field(ROLLOVER_ON_WRITE.getPreferredName(), dataStream.rolloverOnWrite()); - if (DataStream.isFailureStoreEnabled()) { - builder.field(DataStream.FAILURE_STORE_FIELD.getPreferredName(), dataStream.isFailureStore()); + if (DataStream.isFailureStoreFeatureFlagEnabled()) { + builder.field(DataStream.FAILURE_STORE_FIELD.getPreferredName(), dataStream.isFailureStoreEnabled()); } if (dataStream.getAutoShardingEvent() != null) { DataStreamAutoShardingEvent autoShardingEvent = dataStream.getAutoShardingEvent(); diff --git a/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java b/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java index e6acaba8307f6..7a8ea12568006 100644 --- a/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java +++ b/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.RemoteClusterActionType; +import org.elasticsearch.action.support.AbstractThreadedActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.HandledTransportAction; @@ -252,7 +253,14 @@ private void doExecuteForked(Task task, FieldCapabilitiesRequest request, final remoteClusterClient.execute( TransportFieldCapabilitiesAction.REMOTE_TYPE, remoteRequest, - ActionListener.releaseAfter(remoteListener, refs.acquire()) + // The underlying transport service may call onFailure with a thread pool other than search_coordinator. + // This fork is a workaround to ensure that the merging of field-caps always occurs on the search_coordinator. + // TODO: remove this workaround after we fixed https://github.com/elastic/elasticsearch/issues/107439 + new ForkingOnFailureActionListener<>( + searchCoordinationExecutor, + true, + ActionListener.releaseAfter(remoteListener, refs.acquire()) + ) ); } } @@ -569,4 +577,15 @@ public void messageReceived(FieldCapabilitiesNodeRequest request, TransportChann }); } } + + private static class ForkingOnFailureActionListener extends AbstractThreadedActionListener { + ForkingOnFailureActionListener(Executor executor, boolean forceExecution, ActionListener delegate) { + super(executor, forceExecution, delegate); + } + + @Override + public void onResponse(Response response) { + delegate.onResponse(response); + } + } } diff --git a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java index d142db2d5a1ab..9d0eeb20dacef 100644 --- a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java +++ b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java @@ -858,7 +858,7 @@ public IndexRequest setRequireDataStream(boolean requireDataStream) { @Override public Index getConcreteWriteIndex(IndexAbstraction ia, Metadata metadata) { - if (DataStream.isFailureStoreEnabled() && writeToFailureStore) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && writeToFailureStore) { if (ia.isDataStreamRelated() == false) { throw new ElasticsearchException( "Attempting to write a document to a failure store but the targeted index is not a data stream" diff --git a/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java b/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java index e46a7bd5f0ec2..1070a5d0bddd0 100644 --- a/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java +++ b/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java @@ -1109,7 +1109,7 @@ public static IndicesOptions fromRequest(RestRequest request, IndicesOptions def request.param(ConcreteTargetOptions.IGNORE_UNAVAILABLE), request.param(WildcardOptions.ALLOW_NO_INDICES), request.param(GatekeeperOptions.IGNORE_THROTTLED), - DataStream.isFailureStoreEnabled() + DataStream.isFailureStoreFeatureFlagEnabled() ? request.param(FailureStoreOptions.FAILURE_STORE) : FailureStoreOptions.INCLUDE_ONLY_REGULAR_INDICES, defaultSettings @@ -1117,7 +1117,7 @@ public static IndicesOptions fromRequest(RestRequest request, IndicesOptions def } public static IndicesOptions fromMap(Map map, IndicesOptions defaultSettings) { - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { return fromParameters( map.containsKey(WildcardOptions.EXPAND_WILDCARDS) ? map.get(WildcardOptions.EXPAND_WILDCARDS) : map.get("expandWildcards"), map.containsKey(ConcreteTargetOptions.IGNORE_UNAVAILABLE) @@ -1155,8 +1155,8 @@ public static boolean isIndicesOptions(String name) { || "ignoreThrottled".equals(name) || WildcardOptions.ALLOW_NO_INDICES.equals(name) || "allowNoIndices".equals(name) - || (DataStream.isFailureStoreEnabled() && FailureStoreOptions.FAILURE_STORE.equals(name)) - || (DataStream.isFailureStoreEnabled() && "failureStore".equals(name)); + || (DataStream.isFailureStoreFeatureFlagEnabled() && FailureStoreOptions.FAILURE_STORE.equals(name)) + || (DataStream.isFailureStoreFeatureFlagEnabled() && "failureStore".equals(name)); } public static IndicesOptions fromParameters( @@ -1187,7 +1187,7 @@ public static IndicesOptions fromParameters( WildcardOptions wildcards = WildcardOptions.parseParameters(wildcardsString, allowNoIndicesString, defaultSettings.wildcardOptions); GatekeeperOptions gatekeeperOptions = GatekeeperOptions.parseParameter(ignoreThrottled, defaultSettings.gatekeeperOptions); - FailureStoreOptions failureStoreOptions = DataStream.isFailureStoreEnabled() + FailureStoreOptions failureStoreOptions = DataStream.isFailureStoreFeatureFlagEnabled() ? FailureStoreOptions.parseParameters(failureStoreString, defaultSettings.failureStoreOptions) : FailureStoreOptions.DEFAULT; @@ -1205,7 +1205,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par concreteTargetOptions.toXContent(builder, params); wildcardOptions.toXContent(builder, params); gatekeeperOptions.toXContent(builder, params); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { failureStoreOptions.toXContent(builder, params); } return builder; @@ -1276,7 +1276,7 @@ public static IndicesOptions fromXContent(XContentParser parser, @Nullable Indic allowNoIndices = parser.booleanValue(); } else if (IGNORE_THROTTLED_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { generalOptions.ignoreThrottled(parser.booleanValue()); - } else if (DataStream.isFailureStoreEnabled() + } else if (DataStream.isFailureStoreFeatureFlagEnabled() && FAILURE_STORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { failureStoreOptions = FailureStoreOptions.parseParameters(parser.text(), failureStoreOptions); } else { @@ -1423,7 +1423,7 @@ public String toString() { + ignoreAliases() + ", ignore_throttled=" + ignoreThrottled() - + (DataStream.isFailureStoreEnabled() + + (DataStream.isFailureStoreFeatureFlagEnabled() ? ", include_regular_indices=" + includeRegularIndices() + ", include_failure_indices=" diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplate.java b/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplate.java index 8e8e6fff4cc6a..e6e48bfbd46b3 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplate.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplate.java @@ -376,14 +376,14 @@ public static class DataStreamTemplate implements Writeable, ToXContentObject { args -> new DataStreamTemplate( args[0] != null && (boolean) args[0], args[1] != null && (boolean) args[1], - DataStream.isFailureStoreEnabled() && args[2] != null && (boolean) args[2] + DataStream.isFailureStoreFeatureFlagEnabled() && args[2] != null && (boolean) args[2] ) ); static { PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), HIDDEN); PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), ALLOW_CUSTOM_ROUTING); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), FAILURE_STORE); } } @@ -478,7 +478,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field("hidden", hidden); builder.field(ALLOW_CUSTOM_ROUTING.getPreferredName(), allowCustomRouting); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { builder.field(FAILURE_STORE.getPreferredName(), failureStore); } builder.endObject(); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java index 5ada0126dc62b..33dab20a81494 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java @@ -73,7 +73,7 @@ public final class DataStream implements SimpleDiffable, ToXContentO public static final TransportVersion ADDED_FAILURE_STORE_TRANSPORT_VERSION = TransportVersions.V_8_12_0; public static final TransportVersion ADDED_AUTO_SHARDING_EVENT_VERSION = TransportVersions.DATA_STREAM_AUTO_SHARDING_EVENT; - public static boolean isFailureStoreEnabled() { + public static boolean isFailureStoreFeatureFlagEnabled() { return FAILURE_STORE_FEATURE_FLAG.isEnabled(); } @@ -104,16 +104,18 @@ public static boolean isFailureStoreEnabled() { private final String name; private final List indices; private final long generation; + @Nullable private final Map metadata; private final boolean hidden; private final boolean replicated; private final boolean system; private final boolean allowCustomRouting; + @Nullable private final IndexMode indexMode; @Nullable private final DataStreamLifecycle lifecycle; private final boolean rolloverOnWrite; - private final boolean failureStore; + private final boolean failureStoreEnabled; private final List failureIndices; private volatile Set failureStoreLookup; @Nullable @@ -130,7 +132,7 @@ public DataStream( boolean allowCustomRouting, IndexMode indexMode, DataStreamLifecycle lifecycle, - boolean failureStore, + boolean failureStoreEnabled, List failureIndices, boolean rolloverOnWrite, @Nullable DataStreamAutoShardingEvent autoShardingEvent @@ -147,7 +149,7 @@ public DataStream( allowCustomRouting, indexMode, lifecycle, - failureStore, + failureStoreEnabled, failureIndices, rolloverOnWrite, autoShardingEvent @@ -167,7 +169,7 @@ public DataStream( boolean allowCustomRouting, IndexMode indexMode, DataStreamLifecycle lifecycle, - boolean failureStore, + boolean failureStoreEnabled, List failureIndices, boolean rolloverOnWrite, @Nullable DataStreamAutoShardingEvent autoShardingEvent @@ -185,7 +187,7 @@ public DataStream( this.allowCustomRouting = allowCustomRouting; this.indexMode = indexMode; this.lifecycle = lifecycle; - this.failureStore = failureStore; + this.failureStoreEnabled = failureStoreEnabled; this.failureIndices = failureIndices; assert assertConsistent(this.indices); assert replicated == false || rolloverOnWrite == false : "replicated data streams cannot be marked for lazy rollover"; @@ -193,36 +195,6 @@ public DataStream( this.autoShardingEvent = autoShardingEvent; } - // mainly available for testing - public DataStream( - String name, - List indices, - long generation, - Map metadata, - boolean hidden, - boolean replicated, - boolean system, - boolean allowCustomRouting, - IndexMode indexMode - ) { - this( - name, - indices, - generation, - metadata, - hidden, - replicated, - system, - allowCustomRouting, - indexMode, - null, - false, - List.of(), - false, - null - ); - } - private static boolean assertConsistent(List indices) { assert indices.size() > 0; final Set indexNames = new HashSet<>(); @@ -271,7 +243,7 @@ public Index getWriteIndex() { */ @Nullable public Index getFailureStoreWriteIndex() { - return isFailureStore() == false || failureIndices.isEmpty() ? null : failureIndices.get(failureIndices.size() - 1); + return isFailureStoreEnabled() == false || failureIndices.isEmpty() ? null : failureIndices.get(failureIndices.size() - 1); } /** @@ -417,8 +389,8 @@ public boolean isAllowCustomRouting() { * * @return Whether this data stream should store ingestion failures. */ - public boolean isFailureStore() { - return failureStore; + public boolean isFailureStoreEnabled() { + return failureStoreEnabled; } @Nullable @@ -476,22 +448,13 @@ public DataStream unsafeRollover(Index writeIndex, long generation, boolean time List backingIndices = new ArrayList<>(indices); backingIndices.add(writeIndex); - return new DataStream( - name, - backingIndices, - generation, - metadata, - hidden, - false, - system, - allowCustomRouting, - indexMode, - lifecycle, - failureStore, - failureIndices, - false, - autoShardingEvent - ); + return copy().setIndices(backingIndices) + .setGeneration(generation) + .setReplicated(false) + .setIndexMode(indexMode) + .setAutoShardingEvent(autoShardingEvent) + .setRolloverOnWrite(false) + .build(); } /** @@ -514,22 +477,7 @@ public DataStream rolloverFailureStore(Index writeIndex, long generation) { public DataStream unsafeRolloverFailureStore(Index writeIndex, long generation) { List failureIndices = new ArrayList<>(this.failureIndices); failureIndices.add(writeIndex); - return new DataStream( - name, - indices, - generation, - metadata, - hidden, - false, - system, - allowCustomRouting, - indexMode, - lifecycle, - failureStore, - failureIndices, - rolloverOnWrite, - autoShardingEvent - ); + return copy().setGeneration(generation).setReplicated(false).setFailureIndices(failureIndices).build(); } /** @@ -617,22 +565,7 @@ public DataStream removeBackingIndex(Index index) { List backingIndices = new ArrayList<>(indices); backingIndices.remove(index); assert backingIndices.size() == indices.size() - 1; - return new DataStream( - name, - backingIndices, - generation + 1, - metadata, - hidden, - replicated, - system, - allowCustomRouting, - indexMode, - lifecycle, - failureStore, - failureIndices, - rolloverOnWrite, - autoShardingEvent - ); + return copy().setIndices(backingIndices).setGeneration(generation + 1).build(); } /** @@ -669,22 +602,7 @@ public DataStream removeFailureStoreIndex(Index index) { List updatedFailureIndices = new ArrayList<>(failureIndices); updatedFailureIndices.remove(index); assert updatedFailureIndices.size() == failureIndices.size() - 1; - return new DataStream( - name, - indices, - generation + 1, - metadata, - hidden, - replicated, - system, - allowCustomRouting, - indexMode, - lifecycle, - failureStore, - updatedFailureIndices, - rolloverOnWrite, - autoShardingEvent - ); + return copy().setGeneration(generation + 1).setFailureIndices(updatedFailureIndices).build(); } /** @@ -716,22 +634,7 @@ public DataStream replaceBackingIndex(Index existingBackingIndex, Index newBacki ); } backingIndices.set(backingIndexPosition, newBackingIndex); - return new DataStream( - name, - backingIndices, - generation + 1, - metadata, - hidden, - replicated, - system, - allowCustomRouting, - indexMode, - lifecycle, - failureStore, - failureIndices, - rolloverOnWrite, - autoShardingEvent - ); + return copy().setIndices(backingIndices).setGeneration(generation + 1).build(); } /** @@ -756,22 +659,7 @@ public DataStream addBackingIndex(Metadata clusterMetadata, Index index) { List backingIndices = new ArrayList<>(indices); backingIndices.add(0, index); assert backingIndices.size() == indices.size() + 1; - return new DataStream( - name, - backingIndices, - generation + 1, - metadata, - hidden, - replicated, - system, - allowCustomRouting, - indexMode, - lifecycle, - failureStore, - failureIndices, - rolloverOnWrite, - autoShardingEvent - ); + return copy().setIndices(backingIndices).setGeneration(generation + 1).build(); } /** @@ -795,22 +683,7 @@ public DataStream addFailureStoreIndex(Metadata clusterMetadata, Index index) { List updatedFailureIndices = new ArrayList<>(failureIndices); updatedFailureIndices.add(0, index); assert updatedFailureIndices.size() == failureIndices.size() + 1; - return new DataStream( - name, - indices, - generation + 1, - metadata, - hidden, - replicated, - system, - allowCustomRouting, - indexMode, - lifecycle, - failureStore, - updatedFailureIndices, - rolloverOnWrite, - autoShardingEvent - ); + return copy().setGeneration(generation + 1).setFailureIndices(updatedFailureIndices).build(); } /** @@ -855,23 +728,7 @@ private void ensureNoAliasesOnIndex(Metadata clusterMetadata, Index index) { } public DataStream promoteDataStream() { - return new DataStream( - name, - indices, - getGeneration(), - metadata, - hidden, - false, - system, - timeProvider, - allowCustomRouting, - indexMode, - lifecycle, - failureStore, - failureIndices, - rolloverOnWrite, - autoShardingEvent - ); + return copy().setReplicated(false).build(); } /** @@ -894,22 +751,7 @@ public DataStream snapshot(Collection indicesInSnapshot) { return null; } - return new DataStream( - name, - reconciledIndices, - generation, - metadata == null ? null : new HashMap<>(metadata), - hidden, - replicated, - system, - allowCustomRouting, - indexMode, - lifecycle, - failureStore, - failureIndices, - rolloverOnWrite, - autoShardingEvent - ); + return copy().setIndices(reconciledIndices).setMetadata(metadata == null ? null : new HashMap<>(metadata)).build(); } /** @@ -1175,7 +1017,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalWriteable(lifecycle); } if (out.getTransportVersion().onOrAfter(DataStream.ADDED_FAILURE_STORE_TRANSPORT_VERSION)) { - out.writeBoolean(failureStore); + out.writeBoolean(failureStoreEnabled); out.writeCollection(failureIndices); } if (out.getTransportVersion().onOrAfter(TransportVersions.LAZY_ROLLOVER_ADDED)) { @@ -1206,8 +1048,10 @@ public void writeTo(StreamOutput out) throws IOException { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("data_stream", args -> { // Fields behind a feature flag need to be parsed last otherwise the parser will fail when the feature flag is disabled. // Until the feature flag is removed we keep them separately to be mindful of this. - boolean failureStoreEnabled = DataStream.isFailureStoreEnabled() && args[12] != null && (boolean) args[12]; - List failureStoreIndices = DataStream.isFailureStoreEnabled() && args[13] != null ? (List) args[13] : List.of(); + boolean failureStoreEnabled = DataStream.isFailureStoreFeatureFlagEnabled() && args[12] != null && (boolean) args[12]; + List failureStoreIndices = DataStream.isFailureStoreFeatureFlagEnabled() && args[13] != null + ? (List) args[13] + : List.of(); return new DataStream( (String) args[0], (List) args[1], @@ -1252,7 +1096,7 @@ public void writeTo(StreamOutput out) throws IOException { AUTO_SHARDING_FIELD ); // The fields behind the feature flag should always be last. - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), FAILURE_STORE_FIELD); PARSER.declareObjectArray( ConstructingObjectParser.optionalConstructorArg(), @@ -1288,7 +1132,7 @@ public XContentBuilder toXContent( .endObject(); builder.xContentList(INDICES_FIELD.getPreferredName(), indices); builder.field(GENERATION_FIELD.getPreferredName(), generation); - if (DataStream.isFailureStoreEnabled() && failureIndices.isEmpty() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && failureIndices.isEmpty() == false) { builder.xContentList(FAILURE_INDICES_FIELD.getPreferredName(), failureIndices); } if (metadata != null) { @@ -1298,8 +1142,8 @@ public XContentBuilder toXContent( builder.field(REPLICATED_FIELD.getPreferredName(), replicated); builder.field(SYSTEM_FIELD.getPreferredName(), system); builder.field(ALLOW_CUSTOM_ROUTING.getPreferredName(), allowCustomRouting); - if (DataStream.isFailureStoreEnabled()) { - builder.field(FAILURE_STORE_FIELD.getPreferredName(), failureStore); + if (DataStream.isFailureStoreFeatureFlagEnabled()) { + builder.field(FAILURE_STORE_FIELD.getPreferredName(), failureStoreEnabled); } if (indexMode != null) { builder.field(INDEX_MODE.getPreferredName(), indexMode); @@ -1333,7 +1177,7 @@ public boolean equals(Object o) { && allowCustomRouting == that.allowCustomRouting && indexMode == that.indexMode && Objects.equals(lifecycle, that.lifecycle) - && failureStore == that.failureStore + && failureStoreEnabled == that.failureStoreEnabled && failureIndices.equals(that.failureIndices) && rolloverOnWrite == that.rolloverOnWrite && Objects.equals(autoShardingEvent, that.autoShardingEvent); @@ -1352,7 +1196,7 @@ public int hashCode() { allowCustomRouting, indexMode, lifecycle, - failureStore, + failureStoreEnabled, failureIndices, rolloverOnWrite, autoShardingEvent @@ -1494,4 +1338,154 @@ private static Instant getTimestampFromParser(BytesReference source, XContentTyp public static Instant getCanonicalTimestampBound(Instant time) { return time.truncatedTo(ChronoUnit.SECONDS); } + + public static Builder builder(String name, List indices) { + return new Builder(name, indices); + } + + public Builder copy() { + return new Builder(this); + } + + public static class Builder { + private LongSupplier timeProvider = System::currentTimeMillis; + private String name; + private List indices; + private long generation = 1; + @Nullable + private Map metadata = null; + private boolean hidden = false; + private boolean replicated = false; + private boolean system = false; + private boolean allowCustomRouting = false; + @Nullable + private IndexMode indexMode = null; + @Nullable + private DataStreamLifecycle lifecycle = null; + private boolean rolloverOnWrite = false; + private boolean failureStoreEnabled = false; + private List failureIndices = List.of(); + @Nullable + private DataStreamAutoShardingEvent autoShardingEvent = null; + + public Builder(String name, List indices) { + this.name = name; + assert indices.isEmpty() == false : "Cannot create data stream with empty backing indices"; + this.indices = indices; + } + + public Builder(DataStream dataStream) { + timeProvider = dataStream.timeProvider; + name = dataStream.name; + indices = dataStream.indices; + generation = dataStream.generation; + metadata = dataStream.metadata; + hidden = dataStream.hidden; + replicated = dataStream.replicated; + system = dataStream.system; + allowCustomRouting = dataStream.allowCustomRouting; + indexMode = dataStream.indexMode; + lifecycle = dataStream.lifecycle; + rolloverOnWrite = dataStream.rolloverOnWrite; + failureStoreEnabled = dataStream.failureStoreEnabled; + failureIndices = dataStream.failureIndices; + autoShardingEvent = dataStream.autoShardingEvent; + } + + public Builder setTimeProvider(LongSupplier timeProvider) { + this.timeProvider = timeProvider; + return this; + } + + public Builder setName(String name) { + this.name = name; + return this; + } + + public Builder setIndices(List indices) { + assert indices.isEmpty() == false : "Cannot create data stream with empty backing indices"; + this.indices = indices; + return this; + } + + public Builder setGeneration(long generation) { + this.generation = generation; + return this; + } + + public Builder setMetadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Builder setHidden(boolean hidden) { + this.hidden = hidden; + return this; + } + + public Builder setReplicated(boolean replicated) { + this.replicated = replicated; + return this; + } + + public Builder setSystem(boolean system) { + this.system = system; + return this; + } + + public Builder setAllowCustomRouting(boolean allowCustomRouting) { + this.allowCustomRouting = allowCustomRouting; + return this; + } + + public Builder setIndexMode(IndexMode indexMode) { + this.indexMode = indexMode; + return this; + } + + public Builder setLifecycle(DataStreamLifecycle lifecycle) { + this.lifecycle = lifecycle; + return this; + } + + public Builder setRolloverOnWrite(boolean rolloverOnWrite) { + this.rolloverOnWrite = rolloverOnWrite; + return this; + } + + public Builder setFailureStoreEnabled(boolean failureStoreEnabled) { + this.failureStoreEnabled = failureStoreEnabled; + return this; + } + + public Builder setFailureIndices(List failureIndices) { + this.failureIndices = failureIndices; + return this; + } + + public Builder setAutoShardingEvent(DataStreamAutoShardingEvent autoShardingEvent) { + this.autoShardingEvent = autoShardingEvent; + return this; + } + + public DataStream build() { + return new DataStream( + name, + indices, + generation, + metadata, + hidden, + replicated, + system, + timeProvider, + allowCustomRouting, + indexMode, + lifecycle, + failureStoreEnabled, + failureIndices, + rolloverOnWrite, + autoShardingEvent + ); + } + } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamAction.java b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamAction.java index 0148315c322be..f260b48cd7b7a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamAction.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamAction.java @@ -142,7 +142,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(type.fieldName); builder.field(DATA_STREAM.getPreferredName(), dataStream); builder.field(INDEX.getPreferredName(), index); - if (DataStream.isFailureStoreEnabled() && failureStore) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && failureStore) { builder.field(FAILURE_STORE.getPreferredName(), failureStore); } builder.endObject(); @@ -180,7 +180,7 @@ public static DataStreamAction fromXContent(XContentParser parser) throws IOExce ObjectParser.ValueType.STRING ); ADD_BACKING_INDEX_PARSER.declareField(DataStreamAction::setIndex, XContentParser::text, INDEX, ObjectParser.ValueType.STRING); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { ADD_BACKING_INDEX_PARSER.declareField( DataStreamAction::setFailureStore, XContentParser::booleanValue, @@ -195,7 +195,7 @@ public static DataStreamAction fromXContent(XContentParser parser) throws IOExce ObjectParser.ValueType.STRING ); REMOVE_BACKING_INDEX_PARSER.declareField(DataStreamAction::setIndex, XContentParser::text, INDEX, ObjectParser.ValueType.STRING); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { REMOVE_BACKING_INDEX_PARSER.declareField( DataStreamAction::setFailureStore, XContentParser::booleanValue, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolver.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolver.java index b88292d4ed79b..effc89d8e535a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolver.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolver.java @@ -387,7 +387,7 @@ Index[] concreteIndices(Context context, String... indexExpressions) { resolveIndicesForDataStream(context, (DataStream) indexAbstraction, concreteIndicesResult); } else if (indexAbstraction.getType() == Type.ALIAS && indexAbstraction.isDataStreamRelated() - && DataStream.isFailureStoreEnabled() + && DataStream.isFailureStoreFeatureFlagEnabled() && context.getOptions().includeFailureIndices()) { // Collect the data streams involved Set aliasDataStreams = new HashSet<>(); @@ -453,11 +453,13 @@ private static void resolveWriteIndexForDataStreams(Context context, DataStream } private static boolean shouldIncludeRegularIndices(IndicesOptions indicesOptions) { - return DataStream.isFailureStoreEnabled() == false || indicesOptions.includeRegularIndices(); + return DataStream.isFailureStoreFeatureFlagEnabled() == false || indicesOptions.includeRegularIndices(); } private static boolean shouldIncludeFailureIndices(IndicesOptions indicesOptions, DataStream dataStream) { - return DataStream.isFailureStoreEnabled() && indicesOptions.includeFailureIndices() && dataStream.isFailureStore(); + return DataStream.isFailureStoreFeatureFlagEnabled() + && indicesOptions.includeFailureIndices() + && dataStream.isFailureStoreEnabled(); } private static boolean resolvesToMoreThanOneIndex(IndexAbstraction indexAbstraction, Context context) { @@ -566,11 +568,11 @@ private static boolean shouldTrackConcreteIndex(Context context, IndicesOptions // Exclude this one as it's a net-new system index, and we explicitly don't want those. return false; } - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { IndexAbstraction indexAbstraction = context.getState().metadata().getIndicesLookup().get(index.getName()); if (context.options.allowFailureIndices() == false) { DataStream parentDataStream = indexAbstraction.getParentDataStream(); - if (parentDataStream != null && parentDataStream.isFailureStore()) { + if (parentDataStream != null && parentDataStream.isFailureStoreEnabled()) { if (parentDataStream.isFailureStoreIndex(index.getName())) { if (options.ignoreUnavailable()) { return false; diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java index f424861c5b7ff..fec209960597b 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java @@ -2598,8 +2598,8 @@ private static void collectIndices( private static boolean assertContainsIndexIfDataStream(DataStream parent, IndexMetadata indexMetadata) { assert parent == null || parent.getIndices().stream().anyMatch(index -> indexMetadata.getIndex().getName().equals(index.getName())) - || (DataStream.isFailureStoreEnabled() - && parent.isFailureStore() + || (DataStream.isFailureStoreFeatureFlagEnabled() + && parent.isFailureStoreEnabled() && parent.getFailureIndices().stream().anyMatch(index -> indexMetadata.getIndex().getName().equals(index.getName()))) : "Expected data stream [" + parent.getName() + "] to contain index " + indexMetadata.getIndex(); return true; @@ -2622,7 +2622,7 @@ private static void collectDataStreams( for (Index i : dataStream.getIndices()) { indexToDataStreamLookup.put(i.getName(), dataStream); } - if (DataStream.isFailureStoreEnabled() && dataStream.isFailureStore()) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && dataStream.isFailureStoreEnabled()) { for (Index i : dataStream.getFailureIndices()) { indexToDataStreamLookup.put(i.getName(), dataStream); } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateDataStreamService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateDataStreamService.java index 3c3ff0d130f0a..2d1d38ac926d6 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateDataStreamService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateDataStreamService.java @@ -418,7 +418,7 @@ public static ClusterState createFailureStoreIndex( String failureStoreIndexName, @Nullable BiConsumer metadataTransformer ) throws Exception { - if (DataStream.isFailureStoreEnabled() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() == false) { return currentState; } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsService.java index f30bc8ab7bfcb..29001a078956a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsService.java @@ -208,24 +208,7 @@ static ClusterState updateDataLifecycle( Metadata.Builder builder = Metadata.builder(metadata); for (var dataStreamName : dataStreamNames) { var dataStream = validateDataStream(metadata, dataStreamName); - builder.put( - new DataStream( - dataStream.getName(), - dataStream.getIndices(), - dataStream.getGeneration(), - dataStream.getMetadata(), - dataStream.isHidden(), - dataStream.isReplicated(), - dataStream.isSystem(), - dataStream.isAllowCustomRouting(), - dataStream.getIndexMode(), - lifecycle, - dataStream.isFailureStore(), - dataStream.getFailureIndices(), - dataStream.rolloverOnWrite(), - dataStream.getAutoShardingEvent() - ) - ); + builder.put(dataStream.copy().setLifecycle(lifecycle).build()); } return ClusterState.builder(currentState).metadata(builder.build()).build(); } @@ -246,24 +229,7 @@ public static ClusterState setRolloverOnWrite(ClusterState currentState, String return currentState; } Metadata.Builder builder = Metadata.builder(metadata); - builder.put( - new DataStream( - dataStream.getName(), - dataStream.getIndices(), - dataStream.getGeneration(), - dataStream.getMetadata(), - dataStream.isHidden(), - dataStream.isReplicated(), - dataStream.isSystem(), - dataStream.isAllowCustomRouting(), - dataStream.getIndexMode(), - dataStream.getLifecycle(), - dataStream.isFailureStore(), - dataStream.getFailureIndices(), - rolloverOnWrite, - dataStream.getAutoShardingEvent() - ) - ); + builder.put(dataStream.copy().setRolloverOnWrite(rolloverOnWrite).build()); return ClusterState.builder(currentState).metadata(builder.build()).build(); } diff --git a/server/src/main/java/org/elasticsearch/index/engine/Engine.java b/server/src/main/java/org/elasticsearch/index/engine/Engine.java index a910e496ce1b5..8ee536ec72248 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/Engine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/Engine.java @@ -395,6 +395,13 @@ boolean throttleLockIsHeldByCurrentThread() { // to be used in assertions and te */ public abstract void trimOperationsFromTranslog(long belowTerm, long aboveSeqNo) throws EngineException; + /** + * Returns the total time flushes have been executed excluding waiting on locks. + */ + public long getTotalFlushTimeExcludingWaitingOnLockInMillis() { + return 0; + } + /** A Lock implementation that always allows the lock to be acquired */ protected static final class NoOpLock implements Lock { diff --git a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java index 0b2e83532d030..d4371d71a4324 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java @@ -52,6 +52,7 @@ import org.elasticsearch.common.lucene.uid.VersionsAndSeqNoResolver; import org.elasticsearch.common.lucene.uid.VersionsAndSeqNoResolver.DocIdAndSeqNo; import org.elasticsearch.common.metrics.CounterMetric; +import org.elasticsearch.common.metrics.MeanMetric; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.Maps; @@ -107,6 +108,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -177,6 +179,8 @@ public class InternalEngine extends Engine { private final CounterMetric numDocDeletes = new CounterMetric(); private final CounterMetric numDocAppends = new CounterMetric(); private final CounterMetric numDocUpdates = new CounterMetric(); + private final MeanMetric totalFlushTimeExcludingWaitingOnLock = new MeanMetric(); + private final NumericDocValuesField softDeletesField = Lucene.newSoftDeletesField(); private final SoftDeletesPolicy softDeletesPolicy; private final LastRefreshedCheckpointListener lastRefreshedCheckpointListener; @@ -2195,6 +2199,7 @@ protected void flushHoldingLock(boolean force, boolean waitIfOngoing, ActionList logger.trace("acquired flush lock immediately"); } + final long startTime = System.nanoTime(); try { // Only flush if (1) Lucene has uncommitted docs, or (2) forced by caller, or (3) the // newly created commit points to a different translog generation (can free translog), @@ -2246,6 +2251,7 @@ protected void flushHoldingLock(boolean force, boolean waitIfOngoing, ActionList listener.onFailure(e); return; } finally { + totalFlushTimeExcludingWaitingOnLock.inc(System.nanoTime() - startTime); flushLock.unlock(); logger.trace("released flush lock"); } @@ -3066,6 +3072,11 @@ long getNumDocUpdates() { return numDocUpdates.count(); } + @Override + public long getTotalFlushTimeExcludingWaitingOnLockInMillis() { + return TimeUnit.NANOSECONDS.toMillis(totalFlushTimeExcludingWaitingOnLock.sum()); + } + @Override public int countChanges(String source, long fromSeqNo, long toSeqNo) throws IOException { ensureOpen(); diff --git a/server/src/main/java/org/elasticsearch/index/flush/FlushStats.java b/server/src/main/java/org/elasticsearch/index/flush/FlushStats.java index 7114b7b0e5c4f..e514a6d2adac0 100644 --- a/server/src/main/java/org/elasticsearch/index/flush/FlushStats.java +++ b/server/src/main/java/org/elasticsearch/index/flush/FlushStats.java @@ -8,6 +8,7 @@ package org.elasticsearch.index.flush; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -23,6 +24,7 @@ public class FlushStats implements Writeable, ToXContentFragment { private long total; private long periodic; private long totalTimeInMillis; + private long totalTimeExcludingWaitingOnLockInMillis; public FlushStats() { @@ -32,18 +34,22 @@ public FlushStats(StreamInput in) throws IOException { total = in.readVLong(); totalTimeInMillis = in.readVLong(); periodic = in.readVLong(); + totalTimeExcludingWaitingOnLockInMillis = in.getTransportVersion() + .onOrAfter(TransportVersions.TRACK_FLUSH_TIME_EXCLUDING_WAITING_ON_LOCKS) ? in.readVLong() : 0L; } - public FlushStats(long total, long periodic, long totalTimeInMillis) { + public FlushStats(long total, long periodic, long totalTimeInMillis, long totalTimeExcludingWaitingOnLockInMillis) { this.total = total; this.periodic = periodic; this.totalTimeInMillis = totalTimeInMillis; + this.totalTimeExcludingWaitingOnLockInMillis = totalTimeExcludingWaitingOnLockInMillis; } - public void add(long total, long periodic, long totalTimeInMillis) { + public void add(long total, long periodic, long totalTimeInMillis, long totalTimeWithoutWaitingInMillis) { this.total += total; this.periodic += periodic; this.totalTimeInMillis += totalTimeInMillis; + this.totalTimeExcludingWaitingOnLockInMillis += totalTimeWithoutWaitingInMillis; } public void add(FlushStats flushStats) { @@ -57,6 +63,7 @@ public void addTotals(FlushStats flushStats) { this.total += flushStats.total; this.periodic += flushStats.periodic; this.totalTimeInMillis += flushStats.totalTimeInMillis; + this.totalTimeExcludingWaitingOnLockInMillis += flushStats.totalTimeExcludingWaitingOnLockInMillis; } /** @@ -81,18 +88,30 @@ public long getTotalTimeInMillis() { } /** - * The total time merges have been executed. + * The total time flushes have been executed. */ public TimeValue getTotalTime() { return new TimeValue(totalTimeInMillis); } + /** + * The total time flushes have been executed excluding waiting time on locks (in milliseconds). + */ + public long getTotalTimeExcludingWaitingOnLockMillis() { + return totalTimeExcludingWaitingOnLockInMillis; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(Fields.FLUSH); builder.field(Fields.TOTAL, total); builder.field(Fields.PERIODIC, periodic); builder.humanReadableField(Fields.TOTAL_TIME_IN_MILLIS, Fields.TOTAL_TIME, getTotalTime()); + builder.humanReadableField( + Fields.TOTAL_TIME_EXCLUDING_WAITING_ON_LOCK_IN_MILLIS, + Fields.TOTAL_TIME_EXCLUDING_WAITING, + new TimeValue(getTotalTimeExcludingWaitingOnLockMillis()) + ); builder.endObject(); return builder; } @@ -103,6 +122,8 @@ static final class Fields { static final String PERIODIC = "periodic"; static final String TOTAL_TIME = "total_time"; static final String TOTAL_TIME_IN_MILLIS = "total_time_in_millis"; + static final String TOTAL_TIME_EXCLUDING_WAITING = "total_time_excluding_waiting"; + static final String TOTAL_TIME_EXCLUDING_WAITING_ON_LOCK_IN_MILLIS = "total_time_excluding_waiting_on_lock_in_millis"; } @Override @@ -110,6 +131,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVLong(total); out.writeVLong(totalTimeInMillis); out.writeVLong(periodic); + if (out.getTransportVersion().onOrAfter(TransportVersions.TRACK_FLUSH_TIME_EXCLUDING_WAITING_ON_LOCKS)) { + out.writeVLong(totalTimeExcludingWaitingOnLockInMillis); + } } @Override @@ -117,11 +141,14 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; FlushStats that = (FlushStats) o; - return total == that.total && totalTimeInMillis == that.totalTimeInMillis && periodic == that.periodic; + return total == that.total + && totalTimeInMillis == that.totalTimeInMillis + && periodic == that.periodic + && totalTimeExcludingWaitingOnLockInMillis == that.totalTimeExcludingWaitingOnLockInMillis; } @Override public int hashCode() { - return Objects.hash(total, totalTimeInMillis, periodic); + return Objects.hash(total, totalTimeInMillis, periodic, totalTimeExcludingWaitingOnLockInMillis); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/RangeType.java b/server/src/main/java/org/elasticsearch/index/mapper/RangeType.java index f8100e794dbd9..f339269d93636 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/RangeType.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/RangeType.java @@ -107,9 +107,8 @@ public BytesRef encodeRanges(Set ranges) throws IOExcept } @Override - public List decodeRanges(BytesRef bytes) { - // TODO: Implement this. - throw new UnsupportedOperationException(); + public List decodeRanges(BytesRef bytes) throws IOException { + return BinaryRangeUtil.decodeIPRanges(bytes); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java index 4a6eaa5b26c39..233faf462400b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java @@ -185,12 +185,13 @@ public SourceFieldMapper build() { if (mode.get() == Mode.DISABLED) { disallowed.add("mode=disabled"); } - assert disallowed.isEmpty() == false; - throw new MapperParsingException( - disallowed.size() == 1 - ? "Parameter [" + disallowed.get(0) + "] is not allowed in source" - : "Parameters [" + String.join(",", disallowed) + "] are not allowed in source" - ); + if (disallowed.isEmpty() == false) { + throw new MapperParsingException( + disallowed.size() == 1 + ? "Parameter [" + disallowed.get(0) + "] is not allowed in source" + : "Parameters [" + String.join(",", disallowed) + "] are not allowed in source" + ); + } } SourceFieldMapper sourceFieldMapper = new SourceFieldMapper( mode.get(), diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index 046483a6b074f..a52b289493cd6 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -757,6 +757,16 @@ public IndexShardState markAsRecovering(String reason, RecoveryState recoverySta private final AtomicBoolean primaryReplicaResyncInProgress = new AtomicBoolean(); + // temporary compatibility shim while adding targetNodeId parameter to dependencies + @Deprecated(forRemoval = true) + public void relocated( + final String targetAllocationId, + final BiConsumer> consumer, + final ActionListener listener + ) throws IllegalIndexShardStateException, IllegalStateException { + relocated(null, targetAllocationId, consumer, listener); + } + /** * Completes the relocation. Operations are blocked and current operations are drained before changing state to relocated. The provided * {@link BiConsumer} is executed after all operations are successfully blocked. @@ -768,6 +778,7 @@ public IndexShardState markAsRecovering(String reason, RecoveryState recoverySta * @throws IllegalStateException if the relocation target is no longer part of the replication group */ public void relocated( + final String targetNodeId, final String targetAllocationId, final BiConsumer> consumer, final ActionListener listener @@ -788,7 +799,7 @@ public void onResponse(Releasable releasable) { * context via a network operation. Doing this under the mutex can implicitly block the cluster state update thread * on network operations. */ - verifyRelocatingState(); + verifyRelocatingState(targetNodeId); final ReplicationTracker.PrimaryContext primaryContext = replicationTracker.startRelocationHandoff( targetAllocationId ); @@ -803,7 +814,7 @@ public void onResponse(Void unused) { try { // make changes to primaryMode and relocated flag only under mutex synchronized (mutex) { - verifyRelocatingState(); + verifyRelocatingState(targetNodeId); replicationTracker.completeRelocationHandoff(); } wrappedInnerListener.onResponse(null); @@ -857,7 +868,8 @@ public void onFailure(Exception e) { } } - private void verifyRelocatingState() { + // TODO only nullable temporarily, remove once deprecated relocated() override is removed, see ES-6725 + private void verifyRelocatingState(@Nullable String targetNodeId) { if (state != IndexShardState.STARTED) { throw new IndexShardNotStartedException(shardId, state); } @@ -871,6 +883,16 @@ private void verifyRelocatingState() { throw new IllegalIndexShardStateException(shardId, IndexShardState.STARTED, ": shard is no longer relocating " + shardRouting); } + if (targetNodeId != null) { + if (targetNodeId.equals(shardRouting.relocatingNodeId()) == false) { + throw new IllegalIndexShardStateException( + shardId, + IndexShardState.STARTED, + ": shard is no longer relocating to node [" + targetNodeId + "]: " + shardRouting + ); + } + } + if (primaryReplicaResyncInProgress.get()) { throw new IllegalIndexShardStateException( shardId, @@ -1307,7 +1329,12 @@ public RefreshStats refreshStats() { } public FlushStats flushStats() { - return new FlushStats(flushMetric.count(), periodicFlushMetric.count(), TimeUnit.NANOSECONDS.toMillis(flushMetric.sum())); + return new FlushStats( + flushMetric.count(), + periodicFlushMetric.count(), + TimeUnit.NANOSECONDS.toMillis(flushMetric.sum()), + getEngineOrNull() != null ? getEngineOrNull().getTotalFlushTimeExcludingWaitingOnLockInMillis() : 0L + ); } public DocsStats docStats() { diff --git a/server/src/main/java/org/elasticsearch/monitor/metrics/NodeMetrics.java b/server/src/main/java/org/elasticsearch/monitor/metrics/NodeMetrics.java index e689898b05da6..68cbcdb5657f9 100644 --- a/server/src/main/java/org/elasticsearch/monitor/metrics/NodeMetrics.java +++ b/server/src/main/java/org/elasticsearch/monitor/metrics/NodeMetrics.java @@ -651,6 +651,29 @@ private void registerAsyncMetrics(MeterRegistry registry) { ) ); + metrics.add( + registry.registerLongAsyncCounter( + "es.flush.total.time", + "The total time flushes have been executed excluding waiting time on locks", + "milliseconds", + () -> new LongWithAttributes( + stats.getOrRefresh() != null ? stats.getOrRefresh().getIndices().getFlush().getTotalTimeInMillis() : 0L + ) + ) + ); + + metrics.add( + registry.registerLongAsyncCounter( + "es.flush.total_excluding_lock_waiting.time", + "The total time flushes have been executed excluding waiting time on locks", + "milliseconds", + () -> new LongWithAttributes( + stats.getOrRefresh() != null + ? stats.getOrRefresh().getIndices().getFlush().getTotalTimeExcludingWaitingOnLockMillis() + : 0L + ) + ) + ); } /** @@ -680,6 +703,7 @@ private long bytesUsedByGCGen(Optional optionalMem, String name) { private NodeStats getNodeStats() { CommonStatsFlags flags = new CommonStatsFlags( CommonStatsFlags.Flag.Indexing, + CommonStatsFlags.Flag.Flush, CommonStatsFlags.Flag.Get, CommonStatsFlags.Flag.Search, CommonStatsFlags.Flag.Merge, diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/indices/RestRolloverIndexAction.java b/server/src/main/java/org/elasticsearch/rest/action/admin/indices/RestRolloverIndexAction.java index 98895a49fae6e..1718d9af7e5c8 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/admin/indices/RestRolloverIndexAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/indices/RestRolloverIndexAction.java @@ -53,7 +53,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC rolloverIndexRequest.lazy(request.paramAsBoolean("lazy", false)); rolloverIndexRequest.timeout(request.paramAsTime("timeout", rolloverIndexRequest.timeout())); rolloverIndexRequest.masterNodeTimeout(request.paramAsTime("master_timeout", rolloverIndexRequest.masterNodeTimeout())); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { boolean failureStore = request.paramAsBoolean("target_failure_store", false); if (failureStore) { rolloverIndexRequest.setIndicesOptions( diff --git a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java index a597901d4600e..5cabe22389529 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java @@ -704,22 +704,7 @@ static DataStream updateDataStream(DataStream dataStream, Metadata.Builder metad .stream() .map(i -> metadata.get(renameIndex(i.getName(), request, true)).getIndex()) .toList(); - return new DataStream( - dataStreamName, - updatedIndices, - dataStream.getGeneration(), - dataStream.getMetadata(), - dataStream.isHidden(), - dataStream.isReplicated(), - dataStream.isSystem(), - dataStream.isAllowCustomRouting(), - dataStream.getIndexMode(), - dataStream.getLifecycle(), - dataStream.isFailureStore(), - dataStream.getFailureIndices(), - dataStream.rolloverOnWrite(), - dataStream.getAutoShardingEvent() - ); + return dataStream.copy().setName(dataStreamName).setIndices(updatedIndices).build(); } public static RestoreInProgress updateRestoreStateWithDeletedIndices(RestoreInProgress oldRestore, Set deletedIndices) { diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodeStatsTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodeStatsTests.java index b91ea304c5da6..e502904004fef 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodeStatsTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodeStatsTests.java @@ -628,7 +628,7 @@ private static CommonStats createShardLevelCommonStats() { indicesCommonStats.getMerge().add(mergeStats); indicesCommonStats.getRefresh().add(new RefreshStats(++iota, ++iota, ++iota, ++iota, ++iota)); - indicesCommonStats.getFlush().add(new FlushStats(++iota, ++iota, ++iota)); + indicesCommonStats.getFlush().add(new FlushStats(++iota, ++iota, ++iota, ++iota)); indicesCommonStats.getWarmer().add(new WarmerStats(++iota, ++iota, ++iota)); indicesCommonStats.getCompletion().add(new CompletionStats(++iota, null)); indicesCommonStats.getTranslog().add(new TranslogStats(++iota, ++iota, ++iota, ++iota, ++iota)); diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java index d386eb40aea43..0bf92df006894 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java @@ -744,7 +744,7 @@ public void testValidation() throws Exception { // ensure no replicate data stream .promoteDataStream(); rolloverTarget = dataStream.getName(); - if (dataStream.isFailureStore() && randomBoolean()) { + if (dataStream.isFailureStoreEnabled() && randomBoolean()) { failureStoreOptions = new FailureStoreOptions(false, true); sourceIndexName = dataStream.getFailureStoreWriteIndex().getName(); defaultRolloverIndexName = DataStream.getDefaultFailureStoreName( diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverActionTests.java index 9faa6c4ba2d3f..427d2769b7399 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverActionTests.java @@ -427,17 +427,10 @@ public void testLazyRollover() throws Exception { .numberOfShards(1) .numberOfReplicas(1) .build(); - final DataStream dataStream = new DataStream( - "logs-ds", - List.of(backingIndexMetadata.getIndex()), - 1, - Map.of(), - false, - false, - false, - false, - IndexMode.STANDARD - ); + final DataStream dataStream = DataStream.builder("logs-ds", List.of(backingIndexMetadata.getIndex())) + .setMetadata(Map.of()) + .setIndexMode(IndexMode.STANDARD) + .build(); final ClusterState stateBefore = ClusterState.builder(ClusterName.DEFAULT) .metadata(Metadata.builder().put(backingIndexMetadata, false).put(dataStream)) .build(); @@ -489,17 +482,11 @@ public void testLazyRolloverFails() throws Exception { .numberOfShards(1) .numberOfReplicas(1) .build(); - final DataStream dataStream = new DataStream( - "logs-ds", - List.of(backingIndexMetadata.getIndex()), - randomIntBetween(1, 10), - Map.of(), - false, - false, - false, - false, - IndexMode.STANDARD - ); + final DataStream dataStream = DataStream.builder("logs-ds", List.of(backingIndexMetadata.getIndex())) + .setGeneration(randomIntBetween(1, 10)) + .setMetadata(Map.of()) + .setIndexMode(IndexMode.STANDARD) + .build(); final ClusterState stateBefore = ClusterState.builder(ClusterName.DEFAULT) .metadata(Metadata.builder().put(indexMetadata).put(backingIndexMetadata, false).put(dataStream)) .build(); @@ -559,17 +546,11 @@ public void testRolloverAliasToDataStreamFails() throws Exception { .numberOfShards(1) .numberOfReplicas(1) .build(); - final DataStream dataStream = new DataStream( - "logs-ds", - List.of(backingIndexMetadata.getIndex()), - 1, - Map.of(), - false, - false, - false, - false, - IndexMode.STANDARD - ); + final DataStream dataStream = DataStream.builder("logs-ds", List.of(backingIndexMetadata.getIndex())) + .setGeneration(1) + .setMetadata(Map.of()) + .setIndexMode(IndexMode.STANDARD) + .build(); Metadata.Builder metadataBuilder = Metadata.builder().put(backingIndexMetadata, false).put(dataStream); metadataBuilder.put("ds-alias", dataStream.getName(), true, null); final ClusterState stateBefore = ClusterState.builder(ClusterName.DEFAULT).metadata(metadataBuilder).build(); diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index 23395556761f1..b662f439a0e6f 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -363,7 +363,7 @@ public void testBulkToDataStreamFailingEntireShard() throws Exception { * A bulk operation to a data stream with a failure store enabled should redirect any shard level failures to the failure store. */ public void testFailingEntireShardRedirectsToFailureStore() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -393,7 +393,7 @@ public void testFailingEntireShardRedirectsToFailureStore() throws Exception { * failure store. */ public void testFailingDocumentRedirectsToFailureStore() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -423,7 +423,7 @@ public void testFailingDocumentRedirectsToFailureStore() throws Exception { * a shard-level failure while writing to the failure store indices. */ public void testFailureStoreShardFailureRejectsDocument() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -467,7 +467,7 @@ public void testFailureStoreShardFailureRejectsDocument() throws Exception { * instead will simply report its original failure in the response, with the conversion failure present as a suppressed exception. */ public void testFailedDocumentCanNotBeConvertedFails() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -505,7 +505,7 @@ public void testFailedDocumentCanNotBeConvertedFails() throws Exception { * non-retryable block when the redirected documents would be sent to the shard-level action. */ public void testBlockedClusterRejectsFailureStoreDocument() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -560,7 +560,7 @@ public void testBlockedClusterRejectsFailureStoreDocument() throws Exception { * retryable block to clear when the redirected documents would be sent to the shard-level action. */ public void testOperationTimeoutRejectsFailureStoreDocument() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -623,7 +623,7 @@ public void testOperationTimeoutRejectsFailureStoreDocument() throws Exception { * for a retryable block to clear when the redirected documents would be sent to the shard-level action. */ public void testNodeClosureRejectsFailureStoreDocument() { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index 960397033f602..c27263f43eff1 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -366,7 +366,7 @@ public void testRejectionAfterCreateIndexIsPropagated() { } public void testResolveFailureStoreFromMetadata() throws Exception { - assumeThat(DataStream.isFailureStoreEnabled(), is(true)); + assumeThat(DataStream.isFailureStoreFeatureFlagEnabled(), is(true)); String dataStreamWithFailureStore = "test-data-stream-failure-enabled"; String dataStreamWithoutFailureStore = "test-data-stream-failure-disabled"; @@ -425,7 +425,7 @@ public void testResolveFailureStoreFromMetadata() throws Exception { } public void testResolveFailureStoreFromTemplate() throws Exception { - assumeThat(DataStream.isFailureStoreEnabled(), is(true)); + assumeThat(DataStream.isFailureStoreFeatureFlagEnabled(), is(true)); String dsTemplateWithFailureStore = "test-data-stream-failure-enabled"; String dsTemplateWithoutFailureStore = "test-data-stream-failure-disabled"; diff --git a/server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java b/server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java index 70e291afcaf32..9803082bbd88a 100644 --- a/server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java +++ b/server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java @@ -37,7 +37,6 @@ import org.junit.Before; import java.util.ArrayList; -import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -622,17 +621,11 @@ public void testGetMaxIndexLoadWithinCoolingPeriod() { backingIndices.add(writeIndexMetadata.getIndex()); metadataBuilder.put(writeIndexMetadata, false); - final DataStream dataStream = new DataStream( - dataStreamName, - backingIndices, - backingIndices.size(), - Collections.emptyMap(), - false, - false, - false, - false, - IndexMode.STANDARD - ); + final DataStream dataStream = DataStream.builder(dataStreamName, backingIndices) + .setGeneration(backingIndices.size()) + .setMetadata(Map.of()) + .setIndexMode(IndexMode.STANDARD) + .build(); metadataBuilder.put(dataStream); @@ -684,17 +677,11 @@ public void testIndexLoadWithinCoolingPeriodIsSumOfShardsLoads() { backingIndices.add(writeIndexMetadata.getIndex()); metadataBuilder.put(writeIndexMetadata, false); - final DataStream dataStream = new DataStream( - dataStreamName, - backingIndices, - backingIndices.size(), - Collections.emptyMap(), - false, - false, - false, - false, - IndexMode.STANDARD - ); + final DataStream dataStream = DataStream.builder(dataStreamName, backingIndices) + .setGeneration(backingIndices.size()) + .setMetadata(Map.of()) + .setIndexMode(IndexMode.STANDARD) + .build(); metadataBuilder.put(dataStream); @@ -781,22 +768,10 @@ private DataStream createDataStream( builder.put(indexMetadata, false); backingIndices.add(indexMetadata.getIndex()); } - return new DataStream( - dataStreamName, - backingIndices, - backingIndicesCount, - null, - false, - false, - false, - false, - null, - null, - false, - List.of(), - false, - autoShardingEvent - ); + return DataStream.builder(dataStreamName, backingIndices) + .setGeneration(backingIndicesCount) + .setAutoShardingEvent(autoShardingEvent) + .build(); } private IndexMetadata createIndexMetadata( diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamMetadataTests.java index d186dafd6c7c3..1dc9e638e6002 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamMetadataTests.java @@ -71,28 +71,14 @@ protected Writeable.Reader instanceReader() { public void testWithAlias() { Index index1 = new Index("data-stream-1-index", "1"); Index index2 = new Index("data-stream-2-index", "2"); - DataStream dataStream1 = new DataStream( - "data-stream-1", - List.of(index1), - 1, - Map.of(), - false, - false, - false, - false, - IndexMode.STANDARD - ); - DataStream dataStream2 = new DataStream( - "data-stream-2", - List.of(index2), - 1, - Map.of(), - false, - false, - false, - false, - IndexMode.STANDARD - ); + DataStream dataStream1 = DataStream.builder("data-stream-1", List.of(index1)) + .setMetadata(Map.of()) + .setIndexMode(IndexMode.STANDARD) + .build(); + DataStream dataStream2 = DataStream.builder("data-stream-2", List.of(index2)) + .setMetadata(Map.of()) + .setIndexMode(IndexMode.STANDARD) + .build(); ImmutableOpenMap dataStreams = new ImmutableOpenMap.Builder().fPut( "data-stream-1", dataStream1 diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java index aceb463989bed..14c38a13f3730 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java @@ -39,7 +39,6 @@ import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -94,7 +93,7 @@ protected DataStream mutateInstance(DataStream instance) { var allowsCustomRouting = instance.isAllowCustomRouting(); var indexMode = instance.getIndexMode(); var lifecycle = instance.getLifecycle(); - var failureStore = instance.isFailureStore(); + var failureStore = instance.isFailureStoreEnabled(); var failureIndices = instance.getFailureIndices(); var rolloverOnWrite = instance.rolloverOnWrite(); var autoShardingEvent = instance.getAutoShardingEvent(); @@ -208,25 +207,11 @@ public void testRolloverWithConflictingBackingIndexName() { } public void testRolloverUpgradeToTsdbDataStream() { - IndexMode indexMode = randomBoolean() ? IndexMode.STANDARD : null; - DataStream ds = DataStreamTestHelper.randomInstance().promoteDataStream(); - // Unsure index_mode=null - ds = new DataStream( - ds.getName(), - ds.getIndices(), - ds.getGeneration(), - ds.getMetadata(), - ds.isHidden(), - ds.isReplicated(), - ds.isSystem(), - ds.isAllowCustomRouting(), - indexMode, - ds.getLifecycle(), - ds.isFailureStore(), - ds.getFailureIndices(), - ds.rolloverOnWrite(), - ds.getAutoShardingEvent() - ); + DataStream ds = DataStreamTestHelper.randomInstance() + .copy() + .setReplicated(false) + .setIndexMode(randomBoolean() ? IndexMode.STANDARD : null) + .build(); var newCoordinates = ds.nextWriteIndexAndGeneration(Metadata.EMPTY_METADATA); var rolledDs = ds.rollover(new Index(newCoordinates.v1(), UUIDs.randomBase64UUID()), newCoordinates.v2(), true, null); @@ -239,23 +224,7 @@ public void testRolloverUpgradeToTsdbDataStream() { } public void testRolloverDowngradeToRegularDataStream() { - DataStream ds = DataStreamTestHelper.randomInstance().promoteDataStream(); - ds = new DataStream( - ds.getName(), - ds.getIndices(), - ds.getGeneration(), - ds.getMetadata(), - ds.isHidden(), - ds.isReplicated(), - ds.isSystem(), - ds.isAllowCustomRouting(), - IndexMode.TIME_SERIES, - ds.getLifecycle(), - ds.isFailureStore(), - ds.getFailureIndices(), - ds.rolloverOnWrite(), - ds.getAutoShardingEvent() - ); + DataStream ds = DataStreamTestHelper.randomInstance().copy().setReplicated(false).setIndexMode(IndexMode.TIME_SERIES).build(); var newCoordinates = ds.nextWriteIndexAndGeneration(Metadata.EMPTY_METADATA); var rolledDs = ds.rollover(new Index(newCoordinates.v1(), UUIDs.randomBase64UUID()), newCoordinates.v2(), false, null); @@ -773,22 +742,13 @@ public void testSnapshot() { postSnapshotIndices.addAll(indicesToAdd); var replicated = preSnapshotDataStream.isReplicated() && randomBoolean(); - var postSnapshotDataStream = new DataStream( - preSnapshotDataStream.getName(), - postSnapshotIndices, - preSnapshotDataStream.getGeneration() + randomIntBetween(0, 5), - preSnapshotDataStream.getMetadata() == null ? null : new HashMap<>(preSnapshotDataStream.getMetadata()), - preSnapshotDataStream.isHidden(), - replicated, - preSnapshotDataStream.isSystem(), - preSnapshotDataStream.isAllowCustomRouting(), - preSnapshotDataStream.getIndexMode(), - preSnapshotDataStream.getLifecycle(), - preSnapshotDataStream.isFailureStore(), - preSnapshotDataStream.getFailureIndices(), - replicated == false && preSnapshotDataStream.rolloverOnWrite(), - preSnapshotDataStream.getAutoShardingEvent() - ); + var postSnapshotDataStream = preSnapshotDataStream.copy() + .setIndices(postSnapshotIndices) + .setGeneration(preSnapshotDataStream.getGeneration() + randomIntBetween(0, 5)) + .setMetadata(preSnapshotDataStream.getMetadata() == null ? null : new HashMap<>(preSnapshotDataStream.getMetadata())) + .setReplicated(replicated) + .setRolloverOnWrite(replicated == false && preSnapshotDataStream.rolloverOnWrite()) + .build(); var reconciledDataStream = postSnapshotDataStream.snapshot( preSnapshotDataStream.getIndices().stream().map(Index::getName).toList() @@ -815,22 +775,7 @@ public void testSnapshotWithAllBackingIndicesRemoved() { var preSnapshotDataStream = DataStreamTestHelper.randomInstance(); var indicesToAdd = randomNonEmptyIndexInstances(); - var postSnapshotDataStream = new DataStream( - preSnapshotDataStream.getName(), - indicesToAdd, - preSnapshotDataStream.getGeneration(), - preSnapshotDataStream.getMetadata(), - preSnapshotDataStream.isHidden(), - preSnapshotDataStream.isReplicated(), - preSnapshotDataStream.isSystem(), - preSnapshotDataStream.isAllowCustomRouting(), - preSnapshotDataStream.getIndexMode(), - preSnapshotDataStream.getLifecycle(), - preSnapshotDataStream.isFailureStore(), - preSnapshotDataStream.getFailureIndices(), - preSnapshotDataStream.rolloverOnWrite(), - preSnapshotDataStream.getAutoShardingEvent() - ); + var postSnapshotDataStream = preSnapshotDataStream.copy().setIndices(indicesToAdd).build(); assertNull(postSnapshotDataStream.snapshot(preSnapshotDataStream.getIndices().stream().map(Index::getName).toList())); } @@ -1049,17 +994,7 @@ public void testGetGenerationLifecycleDate() { .numberOfReplicas(1) .creationDate(creationTimeMillis); IndexMetadata indexMetadata = indexMetaBuilder.build(); - DataStream dataStream = new DataStream( - dataStreamName, - List.of(indexMetadata.getIndex()), - 1L, - Map.of(), - false, - randomBoolean(), - false, - randomBoolean(), - IndexMode.STANDARD - ); + DataStream dataStream = createDataStream(dataStreamName, List.of(indexMetadata.getIndex())); assertNull(dataStream.getGenerationLifecycleDate(indexMetadata)); } @@ -1079,16 +1014,9 @@ public void testGetGenerationLifecycleDate() { MaxAgeCondition rolloverCondition = new MaxAgeCondition(TimeValue.timeValueMillis(rolloverTimeMills)); indexMetaBuilder.putRolloverInfo(new RolloverInfo(dataStreamName, List.of(rolloverCondition), now - 2000L)); IndexMetadata indexMetadata = indexMetaBuilder.build(); - DataStream dataStream = new DataStream( + DataStream dataStream = createDataStream( dataStreamName, - List.of(indexMetadata.getIndex(), writeIndexMetaBuilder.build().getIndex()), - 1L, - Map.of(), - false, - randomBoolean(), - false, - randomBoolean(), - IndexMode.STANDARD + List.of(indexMetadata.getIndex(), writeIndexMetaBuilder.build().getIndex()) ); assertThat(dataStream.getGenerationLifecycleDate(indexMetadata).millis(), is(rolloverTimeMills)); } @@ -1109,16 +1037,9 @@ public void testGetGenerationLifecycleDate() { MaxAgeCondition rolloverCondition = new MaxAgeCondition(TimeValue.timeValueMillis(rolloverTimeMills)); indexMetaBuilder.putRolloverInfo(new RolloverInfo("some-alias-name", List.of(rolloverCondition), now - 2000L)); IndexMetadata indexMetadata = indexMetaBuilder.build(); - DataStream dataStream = new DataStream( + DataStream dataStream = createDataStream( dataStreamName, - List.of(indexMetadata.getIndex(), writeIndexMetaBuilder.build().getIndex()), - 1L, - Map.of(), - false, - randomBoolean(), - false, - randomBoolean(), - IndexMode.STANDARD + List.of(indexMetadata.getIndex(), writeIndexMetaBuilder.build().getIndex()) ); assertThat(dataStream.getGenerationLifecycleDate(indexMetadata).millis(), is(creationTimeMillis)); } @@ -1131,17 +1052,7 @@ public void testGetGenerationLifecycleDate() { .numberOfReplicas(1) .creationDate(creationTimeMillis); IndexMetadata indexMetadata = indexMetaBuilder.build(); - DataStream dataStream = new DataStream( - dataStreamName, - List.of(indexMetadata.getIndex()), - 1L, - Map.of(), - false, - randomBoolean(), - false, - randomBoolean(), - IndexMode.STANDARD - ); + DataStream dataStream = createDataStream(dataStreamName, List.of(indexMetadata.getIndex())); assertNull(dataStream.getGenerationLifecycleDate(indexMetadata)); } @@ -1160,16 +1071,9 @@ public void testGetGenerationLifecycleDate() { .numberOfReplicas(1) .creationDate(creationTimeMillis); IndexMetadata indexMetadata = indexMetaBuilder.build(); - DataStream dataStream = new DataStream( + DataStream dataStream = createDataStream( dataStreamName, - List.of(indexMetadata.getIndex(), writeIndexMetaBuilder.build().getIndex()), - 1L, - Map.of(), - false, - randomBoolean(), - false, - randomBoolean(), - IndexMode.STANDARD + List.of(indexMetadata.getIndex(), writeIndexMetaBuilder.build().getIndex()) ); assertThat(dataStream.getGenerationLifecycleDate(indexMetadata).millis(), is(originTimeMillis)); } @@ -1190,16 +1094,9 @@ public void testGetGenerationLifecycleDate() { MaxAgeCondition rolloverCondition = new MaxAgeCondition(TimeValue.timeValueMillis(rolloverTimeMills)); indexMetaBuilder.putRolloverInfo(new RolloverInfo(dataStreamName, List.of(rolloverCondition), now - 2000L)); IndexMetadata indexMetadata = indexMetaBuilder.build(); - DataStream dataStream = new DataStream( + DataStream dataStream = createDataStream( dataStreamName, - List.of(indexMetadata.getIndex(), writeIndexMetaBuilder.build().getIndex()), - 1L, - Map.of(), - false, - randomBoolean(), - false, - randomBoolean(), - IndexMode.STANDARD + List.of(indexMetadata.getIndex(), writeIndexMetaBuilder.build().getIndex()) ); assertThat(dataStream.getGenerationLifecycleDate(indexMetadata).millis(), is(originTimeMillis)); } @@ -1220,21 +1117,23 @@ public void testGetGenerationLifecycleDate() { MaxAgeCondition rolloverCondition = new MaxAgeCondition(TimeValue.timeValueMillis(rolloverTimeMills)); indexMetaBuilder.putRolloverInfo(new RolloverInfo("some-alias-name", List.of(rolloverCondition), now - 2000L)); IndexMetadata indexMetadata = indexMetaBuilder.build(); - DataStream dataStream = new DataStream( + DataStream dataStream = createDataStream( dataStreamName, - List.of(indexMetadata.getIndex(), writeIndexMetaBuilder.build().getIndex()), - 1L, - Map.of(), - false, - randomBoolean(), - false, - randomBoolean(), - IndexMode.STANDARD + List.of(indexMetadata.getIndex(), writeIndexMetaBuilder.build().getIndex()) ); assertThat(dataStream.getGenerationLifecycleDate(indexMetadata).millis(), is(originTimeMillis)); } } + private DataStream createDataStream(String name, List indices) { + return DataStream.builder(name, indices) + .setMetadata(Map.of()) + .setReplicated(randomBoolean()) + .setAllowCustomRouting(randomBoolean()) + .setIndexMode(IndexMode.STANDARD) + .build(); + } + public void testGetIndicesOlderThan() { String dataStreamName = "metrics-foo"; long now = System.currentTimeMillis(); @@ -1939,17 +1838,11 @@ public void testGetIndicesWithinMaxAgeRange() { backingIndices.add(writeIndexMetadata.getIndex()); metadataBuilder.put(writeIndexMetadata, false); - final DataStream dataStream = new DataStream( - dataStreamName, - backingIndices, - backingIndices.size(), - Collections.emptyMap(), - false, - false, - false, - false, - randomBoolean() ? IndexMode.STANDARD : IndexMode.TIME_SERIES - ); + final DataStream dataStream = DataStream.builder(dataStreamName, backingIndices) + .setGeneration(backingIndices.size()) + .setMetadata(Map.of()) + .setIndexMode(randomBoolean() ? IndexMode.STANDARD : IndexMode.TIME_SERIES) + .build(); metadataBuilder.put(dataStream); @@ -2000,17 +1893,11 @@ public void testGetIndicesWithinMaxAgeRangeAllIndicesOutsideRange() { backingIndices.add(writeIndexMetadata.getIndex()); metadataBuilder.put(writeIndexMetadata, false); - final DataStream dataStream = new DataStream( - dataStreamName, - backingIndices, - backingIndices.size(), - Collections.emptyMap(), - false, - false, - false, - false, - randomBoolean() ? IndexMode.STANDARD : IndexMode.TIME_SERIES - ); + final DataStream dataStream = DataStream.builder(dataStreamName, backingIndices) + .setGeneration(backingIndices.size()) + .setMetadata(Map.of()) + .setIndexMode(randomBoolean() ? IndexMode.STANDARD : IndexMode.TIME_SERIES) + .build(); metadataBuilder.put(dataStream); diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolverTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolverTests.java index 2fba37772ef94..4cb38604aa49a 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolverTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolverTests.java @@ -49,6 +49,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.function.Function; @@ -3167,17 +3168,11 @@ public void testHiddenDataStreams() { .put(index2, false) .put(justAnIndex, false) .put( - new DataStream( - dataStream1, - List.of(index1.getIndex(), index2.getIndex()), - 2, - Collections.emptyMap(), - true, - false, - false, - false, - null - ) + DataStream.builder(dataStream1, List.of(index1.getIndex(), index2.getIndex())) + .setGeneration(2) + .setMetadata(Map.of()) + .setHidden(true) + .build() ) ) .build(); diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsServiceTests.java index 0b4883ee4e3c5..3d9368ddb9fc9 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsServiceTests.java @@ -353,22 +353,9 @@ public void testRemoveBrokenBackingIndexReference() { var dataStreamName = "my-logs"; var state = DataStreamTestHelper.getClusterStateWithDataStreams(List.of(new Tuple<>(dataStreamName, 2)), List.of()); var original = state.getMetadata().dataStreams().get(dataStreamName); - var broken = new DataStream( - original.getName(), - List.of(new Index(original.getIndices().get(0).getName(), "broken"), original.getIndices().get(1)), - original.getGeneration(), - original.getMetadata(), - original.isHidden(), - original.isReplicated(), - original.isSystem(), - original.isAllowCustomRouting(), - original.getIndexMode(), - original.getLifecycle(), - original.isFailureStore(), - original.getFailureIndices(), - original.rolloverOnWrite(), - original.getAutoShardingEvent() - ); + var broken = original.copy() + .setIndices(List.of(new Index(original.getIndices().get(0).getName(), "broken"), original.getIndices().get(1))) + .build(); var brokenState = ClusterState.builder(state).metadata(Metadata.builder(state.getMetadata()).put(broken).build()).build(); var result = MetadataDataStreamsService.modifyDataStream( diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/WildcardExpressionResolverTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/WildcardExpressionResolverTests.java index c7a30e3eae548..46d14436fd947 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/WildcardExpressionResolverTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/WildcardExpressionResolverTests.java @@ -367,23 +367,7 @@ public void testAllDataStreams() { { // if data stream itself is hidden, backing indices should not be returned - boolean hidden = true; - var dataStream = new DataStream( - dataStreamName, - List.of(firstBackingIndexMetadata.getIndex()), - 1, - null, - hidden, - false, - false, - false, - null, - null, - false, - List.of(), - false, - null - ); + var dataStream = DataStream.builder(dataStreamName, List.of(firstBackingIndexMetadata.getIndex())).setHidden(true).build(); Metadata.Builder mdBuilder = Metadata.builder().put(firstBackingIndexMetadata, true).put(dataStream); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java index 47b8bb3be36b7..a5264512d8086 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; public class SourceFieldMapperTests extends MetadataMapperTestCase { @@ -242,6 +243,18 @@ public void testSyntheticSourceInTimeSeries() throws IOException { public void testSupportsNonDefaultParameterValues() throws IOException { Settings settings = Settings.builder().put(SourceFieldMapper.LOSSY_PARAMETERS_ALLOWED_SETTING_NAME, false).build(); + { + var sourceFieldMapper = createMapperService(settings, topMapping(b -> b.startObject("_source").endObject())).documentMapper() + .sourceMapper(); + assertThat(sourceFieldMapper, notNullValue()); + } + { + var sourceFieldMapper = createMapperService( + settings, + topMapping(b -> b.startObject("_source").field("mode", randomBoolean() ? "synthetic" : "stored").endObject()) + ).documentMapper().sourceMapper(); + assertThat(sourceFieldMapper, notNullValue()); + } Exception e = expectThrows( MapperParsingException.class, () -> createMapperService(settings, topMapping(b -> b.startObject("_source").field("enabled", false).endObject())) diff --git a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java index c2706a7a3cf22..df4bde959d6ca 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java @@ -89,6 +89,7 @@ import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.fielddata.IndexFieldDataCache; import org.elasticsearch.index.fielddata.IndexFieldDataService; +import org.elasticsearch.index.flush.FlushStats; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.IdFieldMapper; import org.elasticsearch.index.mapper.LuceneDocument; @@ -155,6 +156,7 @@ import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -192,6 +194,7 @@ import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.matchesRegex; import static org.hamcrest.Matchers.not; @@ -1916,10 +1919,15 @@ public void testDelayedOperationsBeforeAndAfterRelocated() throws Exception { Thread recoveryThread = new Thread(() -> { try { startRecovery.await(); - shard.relocated(routing.getTargetRelocatingShard().allocationId().getId(), (primaryContext, listener) -> { - relocationStarted.countDown(); - listener.onResponse(null); - }, ActionListener.noop()); + shard.relocated( + routing.relocatingNodeId(), + routing.getTargetRelocatingShard().allocationId().getId(), + (primaryContext, listener) -> { + relocationStarted.countDown(); + listener.onResponse(null); + }, + ActionListener.noop() + ); } catch (InterruptedException e) { throw new RuntimeException(e); } @@ -2123,29 +2131,48 @@ protected void doRun() throws Exception { closeShards(shard); } - public void testRelocateMissingTarget() throws Exception { + public void testRelocateMismatchedTarget() throws Exception { final IndexShard shard = newStartedShard(true); final ShardRouting original = shard.routingEntry(); - final ShardRouting toNode1 = ShardRoutingHelper.relocate(original, "node_1"); - IndexShardTestCase.updateRoutingEntry(shard, toNode1); + + final ShardRouting wrongTargetNodeShardRouting = ShardRoutingHelper.relocate(original, "node_1"); + IndexShardTestCase.updateRoutingEntry(shard, wrongTargetNodeShardRouting); + IndexShardTestCase.updateRoutingEntry(shard, original); + + final ShardRouting wrongTargetAllocationIdShardRouting = ShardRoutingHelper.relocate(original, "node_2"); + IndexShardTestCase.updateRoutingEntry(shard, wrongTargetAllocationIdShardRouting); IndexShardTestCase.updateRoutingEntry(shard, original); - final ShardRouting toNode2 = ShardRoutingHelper.relocate(original, "node_2"); - IndexShardTestCase.updateRoutingEntry(shard, toNode2); + + final ShardRouting correctShardRouting = ShardRoutingHelper.relocate(original, "node_2"); + IndexShardTestCase.updateRoutingEntry(shard, correctShardRouting); + final AtomicBoolean relocated = new AtomicBoolean(); - final IllegalStateException error = expectThrows( + + final IllegalIndexShardStateException wrongNodeException = expectThrows( + IllegalIndexShardStateException.class, + () -> blockingCallRelocated(shard, wrongTargetNodeShardRouting, (ctx, listener) -> relocated.set(true)) + ); + assertThat( + wrongNodeException.getMessage(), + equalTo("CurrentState[STARTED] : shard is no longer relocating to node [node_1]: " + correctShardRouting) + ); + assertFalse(relocated.get()); + + final IllegalStateException wrongTargetIdException = expectThrows( IllegalStateException.class, - () -> blockingCallRelocated(shard, toNode1, (ctx, listener) -> relocated.set(true)) + () -> blockingCallRelocated(shard, wrongTargetAllocationIdShardRouting, (ctx, listener) -> relocated.set(true)) ); assertThat( - error.getMessage(), + wrongTargetIdException.getMessage(), equalTo( "relocation target [" - + toNode1.getTargetRelocatingShard().allocationId().getId() + + wrongTargetAllocationIdShardRouting.getTargetRelocatingShard().allocationId().getId() + "] is no longer part of the replication group" ) ); assertFalse(relocated.get()); - blockingCallRelocated(shard, toNode2, (ctx, listener) -> { + + blockingCallRelocated(shard, correctShardRouting, (ctx, listener) -> { relocated.set(true); listener.onResponse(null); }); @@ -3974,6 +4001,39 @@ public void testFlushOnIdle() throws Exception { closeShards(shard); } + public void testFlushTimeExcludingWaiting() throws Exception { + IndexShard shard = newStartedShard(); + for (int i = 0; i < randomIntBetween(4, 10); i++) { + indexDoc(shard, "_doc", Integer.toString(i)); + } + + int numFlushes = randomIntBetween(2, 5); + var flushesLatch = new CountDownLatch(numFlushes); + var executor = Executors.newFixedThreadPool(numFlushes); + for (int i = 0; i < numFlushes; i++) { + executor.submit(() -> { + shard.flush(new FlushRequest().waitIfOngoing(true).force(true)); + flushesLatch.countDown(); + }); + } + safeAwait(flushesLatch); + + FlushStats flushStats = shard.flushStats(); + assertThat( + "Flush time excluding waiting should be captured", + flushStats.getTotalTimeExcludingWaitingOnLockMillis(), + greaterThan(0L) + ); + assertThat( + "Flush time excluding waiting should less than flush time with waiting", + flushStats.getTotalTimeExcludingWaitingOnLockMillis(), + lessThan(flushStats.getTotalTime().millis()) + ); + + closeShards(shard); + executor.shutdown(); + } + @TestLogging(reason = "testing traces of concurrent flushes", value = "org.elasticsearch.index.engine.Engine:TRACE") public void testFlushOnIdleConcurrentFlushDoesNotWait() throws Exception { final MockLogAppender mockLogAppender = new MockLogAppender(); @@ -4937,7 +4997,7 @@ private static void blockingCallRelocated( BiConsumer> consumer ) { PlainActionFuture.get( - f -> indexShard.relocated(routing.getTargetRelocatingShard().allocationId().getId(), consumer, f) + f -> indexShard.relocated(routing.relocatingNodeId(), routing.getTargetRelocatingShard().allocationId().getId(), consumer, f) ); } } diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java index fbc191a12d8b0..72e08c340ea0c 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java @@ -19,7 +19,8 @@ static ElasticsearchCluster buildCluster() { .nodes(2) .module("test-esql-heap-attack") .setting("xpack.security.enabled", "false") - .setting("xpack.license.self_generated.type", "trial"); + .setting("xpack.license.self_generated.type", "trial") + .jvmArg("-Xmx512m"); String javaVersion = JvmInfo.jvmInfo().version(); if (javaVersion.equals("20") || javaVersion.equals("21")) { // see https://github.com/elastic/elasticsearch/issues/99592 diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java index 2f3826f8423b8..4f43817b7b92c 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java @@ -77,7 +77,7 @@ public void skipOnAborted() { */ public void testSortByManyLongsSuccess() throws IOException { initManyLongs(); - Response response = sortByManyLongs(2000); + Response response = sortByManyLongs(500); Map map = responseAsMap(response); ListMatcher columns = matchesList().item(matchesMap().entry("name", "a").entry("type", "long")) .item(matchesMap().entry("name", "b").entry("type", "long")); diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java index e2b03c6b81af3..6c038470b158d 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java @@ -130,22 +130,13 @@ public static DataStream newInstance( @Nullable DataStreamLifecycle lifecycle, @Nullable DataStreamAutoShardingEvent autoShardingEvent ) { - return new DataStream( - name, - indices, - generation, - metadata, - false, - replicated, - false, - false, - null, - lifecycle, - false, - List.of(), - false, - autoShardingEvent - ); + return DataStream.builder(name, indices) + .setGeneration(generation) + .setMetadata(metadata) + .setReplicated(replicated) + .setLifecycle(lifecycle) + .setAutoShardingEvent(autoShardingEvent) + .build(); } public static DataStream newInstance( @@ -157,22 +148,14 @@ public static DataStream newInstance( @Nullable DataStreamLifecycle lifecycle, List failureStores ) { - return new DataStream( - name, - indices, - generation, - metadata, - false, - replicated, - false, - false, - null, - lifecycle, - failureStores.size() > 0, - failureStores, - false, - null - ); + return DataStream.builder(name, indices) + .setGeneration(generation) + .setMetadata(metadata) + .setReplicated(replicated) + .setLifecycle(lifecycle) + .setFailureStoreEnabled(failureStores.isEmpty() == false) + .setFailureIndices(failureStores) + .build(); } public static String getLegacyDefaultBackingIndexName( @@ -477,7 +460,11 @@ public static void getClusterStateWithDataStreams( ComposableIndexTemplate.builder() .indexPatterns(List.of("*")) .dataStreamTemplate( - new ComposableIndexTemplate.DataStreamTemplate(false, false, DataStream.isFailureStoreEnabled() && storeFailures) + new ComposableIndexTemplate.DataStreamTemplate( + false, + false, + DataStream.isFailureStoreFeatureFlagEnabled() && storeFailures + ) ) .build() ); @@ -493,7 +480,7 @@ public static void getClusterStateWithDataStreams( allIndices.addAll(backingIndices); List failureStores = new ArrayList<>(); - if (DataStream.isFailureStoreEnabled() && storeFailures) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && storeFailures) { for (int failureStoreNumber = 1; failureStoreNumber <= dsTuple.v2(); failureStoreNumber++) { failureStores.add( createIndexMetadata( @@ -561,18 +548,18 @@ public static void getClusterStateWithDataStream( backingIndices.add(im); generation++; } - DataStream ds = new DataStream( + var dataStreamBuilder = DataStream.builder( dataStreamName, - backingIndices.stream().map(IndexMetadata::getIndex).collect(Collectors.toList()), - generation, - existing != null ? existing.getMetadata() : null, - existing != null && existing.isHidden(), - existing != null && existing.isReplicated(), - existing != null && existing.isSystem(), - existing != null && existing.isAllowCustomRouting(), - IndexMode.TIME_SERIES - ); - builder.put(ds); + backingIndices.stream().map(IndexMetadata::getIndex).collect(Collectors.toList()) + ).setGeneration(generation).setIndexMode(IndexMode.TIME_SERIES); + if (existing != null) { + dataStreamBuilder.setMetadata(existing.getMetadata()) + .setHidden(existing.isHidden()) + .setReplicated(existing.isReplicated()) + .setSystem(existing.isSystem()) + .setAllowCustomRouting(existing.isAllowCustomRouting()); + } + builder.put(dataStreamBuilder.build()); } private static IndexMetadata createIndexMetadata(String name, boolean hidden, Settings settings, int replicas) { diff --git a/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java index a5ace3e357f90..97f17858e753d 100644 --- a/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java @@ -17,7 +17,6 @@ import org.elasticsearch.cluster.NodeConnectionsService; import org.elasticsearch.cluster.action.shard.ShardStateAction; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; @@ -135,7 +134,7 @@ protected void checkTransientErrorsDuringRecoveryAreRetried(String recoveryActio ensureSearchable(indexName); ClusterStateResponse stateResponse = clusterAdmin().prepareState().get(); - final String blueNodeId = internalCluster().getInstance(ClusterService.class, blueNodeName).localNode().getId(); + final String blueNodeId = getNodeId(blueNodeName); assertFalse(stateResponse.getState().getRoutingNodes().node(blueNodeId).isEmpty()); @@ -231,7 +230,7 @@ public void checkDisconnectsWhileRecovering(String recoveryActionToBlock) throws ensureSearchable(indexName); ClusterStateResponse stateResponse = clusterAdmin().prepareState().get(); - final String blueNodeId = internalCluster().getInstance(ClusterService.class, blueNodeName).localNode().getId(); + final String blueNodeId = getNodeId(blueNodeName); assertFalse(stateResponse.getState().getRoutingNodes().node(blueNodeId).isEmpty()); diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java index 11d4754eaa596..1056c766e17ca 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java @@ -1093,6 +1093,10 @@ public static void awaitClusterState(Logger logger, String viaNode, Predicate DataStream.getDefaultBackingIndexName(name, value)) .map(value -> new Index(value, "uuid")) .collect(Collectors.toList()); - return new DataStream(name, backingIndices, backingIndices.size(), Map.of(), false, replicate, false, false, null); + long generation = backingIndices.size(); + return DataStream.builder(name, backingIndices).setGeneration(generation).setMetadata(Map.of()).setReplicated(replicate).build(); } static DataStream generateDataSteam(String name, int generation, boolean replicate, String... backingIndexNames) { List backingIndices = Arrays.stream(backingIndexNames).map(value -> new Index(value, "uuid")).collect(Collectors.toList()); - return new DataStream(name, backingIndices, generation, Map.of(), false, replicate, false, false, null); + return DataStream.builder(name, backingIndices).setGeneration(generation).setMetadata(Map.of()).setReplicated(replicate).build(); } } diff --git a/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/nodesinfo/ComponentVersionsNodesInfoIT.java b/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/nodesinfo/ComponentVersionsNodesInfoIT.java index 7f4fca7063cd5..32024ff03ed15 100644 --- a/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/nodesinfo/ComponentVersionsNodesInfoIT.java +++ b/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/nodesinfo/ComponentVersionsNodesInfoIT.java @@ -9,24 +9,20 @@ import org.elasticsearch.action.admin.cluster.health.ClusterHealthResponse; import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.test.ESIntegTestCase; -import java.util.List; - import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.notNullValue; public class ComponentVersionsNodesInfoIT extends ESIntegTestCase { public void testNodesInfoComponentVersions() { - List nodesIds = internalCluster().startNodes(1); - final String node_1 = nodesIds.get(0); + final String node_1 = internalCluster().startNode(); ClusterHealthResponse clusterHealth = clusterAdmin().prepareHealth().setWaitForGreenStatus().setWaitForNodes("1").get(); logger.info("--> done cluster_health, status {}", clusterHealth.getStatus()); - String server1NodeId = internalCluster().getInstance(ClusterService.class, node_1).state().nodes().getLocalNodeId(); + String server1NodeId = getNodeId(node_1); logger.info("--> started nodes: {}", server1NodeId); NodesInfoResponse response = clusterAdmin().prepareNodesInfo().get(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ClusterPrivilegeResolver.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ClusterPrivilegeResolver.java index 47e4a6913897b..3774efcdd2ad2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ClusterPrivilegeResolver.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ClusterPrivilegeResolver.java @@ -180,7 +180,9 @@ public class ClusterPrivilegeResolver { RemoteClusterNodesAction.TYPE.name(), XPackInfoAction.NAME, // esql enrich - "cluster:monitor/xpack/enrich/esql/resolve_policy" + "cluster:monitor/xpack/enrich/esql/resolve_policy", + "cluster:internal:data/read/esql/open_exchange", + "cluster:internal:data/read/esql/exchange" ); private static final Set CROSS_CLUSTER_REPLICATION_PATTERN = Set.of( RemoteClusterService.REMOTE_CLUSTER_HANDSHAKE_ACTION_NAME, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java index 066924b21c99f..674706eb9af49 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java @@ -86,10 +86,8 @@ public final class IndexPrivilege extends Privilege { ClusterSearchShardsAction.NAME, TransportSearchShardsAction.TYPE.name(), TransportResolveClusterAction.NAME, - // cross clusters query for ESQL - "internal:data/read/esql/open_exchange", - "internal:data/read/esql/exchange", - "indices:data/read/esql/cluster" + "indices:data/read/esql", + "indices:data/read/esql/compute" ); private static final Automaton CREATE_AUTOMATON = patterns( "indices:data/write/index*", diff --git a/x-pack/plugin/enrich/src/internalClusterTest/java/org/elasticsearch/xpack/enrich/EnrichMultiNodeIT.java b/x-pack/plugin/enrich/src/internalClusterTest/java/org/elasticsearch/xpack/enrich/EnrichMultiNodeIT.java index b81a5e6b902b3..26e38252a4572 100644 --- a/x-pack/plugin/enrich/src/internalClusterTest/java/org/elasticsearch/xpack/enrich/EnrichMultiNodeIT.java +++ b/x-pack/plugin/enrich/src/internalClusterTest/java/org/elasticsearch/xpack/enrich/EnrichMultiNodeIT.java @@ -20,7 +20,6 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.ingest.PutPipelineRequest; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.Maps; @@ -268,7 +267,7 @@ private static void enrich(Map> keys, String coordinatingNo EnrichStatsAction.Response statsResponse = client().execute(EnrichStatsAction.INSTANCE, new EnrichStatsAction.Request()) .actionGet(); assertThat(statsResponse.getCoordinatorStats().size(), equalTo(internalCluster().size())); - String nodeId = internalCluster().getInstance(ClusterService.class, coordinatingNode).localNode().getId(); + String nodeId = getNodeId(coordinatingNode); CoordinatorStats stats = statsResponse.getCoordinatorStats().stream().filter(s -> s.getNodeId().equals(nodeId)).findAny().get(); assertThat(stats.getNodeId(), equalTo(nodeId)); assertThat(stats.getRemoteRequestsTotal(), greaterThanOrEqualTo(1L)); diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml index ac102db163767..5734fdfe67ce8 100644 --- a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml @@ -13,91 +13,27 @@ setup: is_native: false service_type: super-connector --- -"Update Connector Filtering with advanced snippet value array": +"Update Connector Filtering - Update draft": - do: connector.update_filtering: connector_id: test-connector body: - filtering: - - active: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: - - tables: - - some_table - query: 'SELECT id, st_geohash(coordinates) FROM my_db.some_table;' - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: DEFAULT - draft: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: - - tables: - - some_table - query: 'SELECT id, st_geohash(coordinates) FROM my_db.some_table;' - - tables: - - another_table - query: 'SELECT id, st_geohash(coordinates) FROM my_db.another_table;' - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - - active: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-1 - order: 0 - policy: include - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: TEST - draft: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-1 - order: 0 - policy: exclude - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid + advanced_snippet: + created_at: "2023-05-25T12:30:00.000Z" + updated_at: "2023-05-25T12:30:00.000Z" + value: + - tables: + - some_table + query: 'SELECT id, st_geohash(coordinates) FROM my_db.some_table;' + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: _ + id: DEFAULT + order: 0 + policy: include + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: ".*" - match: { result: updated } @@ -105,19 +41,88 @@ setup: connector.get: connector_id: test-connector + + - match: { filtering.0.draft.advanced_snippet.created_at: "2023-05-25T12:30:00.000Z" } + - match: { filtering.0.draft.advanced_snippet.value.0.tables.0.: "some_table" } + - match: { filtering.0.draft.rules.0.id: DEFAULT } + - match: { filtering.0.draft.validation.errors: [] } + - match: { filtering.0.draft.validation.state: edited } + + # Default domain and active should be unchanged - match: { filtering.0.domain: DEFAULT } - - match: { filtering.0.active.advanced_snippet.created_at: "2023-05-25T12:30:00.000Z" } - - match: { filtering.0.active.advanced_snippet.value.0.tables.0.: "some_table" } - - match: { filtering.0.active.rules.0.id: "RULE-ACTIVE-0" } - - match: { filtering.0.draft.rules.0.id: "RULE-DRAFT-0" } + - match: { filtering.0.active.advanced_snippet.value: {} } + - match: { filtering.0.active.advanced_snippet.value: {} } + - match: { filtering.0.active.rules.0.field: _ } + - match: { filtering.0.active.rules.0.id: DEFAULT } + - match: { filtering.0.active.rules.0.rule: regex } + + +--- +"Update Connector Filtering - Update draft rules only": + - do: + connector.update_filtering: + connector_id: test-connector + body: + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: my_field + id: MY-RULE-1 + order: 0 + policy: exclude + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: "tax-.*" + - created_at: "2023-05-25T12:30:00.000Z" + field: _ + id: DEFAULT + order: 1 + policy: include + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: ".*" + + - match: { result: updated } - - match: { filtering.1.domain: TEST } - - match: { filtering.1.active.advanced_snippet.created_at: "2021-05-25T12:30:00.000Z" } - - match: { filtering.1.active.rules.0.id: "RULE-ACTIVE-1" } - - match: { filtering.1.draft.rules.0.id: "RULE-DRAFT-1" } + - do: + connector.get: + connector_id: test-connector + + - match: { filtering.0.draft.rules.0.id: MY-RULE-1 } + - match: { filtering.0.draft.rules.1.id: DEFAULT } + + # Default domain and active should be unchanged + - match: { filtering.0.domain: DEFAULT } + - match: { filtering.0.active.advanced_snippet.value: {} } + - match: { filtering.0.active.advanced_snippet.value: {} } + - match: { filtering.0.active.rules.0.field: _ } + - match: { filtering.0.active.rules.0.id: DEFAULT } + - match: { filtering.0.active.rules.0.rule: regex } --- -"Update Connector Filtering with advanced snippet value object": +"Update Connector Filtering - Update draft advanced snippet only": + - do: + connector.update_filtering: + connector_id: test-connector + body: + advanced_snippet: + created_at: "2023-05-25T12:30:00.000Z" + updated_at: "2023-05-25T12:30:00.000Z" + value: + - tables: + - some_table + query: 'SELECT id, st_geohash(coordinates) FROM my_db.some_table;' + + - match: { result: updated } + + - do: + connector.get: + connector_id: test-connector + + - match: { filtering.0.draft.advanced_snippet.created_at: "2023-05-25T12:30:00.000Z" } + - match: { filtering.0.draft.advanced_snippet.value.0.tables.0.: "some_table" } + +--- +"Update Connector Filtering - Update full filtering object": - do: connector.update_filtering: connector_id: test-connector @@ -132,7 +137,7 @@ setup: rules: - created_at: "2023-05-25T12:30:00.000Z" field: _ - id: RULE-ACTIVE-0 + id: DEFAULT order: 0 policy: include rule: regex @@ -141,7 +146,6 @@ setup: validation: errors: [] state: valid - domain: DEFAULT draft: advanced_snippet: created_at: "2023-05-25T12:30:00.000Z" @@ -150,7 +154,7 @@ setup: rules: - created_at: "2023-05-25T12:30:00.000Z" field: _ - id: RULE-DRAFT-0 + id: DEFAULT order: 0 policy: include rule: regex @@ -159,41 +163,7 @@ setup: validation: errors: [] state: valid - - active: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-1 - order: 0 - policy: include - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: TEST - draft: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-1 - order: 0 - policy: exclude - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid + - match: { result: updated } @@ -204,13 +174,9 @@ setup: - match: { filtering.0.domain: DEFAULT } - match: { filtering.0.active.advanced_snippet.created_at: "2023-05-25T12:30:00.000Z" } - match: { filtering.0.active.advanced_snippet.value.some_filtering_key: "some_filtering_value" } - - match: { filtering.0.active.rules.0.id: "RULE-ACTIVE-0" } - - match: { filtering.0.draft.rules.0.id: "RULE-DRAFT-0" } + - match: { filtering.0.active.rules.0.id: "DEFAULT" } + - match: { filtering.0.draft.rules.0.id: "DEFAULT" } - - match: { filtering.1.domain: TEST } - - match: { filtering.1.active.advanced_snippet.created_at: "2021-05-25T12:30:00.000Z" } - - match: { filtering.1.active.rules.0.id: "RULE-ACTIVE-1" } - - match: { filtering.1.draft.rules.0.id: "RULE-DRAFT-1" } --- "Update Connector Filtering with value literal - Wrong advanced snippet value": @@ -219,77 +185,34 @@ setup: connector.update_filtering: connector_id: test-connector body: - filtering: - - active: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: "string literal" - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: DEFAULT - draft: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - - active: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-1 - order: 0 - policy: include - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: TEST - draft: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-1 - order: 0 - policy: exclude - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid + advanced_snippet: + value: "string literal" + +--- +"Update Connector Filtering with value literal - Empty rules": + - do: + catch: "bad_request" + connector.update_filtering: + connector_id: test-connector + body: + rules: [ ] + +--- +"Update Connector Filtering with value literal - Default rule not present": + - do: + catch: "bad_request" + connector.update_filtering: + connector_id: test-connector + body: + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: my_field + id: MY_RULE + order: 0 + policy: exclude + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: "hello-not-default-rule.*" --- "Update Connector Filtering - Connector doesn't exist": @@ -298,77 +221,8 @@ setup: connector.update_filtering: connector_id: test-non-existent-connector body: - filtering: - - active: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: DEFAULT - draft: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - - active: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-1 - order: 0 - policy: include - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: TEST - draft: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-1 - order: 0 - policy: exclude - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid + advanced_snippet: + value: {} --- "Update Connector Filtering - Required fields are missing": @@ -376,9 +230,7 @@ setup: catch: "bad_request" connector.update_filtering: connector_id: test-connector - body: - filtering: - - domain: some_domain + body: {} - match: status: 400 @@ -390,74 +242,7 @@ setup: connector.update_filtering: connector_id: test-connector body: - filtering: - - active: - advanced_snippet: - created_at: "this-is-not-a-datetime-!!!!" - updated_at: "2023-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: DEFAULT - draft: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - - active: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-1 - order: 0 - policy: include - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: TEST - draft: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-1 - order: 0 - policy: exclude - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid + rules: [ ] + advanced_snippet: + updated_at: "wrong datetime" + value: { } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorFiltering.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorFiltering.java index 62a8a68cea5ca..4d357f459cb2f 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorFiltering.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorFiltering.java @@ -43,26 +43,23 @@ */ public class ConnectorFiltering implements Writeable, ToXContentObject { - private final FilteringRules active; - private final String domain; - private final FilteringRules draft; + private FilteringRules active; + private final String domain = "DEFAULT"; // Connectors always use DEFAULT domain, users should not modify it via API + private FilteringRules draft; /** * Constructs a new ConnectorFiltering instance. * * @param active The active filtering rules. - * @param domain The domain associated with the filtering. * @param draft The draft filtering rules. */ - public ConnectorFiltering(FilteringRules active, String domain, FilteringRules draft) { + public ConnectorFiltering(FilteringRules active, FilteringRules draft) { this.active = active; - this.domain = domain; this.draft = draft; } public ConnectorFiltering(StreamInput in) throws IOException { this.active = new FilteringRules(in); - this.domain = in.readString(); this.draft = new FilteringRules(in); } @@ -78,22 +75,27 @@ public FilteringRules getDraft() { return draft; } + public ConnectorFiltering setActive(FilteringRules active) { + this.active = active; + return this; + } + + public ConnectorFiltering setDraft(FilteringRules draft) { + this.draft = draft; + return this; + } + private static final ParseField ACTIVE_FIELD = new ParseField("active"); - private static final ParseField DOMAIN_FIELD = new ParseField("domain"); private static final ParseField DRAFT_FIELD = new ParseField("draft"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "connector_filtering", true, - args -> new ConnectorFiltering.Builder().setActive((FilteringRules) args[0]) - .setDomain((String) args[1]) - .setDraft((FilteringRules) args[2]) - .build() + args -> new ConnectorFiltering.Builder().setActive((FilteringRules) args[0]).setDraft((FilteringRules) args[1]).build() ); static { PARSER.declareObject(constructorArg(), (p, c) -> FilteringRules.fromXContent(p), ACTIVE_FIELD); - PARSER.declareString(constructorArg(), DOMAIN_FIELD); PARSER.declareObject(constructorArg(), (p, c) -> FilteringRules.fromXContent(p), DRAFT_FIELD); } @@ -102,7 +104,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); { builder.field(ACTIVE_FIELD.getPreferredName(), active); - builder.field(DOMAIN_FIELD.getPreferredName(), domain); + builder.field("domain", domain); // We still want to write the DEFAULT domain to the index builder.field(DRAFT_FIELD.getPreferredName(), draft); } builder.endObject(); @@ -124,7 +126,6 @@ public static ConnectorFiltering fromXContentBytes(BytesReference source, XConte @Override public void writeTo(StreamOutput out) throws IOException { active.writeTo(out); - out.writeString(domain); draft.writeTo(out); } @@ -141,10 +142,41 @@ public int hashCode() { return Objects.hash(active, domain, draft); } + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser, Void> CONNECTOR_FILTERING_PARSER = + new ConstructingObjectParser<>( + "connector_filtering_parser", + true, + args -> (List) args[0] + + ); + + static { + CONNECTOR_FILTERING_PARSER.declareObjectArray( + constructorArg(), + (p, c) -> ConnectorFiltering.fromXContent(p), + Connector.FILTERING_FIELD + ); + } + + /** + * Deserializes the {@link ConnectorFiltering} property from a {@link Connector} byte representation. + * + * @param source Byte representation of the {@link Connector}. + * @param xContentType {@link XContentType} of the content (e.g., JSON). + * @return List of {@link ConnectorFiltering} objects. + */ + public static List fromXContentBytesConnectorFiltering(BytesReference source, XContentType xContentType) { + try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { + return CONNECTOR_FILTERING_PARSER.parse(parser, null); + } catch (IOException e) { + throw new ElasticsearchParseException("Failed to parse a connector filtering.", e); + } + } + public static class Builder { private FilteringRules active; - private String domain; private FilteringRules draft; public Builder setActive(FilteringRules active) { @@ -152,21 +184,33 @@ public Builder setActive(FilteringRules active) { return this; } - public Builder setDomain(String domain) { - this.domain = domain; - return this; - } - public Builder setDraft(FilteringRules draft) { this.draft = draft; return this; } public ConnectorFiltering build() { - return new ConnectorFiltering(active, domain, draft); + return new ConnectorFiltering(active, draft); } } + public static boolean isDefaultRulePresentInFilteringRules(List rules) { + FilteringRule defaultRule = getDefaultFilteringRule(null); + return rules.stream().anyMatch(rule -> rule.equalsExceptForTimestampsAndOrder(defaultRule)); + } + + public static FilteringRule getDefaultFilteringRule(Instant timestamp) { + return new FilteringRule.Builder().setCreatedAt(timestamp) + .setField("_") + .setId("DEFAULT") + .setOrder(0) + .setPolicy(FilteringPolicy.INCLUDE) + .setRule(FilteringRuleCondition.REGEX) + .setUpdatedAt(timestamp) + .setValue(".*") + .build(); + } + public static ConnectorFiltering getDefaultConnectorFilteringConfig() { Instant currentTimestamp = Instant.now(); @@ -178,19 +222,7 @@ public static ConnectorFiltering getDefaultConnectorFilteringConfig() { .setAdvancedSnippetValue(Collections.emptyMap()) .build() ) - .setRules( - List.of( - new FilteringRule.Builder().setCreatedAt(currentTimestamp) - .setField("_") - .setId("DEFAULT") - .setOrder(0) - .setPolicy(FilteringPolicy.INCLUDE) - .setRule(FilteringRuleCondition.REGEX) - .setUpdatedAt(currentTimestamp) - .setValue(".*") - .build() - ) - ) + .setRules(List.of(getDefaultFilteringRule(currentTimestamp))) .setFilteringValidationInfo( new FilteringValidationInfo.Builder().setValidationErrors(Collections.emptyList()) .setValidationState(FilteringValidationState.VALID) @@ -198,7 +230,6 @@ public static ConnectorFiltering getDefaultConnectorFilteringConfig() { ) .build() ) - .setDomain("DEFAULT") .setDraft( new FilteringRules.Builder().setAdvancedSnippet( new FilteringAdvancedSnippet.Builder().setAdvancedSnippetCreatedAt(currentTimestamp) @@ -206,19 +237,7 @@ public static ConnectorFiltering getDefaultConnectorFilteringConfig() { .setAdvancedSnippetValue(Collections.emptyMap()) .build() ) - .setRules( - List.of( - new FilteringRule.Builder().setCreatedAt(currentTimestamp) - .setField("_") - .setId("DEFAULT") - .setOrder(0) - .setPolicy(FilteringPolicy.INCLUDE) - .setRule(FilteringRuleCondition.REGEX) - .setUpdatedAt(currentTimestamp) - .setValue(".*") - .build() - ) - ) + .setRules(List.of(getDefaultFilteringRule(currentTimestamp))) .setFilteringValidationInfo( new FilteringValidationInfo.Builder().setValidationErrors(Collections.emptyList()) .setValidationState(FilteringValidationState.VALID) diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java index bceeece6ec17b..20b9a8ec74027 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java @@ -40,12 +40,12 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.action.PostConnectorAction; import org.elasticsearch.xpack.application.connector.action.PutConnectorAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorApiKeyIdAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorConfigurationAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorErrorAction; -import org.elasticsearch.xpack.application.connector.action.UpdateConnectorFilteringAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorIndexNameAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorLastSyncStatsAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorNameAction; @@ -54,6 +54,10 @@ import org.elasticsearch.xpack.application.connector.action.UpdateConnectorSchedulingAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorServiceTypeAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorStatusAction; +import org.elasticsearch.xpack.application.connector.filtering.FilteringAdvancedSnippet; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRule; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRules; +import org.elasticsearch.xpack.application.connector.filtering.FilteringValidationInfo; import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJob; import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobIndexService; @@ -70,6 +74,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; +import static org.elasticsearch.xpack.application.connector.ConnectorFiltering.fromXContentBytesConnectorFiltering; /** * A service that manages persistent {@link Connector} configurations. @@ -555,19 +560,19 @@ public void updateConnectorNameOrDescription(UpdateConnectorNameAction.Request r } /** - * Updates the {@link ConnectorFiltering} property of a {@link Connector}. + * Sets the {@link ConnectorFiltering} property of a {@link Connector}. * - * @param request Request for updating connector filtering property. - * @param listener Listener to respond to a successful response or an error. + * @param connectorId The ID of the {@link Connector} to update. + * @param filtering The list of {@link ConnectorFiltering} . + * @param listener Listener to respond to a successful response or an error. */ - public void updateConnectorFiltering(UpdateConnectorFilteringAction.Request request, ActionListener listener) { + public void updateConnectorFiltering(String connectorId, List filtering, ActionListener listener) { try { - String connectorId = request.getConnectorId(); final UpdateRequest updateRequest = new UpdateRequest(CONNECTOR_INDEX_NAME, connectorId).doc( new IndexRequest(CONNECTOR_INDEX_NAME).opType(DocWriteRequest.OpType.INDEX) .id(connectorId) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(Map.of(Connector.FILTERING_FIELD.getPreferredName(), request.getFiltering())) + .source(Map.of(Connector.FILTERING_FIELD.getPreferredName(), filtering)) ); client.update(updateRequest, new DelegatingIndexNotFoundActionListener<>(connectorId, listener, (l, updateResponse) -> { if (updateResponse.getResult() == UpdateResponse.Result.NOT_FOUND) { @@ -581,6 +586,64 @@ public void updateConnectorFiltering(UpdateConnectorFilteringAction.Request requ } } + /** + * Updates the draft filtering in a given {@link Connector}. + * + * @param connectorId The ID of the {@link Connector} to be updated. + * @param advancedSnippet An instance of {@link FilteringAdvancedSnippet}. + * @param rules A list of instances of {@link FilteringRule} to be applied. + * @param listener Listener to respond to a successful response or an error. + */ + public void updateConnectorFilteringDraft( + String connectorId, + FilteringAdvancedSnippet advancedSnippet, + List rules, + ActionListener listener + ) { + try { + getConnector(connectorId, listener.delegateFailure((l, connector) -> { + List connectorFilteringList = fromXContentBytesConnectorFiltering( + connector.getSourceRef(), + XContentType.JSON + ); + // Connectors represent their filtering configuration as a singleton list + ConnectorFiltering connectorFilteringSingleton = connectorFilteringList.get(0); + + // If advanced snippet or rules are not defined, keep the current draft state + FilteringAdvancedSnippet newDraftAdvancedSnippet = advancedSnippet == null + ? connectorFilteringSingleton.getDraft().getAdvancedSnippet() + : advancedSnippet; + + List newDraftRules = rules == null ? connectorFilteringSingleton.getDraft().getRules() : rules; + + ConnectorFiltering connectorFilteringWithUpdatedDraft = connectorFilteringSingleton.setDraft( + new FilteringRules.Builder().setRules(newDraftRules) + .setAdvancedSnippet(newDraftAdvancedSnippet) + .setFilteringValidationInfo(FilteringValidationInfo.getInitialDraftValidationInfo()) + .build() + ); + + final UpdateRequest updateRequest = new UpdateRequest(CONNECTOR_INDEX_NAME, connectorId).doc( + new IndexRequest(CONNECTOR_INDEX_NAME).opType(DocWriteRequest.OpType.INDEX) + .id(connectorId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(Map.of(Connector.FILTERING_FIELD.getPreferredName(), List.of(connectorFilteringWithUpdatedDraft))) + ); + + client.update(updateRequest, new DelegatingIndexNotFoundActionListener<>(connectorId, listener, (ll, updateResponse) -> { + if (updateResponse.getResult() == UpdateResponse.Result.NOT_FOUND) { + ll.onFailure(new ResourceNotFoundException(connectorNotFoundErrorMsg(connectorId))); + return; + } + ll.onResponse(updateResponse); + })); + })); + + } catch (Exception e) { + listener.onFailure(e); + } + } + /** * Updates the lastSeen property of a {@link Connector}. * diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java index 658a8075121af..ac3b3212c02da 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java @@ -16,7 +16,12 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.application.connector.ConnectorFiltering; import org.elasticsearch.xpack.application.connector.ConnectorIndexService; +import org.elasticsearch.xpack.application.connector.filtering.FilteringAdvancedSnippet; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRule; + +import java.util.List; public class TransportUpdateConnectorFilteringAction extends HandledTransportAction< UpdateConnectorFilteringAction.Request, @@ -47,6 +52,27 @@ protected void doExecute( UpdateConnectorFilteringAction.Request request, ActionListener listener ) { - connectorIndexService.updateConnectorFiltering(request, listener.map(r -> new ConnectorUpdateActionResponse(r.getResult()))); + String connectorId = request.getConnectorId(); + List filtering = request.getFiltering(); + FilteringAdvancedSnippet advancedSnippet = request.getAdvancedSnippet(); + List rules = request.getRules(); + // If [filtering] is not present in request body, it means that user's intention is to + // update draft's rules or advanced snippet + if (request.getFiltering() == null) { + connectorIndexService.updateConnectorFilteringDraft( + connectorId, + advancedSnippet, + rules, + listener.map(r -> new ConnectorUpdateActionResponse(r.getResult())) + ); + } + // Otherwise override the whole filtering object (discouraged in docs) + else { + connectorIndexService.updateConnectorFiltering( + connectorId, + filtering, + listener.map(r -> new ConnectorUpdateActionResponse(r.getResult())) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java index 9d55c12e4b7a1..566a01b855b99 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -23,13 +24,17 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorFiltering; +import org.elasticsearch.xpack.application.connector.filtering.FilteringAdvancedSnippet; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRule; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRules; import java.io.IOException; import java.util.List; import java.util.Objects; import static org.elasticsearch.action.ValidateActions.addValidationError; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.application.connector.ConnectorFiltering.isDefaultRulePresentInFilteringRules; public class UpdateConnectorFilteringAction { @@ -41,17 +46,31 @@ private UpdateConnectorFilteringAction() {/* no instances */} public static class Request extends ConnectorActionRequest implements ToXContentObject { private final String connectorId; + @Nullable private final List filtering; + @Nullable + private final FilteringAdvancedSnippet advancedSnippet; + @Nullable + private final List rules; - public Request(String connectorId, List filtering) { + public Request( + String connectorId, + List filtering, + FilteringAdvancedSnippet advancedSnippet, + List rules + ) { this.connectorId = connectorId; this.filtering = filtering; + this.advancedSnippet = advancedSnippet; + this.rules = rules; } public Request(StreamInput in) throws IOException { super(in); this.connectorId = in.readString(); this.filtering = in.readOptionalCollectionAsList(ConnectorFiltering::new); + this.advancedSnippet = new FilteringAdvancedSnippet(in); + this.rules = in.readCollectionAsList(FilteringRule::new); } public String getConnectorId() { @@ -62,6 +81,14 @@ public List getFiltering() { return filtering; } + public FilteringAdvancedSnippet getAdvancedSnippet() { + return advancedSnippet; + } + + public List getRules() { + return rules; + } + @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; @@ -70,8 +97,29 @@ public ActionRequestValidationException validate() { validationException = addValidationError("[connector_id] cannot be [null] or [\"\"].", validationException); } + // If [filtering] is not present in the request payload it means that the user should define [rules] and/or [advanced_snippet] if (filtering == null) { - validationException = addValidationError("[filtering] cannot be [null].", validationException); + if (rules == null && advancedSnippet == null) { + validationException = addValidationError("[advanced_snippet] and [rules] cannot be both [null].", validationException); + } else if (rules != null) { + if (rules.isEmpty()) { + validationException = addValidationError("[rules] cannot be an empty list.", validationException); + } else if (isDefaultRulePresentInFilteringRules(rules) == false) { + validationException = addValidationError( + "[rules] need to include the default filtering rule.", + validationException + ); + } + } + } + // If [filtering] is present we don't expect [rules] and [advances_snippet] in the request body + else { + if (rules != null || advancedSnippet != null) { + validationException = addValidationError( + "If [filtering] is specified, [rules] and [advanced_snippet] should not be present in the request body.", + validationException + ); + } } return validationException; @@ -82,11 +130,22 @@ public ActionRequestValidationException validate() { new ConstructingObjectParser<>( "connector_update_filtering_request", false, - ((args, connectorId) -> new UpdateConnectorFilteringAction.Request(connectorId, (List) args[0])) + ((args, connectorId) -> new UpdateConnectorFilteringAction.Request( + connectorId, + (List) args[0], + (FilteringAdvancedSnippet) args[1], + (List) args[2] + )) ); static { - PARSER.declareObjectArray(constructorArg(), (p, c) -> ConnectorFiltering.fromXContent(p), Connector.FILTERING_FIELD); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> ConnectorFiltering.fromXContent(p), Connector.FILTERING_FIELD); + PARSER.declareObject( + optionalConstructorArg(), + (p, c) -> FilteringAdvancedSnippet.fromXContent(p), + FilteringRules.ADVANCED_SNIPPET_FIELD + ); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FilteringRule.fromXContent(p), FilteringRules.RULES_FIELD); } public static UpdateConnectorFilteringAction.Request fromXContentBytes( @@ -110,6 +169,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); { builder.field(Connector.FILTERING_FIELD.getPreferredName(), filtering); + builder.field(FilteringRules.ADVANCED_SNIPPET_FIELD.getPreferredName(), advancedSnippet); + builder.xContentList(FilteringRules.RULES_FIELD.getPreferredName(), rules); } builder.endObject(); return builder; @@ -120,6 +181,8 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(connectorId); out.writeOptionalCollection(filtering); + advancedSnippet.writeTo(out); + out.writeCollection(rules); } @Override @@ -127,12 +190,15 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Request request = (Request) o; - return Objects.equals(connectorId, request.connectorId) && Objects.equals(filtering, request.filtering); + return Objects.equals(connectorId, request.connectorId) + && Objects.equals(filtering, request.filtering) + && Objects.equals(advancedSnippet, request.advancedSnippet) + && Objects.equals(rules, request.rules); } @Override public int hashCode() { - return Objects.hash(connectorId, filtering); + return Objects.hash(connectorId, filtering, advancedSnippet, rules); } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringAdvancedSnippet.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringAdvancedSnippet.java index 384fbc7bb5340..62da1dab08358 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringAdvancedSnippet.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringAdvancedSnippet.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -24,6 +25,7 @@ import java.util.Objects; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; /** * Represents an advanced snippet used in filtering processes, providing detailed criteria or rules. @@ -31,8 +33,9 @@ * actual snippet content represented as a map. */ public class FilteringAdvancedSnippet implements Writeable, ToXContentObject { - + @Nullable private final Instant advancedSnippetCreatedAt; + @Nullable private final Instant advancedSnippetUpdatedAt; private final Object advancedSnippetValue; @@ -48,8 +51,8 @@ private FilteringAdvancedSnippet(Instant advancedSnippetCreatedAt, Instant advan } public FilteringAdvancedSnippet(StreamInput in) throws IOException { - this.advancedSnippetCreatedAt = in.readInstant(); - this.advancedSnippetUpdatedAt = in.readInstant(); + this.advancedSnippetCreatedAt = in.readOptionalInstant(); + this.advancedSnippetUpdatedAt = in.readOptionalInstant(); this.advancedSnippetValue = in.readGenericValue(); } @@ -57,6 +60,18 @@ public FilteringAdvancedSnippet(StreamInput in) throws IOException { private static final ParseField UPDATED_AT_FIELD = new ParseField("updated_at"); private static final ParseField VALUE_FIELD = new ParseField("value"); + public Instant getAdvancedSnippetCreatedAt() { + return advancedSnippetCreatedAt; + } + + public Instant getAdvancedSnippetUpdatedAt() { + return advancedSnippetUpdatedAt; + } + + public Object getAdvancedSnippetValue() { + return advancedSnippetValue; + } + @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "connector_filtering_advanced_snippet", @@ -69,13 +84,13 @@ public FilteringAdvancedSnippet(StreamInput in) throws IOException { static { PARSER.declareField( - constructorArg(), + optionalConstructorArg(), (p, c) -> ConnectorUtils.parseInstant(p, CREATED_AT_FIELD.getPreferredName()), CREATED_AT_FIELD, ObjectParser.ValueType.STRING ); PARSER.declareField( - constructorArg(), + optionalConstructorArg(), (p, c) -> ConnectorUtils.parseInstant(p, UPDATED_AT_FIELD.getPreferredName()), UPDATED_AT_FIELD, ObjectParser.ValueType.STRING @@ -108,8 +123,8 @@ public static FilteringAdvancedSnippet fromXContent(XContentParser parser) throw @Override public void writeTo(StreamOutput out) throws IOException { - out.writeInstant(advancedSnippetCreatedAt); - out.writeInstant(advancedSnippetUpdatedAt); + out.writeOptionalInstant(advancedSnippetCreatedAt); + out.writeOptionalInstant(advancedSnippetUpdatedAt); out.writeGenericValue(advancedSnippetValue); } @@ -133,14 +148,15 @@ public static class Builder { private Instant advancedSnippetCreatedAt; private Instant advancedSnippetUpdatedAt; private Object advancedSnippetValue; + private final Instant currentTimestamp = Instant.now(); public Builder setAdvancedSnippetCreatedAt(Instant advancedSnippetCreatedAt) { - this.advancedSnippetCreatedAt = advancedSnippetCreatedAt; + this.advancedSnippetCreatedAt = Objects.requireNonNullElse(advancedSnippetCreatedAt, currentTimestamp); return this; } public Builder setAdvancedSnippetUpdatedAt(Instant advancedSnippetUpdatedAt) { - this.advancedSnippetUpdatedAt = advancedSnippetUpdatedAt; + this.advancedSnippetUpdatedAt = Objects.requireNonNullElse(advancedSnippetUpdatedAt, currentTimestamp); return this; } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRule.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRule.java index 02571078f4e21..3829eb7442522 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRule.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRule.java @@ -23,6 +23,7 @@ import java.util.Objects; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; /** * Represents a single rule used for filtering in a data processing or querying context. @@ -75,13 +76,13 @@ public FilteringRule( } public FilteringRule(StreamInput in) throws IOException { - this.createdAt = in.readInstant(); + this.createdAt = in.readOptionalInstant(); this.field = in.readString(); this.id = in.readString(); this.order = in.readInt(); this.policy = in.readEnum(FilteringPolicy.class); this.rule = in.readEnum(FilteringRuleCondition.class); - this.updatedAt = in.readInstant(); + this.updatedAt = in.readOptionalInstant(); this.value = in.readString(); } @@ -110,7 +111,7 @@ public FilteringRule(StreamInput in) throws IOException { static { PARSER.declareField( - constructorArg(), + optionalConstructorArg(), (p, c) -> ConnectorUtils.parseInstant(p, CREATED_AT_FIELD.getPreferredName()), CREATED_AT_FIELD, ObjectParser.ValueType.STRING @@ -131,7 +132,7 @@ public FilteringRule(StreamInput in) throws IOException { ObjectParser.ValueType.STRING ); PARSER.declareField( - constructorArg(), + optionalConstructorArg(), (p, c) -> ConnectorUtils.parseInstant(p, UPDATED_AT_FIELD.getPreferredName()), UPDATED_AT_FIELD, ObjectParser.ValueType.STRING @@ -160,13 +161,13 @@ public static FilteringRule fromXContent(XContentParser parser) throws IOExcepti @Override public void writeTo(StreamOutput out) throws IOException { - out.writeInstant(createdAt); + out.writeOptionalInstant(createdAt); out.writeString(field); out.writeString(id); out.writeInt(order); out.writeEnum(policy); out.writeEnum(rule); - out.writeInstant(updatedAt); + out.writeOptionalInstant(updatedAt); out.writeString(value); } @@ -185,6 +186,18 @@ public boolean equals(Object o) { && Objects.equals(value, that.value); } + /** + * Compares this {@code FilteringRule} to another rule for equality, ignoring differences + * in created_at, updated_at timestamps and order. + */ + public boolean equalsExceptForTimestampsAndOrder(FilteringRule that) { + return Objects.equals(field, that.field) + && Objects.equals(id, that.id) + && policy == that.policy + && rule == that.rule + && Objects.equals(value, that.value); + } + @Override public int hashCode() { return Objects.hash(createdAt, field, id, order, policy, rule, updatedAt, value); @@ -200,9 +213,10 @@ public static class Builder { private FilteringRuleCondition rule; private Instant updatedAt; private String value; + private final Instant currentTimestamp = Instant.now(); public Builder setCreatedAt(Instant createdAt) { - this.createdAt = createdAt; + this.createdAt = Objects.requireNonNullElse(createdAt, currentTimestamp); return this; } @@ -232,7 +246,7 @@ public Builder setRule(FilteringRuleCondition rule) { } public Builder setUpdatedAt(Instant updatedAt) { - this.updatedAt = updatedAt; + this.updatedAt = Objects.requireNonNullElse(updatedAt, currentTimestamp); return this; } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRules.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRules.java index fb4e25131449d..35d18d23450b1 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRules.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRules.java @@ -69,9 +69,9 @@ public FilteringValidationInfo getFilteringValidationInfo() { return filteringValidationInfo; } - private static final ParseField ADVANCED_SNIPPET_FIELD = new ParseField("advanced_snippet"); - private static final ParseField RULES_FIELD = new ParseField("rules"); - private static final ParseField VALIDATION_FIELD = new ParseField("validation"); + public static final ParseField ADVANCED_SNIPPET_FIELD = new ParseField("advanced_snippet"); + public static final ParseField RULES_FIELD = new ParseField("rules"); + public static final ParseField VALIDATION_FIELD = new ParseField("validation"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringValidationInfo.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringValidationInfo.java index c0cd80d867592..cd197bf0538e4 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringValidationInfo.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringValidationInfo.java @@ -18,6 +18,7 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Objects; @@ -105,6 +106,12 @@ public int hashCode() { return Objects.hash(validationErrors, validationState); } + public static FilteringValidationInfo getInitialDraftValidationInfo() { + return new FilteringValidationInfo.Builder().setValidationErrors(Collections.emptyList()) + .setValidationState(FilteringValidationState.EDITED) + .build(); + } + public static class Builder { private List validationErrors; diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java index f483887c4d81b..ea510086fcf8c 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java @@ -28,7 +28,6 @@ import org.elasticsearch.xpack.application.connector.action.UpdateConnectorApiKeyIdAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorConfigurationAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorErrorAction; -import org.elasticsearch.xpack.application.connector.action.UpdateConnectorFilteringAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorIndexNameAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorLastSeenAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorLastSyncStatsAction; @@ -38,6 +37,9 @@ import org.elasticsearch.xpack.application.connector.action.UpdateConnectorSchedulingAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorServiceTypeAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorStatusAction; +import org.elasticsearch.xpack.application.connector.filtering.FilteringAdvancedSnippet; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRule; +import org.elasticsearch.xpack.application.connector.filtering.FilteringValidationInfo; import org.junit.Before; import java.util.ArrayList; @@ -248,17 +250,46 @@ public void testUpdateConnectorFiltering() throws Exception { .mapToObj((i) -> ConnectorTestUtils.getRandomConnectorFiltering()) .collect(Collectors.toList()); - UpdateConnectorFilteringAction.Request updateFilteringRequest = new UpdateConnectorFilteringAction.Request( - connectorId, - filteringList - ); - - DocWriteResponse updateResponse = awaitUpdateConnectorFiltering(updateFilteringRequest); + DocWriteResponse updateResponse = awaitUpdateConnectorFiltering(connectorId, filteringList); assertThat(updateResponse.status(), equalTo(RestStatus.OK)); Connector indexedConnector = awaitGetConnector(connectorId); assertThat(filteringList, equalTo(indexedConnector.getFiltering())); } + public void testUpdateConnectorFiltering_updateDraft() throws Exception { + Connector connector = ConnectorTestUtils.getRandomConnector(); + String connectorId = randomUUID(); + + DocWriteResponse resp = buildRequestAndAwaitPutConnector(connectorId, connector); + assertThat(resp.status(), anyOf(equalTo(RestStatus.CREATED), equalTo(RestStatus.OK))); + + FilteringAdvancedSnippet advancedSnippet = ConnectorTestUtils.getRandomConnectorFiltering().getDraft().getAdvancedSnippet(); + List rules = ConnectorTestUtils.getRandomConnectorFiltering().getDraft().getRules(); + + DocWriteResponse updateResponse = awaitUpdateConnectorFilteringDraft(connectorId, advancedSnippet, rules); + assertThat(updateResponse.status(), equalTo(RestStatus.OK)); + Connector indexedConnector = awaitGetConnector(connectorId); + + // Assert that draft got updated + assertThat(advancedSnippet, equalTo(indexedConnector.getFiltering().get(0).getDraft().getAdvancedSnippet())); + assertThat(rules, equalTo(indexedConnector.getFiltering().get(0).getDraft().getRules())); + // Assert that draft is marked as EDITED + assertThat( + FilteringValidationInfo.getInitialDraftValidationInfo(), + equalTo(indexedConnector.getFiltering().get(0).getDraft().getFilteringValidationInfo()) + ); + // Assert that default active rules are unchanged, avoid comparing timestamps + assertThat( + ConnectorFiltering.getDefaultConnectorFilteringConfig().getActive().getAdvancedSnippet().getAdvancedSnippetValue(), + equalTo(indexedConnector.getFiltering().get(0).getActive().getAdvancedSnippet().getAdvancedSnippetValue()) + ); + // Assert that domain is unchanged + assertThat( + ConnectorFiltering.getDefaultConnectorFilteringConfig().getDomain(), + equalTo(indexedConnector.getFiltering().get(0).getDomain()) + ); + } + public void testUpdateConnectorLastSeen() throws Exception { Connector connector = ConnectorTestUtils.getRandomConnector(); String connectorId = randomUUID(); @@ -717,11 +748,11 @@ public void onFailure(Exception e) { return resp.get(); } - private UpdateResponse awaitUpdateConnectorFiltering(UpdateConnectorFilteringAction.Request updateFiltering) throws Exception { + private UpdateResponse awaitUpdateConnectorFiltering(String connectorId, List filtering) throws Exception { CountDownLatch latch = new CountDownLatch(1); final AtomicReference resp = new AtomicReference<>(null); final AtomicReference exc = new AtomicReference<>(null); - connectorIndexService.updateConnectorFiltering(updateFiltering, new ActionListener<>() { + connectorIndexService.updateConnectorFiltering(connectorId, filtering, new ActionListener<>() { @Override public void onResponse(UpdateResponse indexResponse) { @@ -744,6 +775,36 @@ public void onFailure(Exception e) { return resp.get(); } + private UpdateResponse awaitUpdateConnectorFilteringDraft( + String connectorId, + FilteringAdvancedSnippet advancedSnippet, + List rules + ) throws Exception { + CountDownLatch latch = new CountDownLatch(1); + final AtomicReference resp = new AtomicReference<>(null); + final AtomicReference exc = new AtomicReference<>(null); + connectorIndexService.updateConnectorFilteringDraft(connectorId, advancedSnippet, rules, new ActionListener<>() { + @Override + public void onResponse(UpdateResponse indexResponse) { + resp.set(indexResponse); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + exc.set(e); + latch.countDown(); + } + }); + + assertTrue("Timeout waiting for update filtering request", latch.await(REQUEST_TIMEOUT_SECONDS, TimeUnit.SECONDS)); + if (exc.get() != null) { + throw exc.get(); + } + assertNotNull("Received null response from update filtering request", resp.get()); + return resp.get(); + } + private UpdateResponse awaitUpdateConnectorIndexName(UpdateConnectorIndexNameAction.Request updateIndexNameRequest) throws Exception { CountDownLatch latch = new CountDownLatch(1); final AtomicReference resp = new AtomicReference<>(null); diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorTestUtils.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorTestUtils.java index 35a910b5641a9..876a1092a1d5b 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorTestUtils.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorTestUtils.java @@ -210,7 +210,6 @@ public static ConnectorFiltering getRandomConnectorFiltering() { ) .build() ) - .setDomain(randomAlphaOfLength(10)) .setDraft( new FilteringRules.Builder().setAdvancedSnippet( new FilteringAdvancedSnippet.Builder().setAdvancedSnippetCreatedAt(currentTimestamp) diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java index 1d433d58be6ad..6874f4b2a1b36 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java @@ -31,7 +31,9 @@ protected UpdateConnectorFilteringAction.Request createTestInstance() { this.connectorId = randomUUID(); return new UpdateConnectorFilteringAction.Request( connectorId, - List.of(ConnectorTestUtils.getRandomConnectorFiltering(), ConnectorTestUtils.getRandomConnectorFiltering()) + List.of(ConnectorTestUtils.getRandomConnectorFiltering(), ConnectorTestUtils.getRandomConnectorFiltering()), + ConnectorTestUtils.getRandomConnectorFiltering().getActive().getAdvancedSnippet(), + ConnectorTestUtils.getRandomConnectorFiltering().getActive().getRules() ); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java index a8afce1a3b223..da014ada387d6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java @@ -50,8 +50,10 @@ public final class ExchangeService extends AbstractLifecycleComponent { // TODO: Make this a child action of the data node transport to ensure that exchanges // are accessed only by the user initialized the session. public static final String EXCHANGE_ACTION_NAME = "internal:data/read/esql/exchange"; + public static final String EXCHANGE_ACTION_NAME_FOR_CCS = "cluster:internal:data/read/esql/exchange"; private static final String OPEN_EXCHANGE_ACTION_NAME = "internal:data/read/esql/open_exchange"; + private static final String OPEN_EXCHANGE_ACTION_NAME_FOR_CCS = "cluster:internal:data/read/esql/open_exchange"; /** * The time interval for an exchange sink handler to be considered inactive and subsequently @@ -85,6 +87,21 @@ public void registerTransportHandler(TransportService transportService) { OpenExchangeRequest::new, new OpenExchangeRequestHandler() ); + + // This allows the system user access this action when executed over CCS and the API key based security model is in use + transportService.registerRequestHandler( + EXCHANGE_ACTION_NAME_FOR_CCS, + this.executor, + ExchangeRequest::new, + new ExchangeTransportAction() + ); + transportService.registerRequestHandler( + OPEN_EXCHANGE_ACTION_NAME_FOR_CCS, + this.executor, + OpenExchangeRequest::new, + new OpenExchangeRequestHandler() + ); + } /** diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java index 6883f71c9ee14..86d48aca3baed 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java @@ -94,8 +94,6 @@ public abstract class RestEsqlTestCase extends ESRestTestCase { // larger than any (unsigned) long private static final String HUMONGOUS_DOUBLE = "1E300"; - private static final String INFINITY = "1.0/0.0"; - private static final String NAN = "0.0/0.0"; public static boolean shouldLog() { return false; @@ -431,22 +429,19 @@ public void testOutOfRangeComparisons() throws IOException { String equalPlusMinus = randomFrom(" == ", " == -"); // TODO: once we do not support infinity and NaN anymore, remove INFINITY/NAN cases. // https://github.com/elastic/elasticsearch/issues/98698#issuecomment-1847423390 - String humongousPositiveLiteral = randomFrom(HUMONGOUS_DOUBLE, INFINITY); - String nanOrNull = randomFrom(NAN, "to_double(null)"); List trueForSingleValuesPredicates = List.of( - lessOrLessEqual + humongousPositiveLiteral, - largerOrLargerEqual + " -" + humongousPositiveLiteral, - inEqualPlusMinus + humongousPositiveLiteral, - inEqualPlusMinus + NAN + lessOrLessEqual + HUMONGOUS_DOUBLE, + largerOrLargerEqual + " -" + HUMONGOUS_DOUBLE, + inEqualPlusMinus + HUMONGOUS_DOUBLE ); List alwaysFalsePredicates = List.of( - lessOrLessEqual + " -" + humongousPositiveLiteral, - largerOrLargerEqual + humongousPositiveLiteral, - equalPlusMinus + humongousPositiveLiteral, - lessOrLessEqual + nanOrNull, - largerOrLargerEqual + nanOrNull, - equalPlusMinus + nanOrNull, + lessOrLessEqual + " -" + HUMONGOUS_DOUBLE, + largerOrLargerEqual + HUMONGOUS_DOUBLE, + equalPlusMinus + HUMONGOUS_DOUBLE, + lessOrLessEqual + "to_double(null)", + largerOrLargerEqual + "to_double(null)", + equalPlusMinus + "to_double(null)", inEqualPlusMinus + "to_double(null)" ); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec index 905eac30a3012..399e1b5dc791b 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec @@ -1213,21 +1213,6 @@ a:double // end::floor-result[] ; -ceilFloorOfInfinite -row i = 1.0/0.0 | eval c = ceil(i), f = floor(i); - -i:double | c:double | f:double -Infinity | Infinity | Infinity -; - -ceilFloorOfNegativeInfinite -row i = -1.0/0.0 | eval c = ceil(i), f = floor(i); - -i:double | c:double | f:double --Infinity | -Infinity | -Infinity -; - - ceilFloorOfInteger row i = 1 | eval c = ceil(i), f = floor(i); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index 867ff127c90e8..749c44d1f6ece 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -235,7 +235,7 @@ h:double ; sumOfScaledFloat -from employees | stats h = sum(height.scaled_float); +from employees | stats h = sum(height.scaled_float) | eval h = round(h, 10); h:double 176.82 diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivDoublesEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivDoublesEvaluator.java index bb9f55f2b5b85..88bf948749ffc 100644 --- a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivDoublesEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivDoublesEvaluator.java @@ -4,6 +4,7 @@ // 2.0. package org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic; +import java.lang.ArithmeticException; import java.lang.IllegalArgumentException; import java.lang.Override; import java.lang.String; @@ -50,7 +51,7 @@ public Block eval(Page page) { if (rhsVector == null) { return eval(page.getPositionCount(), lhsBlock, rhsBlock); } - return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + return eval(page.getPositionCount(), lhsVector, rhsVector); } } } @@ -80,16 +81,26 @@ public DoubleBlock eval(int positionCount, DoubleBlock lhsBlock, DoubleBlock rhs result.appendNull(); continue position; } - result.appendDouble(Div.processDoubles(lhsBlock.getDouble(lhsBlock.getFirstValueIndex(p)), rhsBlock.getDouble(rhsBlock.getFirstValueIndex(p)))); + try { + result.appendDouble(Div.processDoubles(lhsBlock.getDouble(lhsBlock.getFirstValueIndex(p)), rhsBlock.getDouble(rhsBlock.getFirstValueIndex(p)))); + } catch (ArithmeticException e) { + warnings.registerException(e); + result.appendNull(); + } } return result.build(); } } - public DoubleVector eval(int positionCount, DoubleVector lhsVector, DoubleVector rhsVector) { - try(DoubleVector.Builder result = driverContext.blockFactory().newDoubleVectorBuilder(positionCount)) { + public DoubleBlock eval(int positionCount, DoubleVector lhsVector, DoubleVector rhsVector) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { position: for (int p = 0; p < positionCount; p++) { - result.appendDouble(Div.processDoubles(lhsVector.getDouble(p), rhsVector.getDouble(p))); + try { + result.appendDouble(Div.processDoubles(lhsVector.getDouble(p), rhsVector.getDouble(p))); + } catch (ArithmeticException e) { + warnings.registerException(e); + result.appendNull(); + } } return result.build(); } diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModDoublesEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModDoublesEvaluator.java index 8d441ffe10a48..3afcac77973fb 100644 --- a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModDoublesEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModDoublesEvaluator.java @@ -4,6 +4,7 @@ // 2.0. package org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic; +import java.lang.ArithmeticException; import java.lang.IllegalArgumentException; import java.lang.Override; import java.lang.String; @@ -50,7 +51,7 @@ public Block eval(Page page) { if (rhsVector == null) { return eval(page.getPositionCount(), lhsBlock, rhsBlock); } - return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + return eval(page.getPositionCount(), lhsVector, rhsVector); } } } @@ -80,16 +81,26 @@ public DoubleBlock eval(int positionCount, DoubleBlock lhsBlock, DoubleBlock rhs result.appendNull(); continue position; } - result.appendDouble(Mod.processDoubles(lhsBlock.getDouble(lhsBlock.getFirstValueIndex(p)), rhsBlock.getDouble(rhsBlock.getFirstValueIndex(p)))); + try { + result.appendDouble(Mod.processDoubles(lhsBlock.getDouble(lhsBlock.getFirstValueIndex(p)), rhsBlock.getDouble(rhsBlock.getFirstValueIndex(p)))); + } catch (ArithmeticException e) { + warnings.registerException(e); + result.appendNull(); + } } return result.build(); } } - public DoubleVector eval(int positionCount, DoubleVector lhsVector, DoubleVector rhsVector) { - try(DoubleVector.Builder result = driverContext.blockFactory().newDoubleVectorBuilder(positionCount)) { + public DoubleBlock eval(int positionCount, DoubleVector lhsVector, DoubleVector rhsVector) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { position: for (int p = 0; p < positionCount; p++) { - result.appendDouble(Mod.processDoubles(lhsVector.getDouble(p), rhsVector.getDouble(p))); + try { + result.appendDouble(Mod.processDoubles(lhsVector.getDouble(p), rhsVector.getDouble(p))); + } catch (ArithmeticException e) { + warnings.registerException(e); + result.appendNull(); + } } return result.build(); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 13e088b81c95f..02969ed56798f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -52,6 +52,7 @@ import org.elasticsearch.xpack.ql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.ql.expression.function.UnresolvedFunction; import org.elasticsearch.xpack.ql.expression.function.scalar.ScalarFunction; +import org.elasticsearch.xpack.ql.expression.predicate.BinaryOperator; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.index.EsIndex; import org.elasticsearch.xpack.ql.plan.TableIdentifier; @@ -62,7 +63,6 @@ import org.elasticsearch.xpack.ql.plan.logical.Project; import org.elasticsearch.xpack.ql.rule.ParameterizedRule; import org.elasticsearch.xpack.ql.rule.ParameterizedRuleExecutor; -import org.elasticsearch.xpack.ql.rule.Rule; import org.elasticsearch.xpack.ql.rule.RuleExecutor; import org.elasticsearch.xpack.ql.session.Configuration; import org.elasticsearch.xpack.ql.tree.Source; @@ -94,7 +94,6 @@ import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE; import static org.elasticsearch.xpack.esql.stats.FeatureMetric.LIMIT; -import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateTimeToLong; import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_POINT; import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_SHAPE; import static org.elasticsearch.xpack.ql.type.DataTypes.DATETIME; @@ -124,7 +123,7 @@ public class Analyzer extends ParameterizedRuleExecutor("Finish Analysis", Limiter.ONCE, new AddImplicitLimit(), new PromoteStringsInDateComparisons()); + var finish = new Batch<>("Finish Analysis", Limiter.ONCE, new AddImplicitLimit()); rules = List.of(resolution, finish); } @@ -778,58 +777,6 @@ public LogicalPlan apply(LogicalPlan logicalPlan, AnalyzerContext context) { } } - private static class PromoteStringsInDateComparisons extends Rule { - - @Override - public LogicalPlan apply(LogicalPlan plan) { - return plan.transformExpressionsUp(BinaryComparison.class, PromoteStringsInDateComparisons::promote); - } - - private static Expression promote(BinaryComparison cmp) { - if (cmp.resolved() == false) { - return cmp; - } - var left = cmp.left(); - var right = cmp.right(); - boolean modified = false; - if (left.dataType() == DATETIME) { - if (right.dataType() == KEYWORD && right.foldable() && ((right instanceof EsqlScalarFunction) == false)) { - right = stringToDate(right); - modified = true; - } - } else { - if (right.dataType() == DATETIME) { - if (left.dataType() == KEYWORD && left.foldable() && ((left instanceof EsqlScalarFunction) == false)) { - left = stringToDate(left); - modified = true; - } - } - } - return modified ? cmp.replaceChildren(List.of(left, right)) : cmp; - } - - private static Expression stringToDate(Expression stringExpression) { - var str = stringExpression.fold().toString(); - - Long millis = null; - // TODO: better control over this string format - do we want this to be flexible or always redirect folks to use date parsing - try { - millis = str == null ? null : dateTimeToLong(str); - } catch (Exception ex) { // in case of exception, millis will be null which will trigger an error - } - - var source = stringExpression.source(); - Expression result; - if (millis == null) { - var errorMessage = format(null, "Invalid date [{}]", str); - result = new UnresolvedAttribute(source, source.text(), null, errorMessage); - } else { - result = new Literal(source, millis, DATETIME); - } - return result; - } - } - private BitSet gatherPreAnalysisMetrics(LogicalPlan plan, BitSet b) { // count only the explicit "limit" the user added, otherwise all queries will have a "limit" and telemetry won't reflect reality if (plan.collectFirstChildren(Limit.class::isInstance).isEmpty() == false) { @@ -852,11 +799,9 @@ private static Expression cast(ScalarFunction f, EsqlFunctionRegistry registry) if (f instanceof EsqlScalarFunction esf) { return processScalarFunction(esf, registry); } - - if (f instanceof EsqlArithmeticOperation eao) { - return processArithmeticOperation(eao); + if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) { + return processBinaryOperator((BinaryOperator) f); } - return f; } @@ -888,7 +833,7 @@ private static Expression processScalarFunction(EsqlScalarFunction f, EsqlFuncti return childrenChanged ? f.replaceChildren(newChildren) : f; } - private static Expression processArithmeticOperation(EsqlArithmeticOperation o) { + private static Expression processBinaryOperator(BinaryOperator o) { Expression left = o.left(); Expression right = o.right(); if (left.resolved() == false || right.resolved() == false) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java index e5d4e58d9d61b..366fb4ff55ba6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java @@ -270,6 +270,7 @@ private void doLookup( }; var queryOperator = new EnrichQuerySourceOperator( driverContext.blockFactory(), + EnrichQuerySourceOperator.DEFAULT_MAX_PAGE_SIZE, queryList, searchExecutionContext.getIndexReader() ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperator.java index b0582e211fdba..6937f1a8c7772 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperator.java @@ -15,7 +15,6 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Weight; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.IntBlock; @@ -36,14 +35,17 @@ final class EnrichQuerySourceOperator extends SourceOperator { private final BlockFactory blockFactory; private final QueryList queryList; - private int queryPosition; - private Weight weight = null; + private int queryPosition = -1; private final IndexReader indexReader; - private int leafIndex = 0; private final IndexSearcher searcher; + private final int maxPageSize; - EnrichQuerySourceOperator(BlockFactory blockFactory, QueryList queryList, IndexReader indexReader) { + // using smaller pages enables quick cancellation and reduces sorting costs + static final int DEFAULT_MAX_PAGE_SIZE = 256; + + EnrichQuerySourceOperator(BlockFactory blockFactory, int maxPageSize, QueryList queryList, IndexReader indexReader) { this.blockFactory = blockFactory; + this.maxPageSize = maxPageSize; this.queryList = queryList; this.indexReader = indexReader; this.searcher = new IndexSearcher(indexReader); @@ -59,62 +61,96 @@ public boolean isFinished() { @Override public Page getOutput() { - if (leafIndex == indexReader.leaves().size()) { - queryPosition++; - leafIndex = 0; - weight = null; - } - if (isFinished()) { - return null; - } - if (weight == null) { - Query query = queryList.getQuery(queryPosition); - if (query != null) { - try { - query = searcher.rewrite(new ConstantScoreQuery(query)); - weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1.0f); - } catch (IOException e) { - throw new UncheckedIOException(e); - } + int estimatedSize = Math.min(maxPageSize, queryList.getPositionCount() - queryPosition); + IntVector.Builder positionsBuilder = null; + IntVector.Builder docsBuilder = null; + IntVector.Builder segmentsBuilder = null; + try { + positionsBuilder = blockFactory.newIntVectorBuilder(estimatedSize); + docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize); + if (indexReader.leaves().size() > 1) { + segmentsBuilder = blockFactory.newIntVectorBuilder(estimatedSize); } + int totalMatches = 0; + do { + Query query = nextQuery(); + if (query == null) { + assert isFinished(); + break; + } + query = searcher.rewrite(new ConstantScoreQuery(query)); + final var weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1.0f); + if (weight == null) { + continue; + } + for (LeafReaderContext leaf : indexReader.leaves()) { + var scorer = weight.bulkScorer(leaf); + if (scorer == null) { + continue; + } + final DocCollector collector = new DocCollector(docsBuilder); + scorer.score(collector, leaf.reader().getLiveDocs()); + int matches = collector.matches; + + if (segmentsBuilder != null) { + for (int i = 0; i < matches; i++) { + segmentsBuilder.appendInt(leaf.ord); + } + } + for (int i = 0; i < matches; i++) { + positionsBuilder.appendInt(queryPosition); + } + totalMatches += matches; + } + } while (totalMatches < maxPageSize); + + return buildPage(totalMatches, positionsBuilder, segmentsBuilder, docsBuilder); + } catch (IOException e) { + throw new UncheckedIOException(e); + } finally { + Releasables.close(docsBuilder, segmentsBuilder, positionsBuilder); } + } + + Page buildPage(int positions, IntVector.Builder positionsBuilder, IntVector.Builder segmentsBuilder, IntVector.Builder docsBuilder) { + IntVector positionsVector = null; + IntVector shardsVector = null; + IntVector segmentsVector = null; + IntVector docsVector = null; + Page page = null; try { - return queryOneLeaf(weight, leafIndex++); - } catch (IOException ex) { - throw new UncheckedIOException(ex); + positionsVector = positionsBuilder.build(); + shardsVector = blockFactory.newConstantIntVector(0, positions); + if (segmentsBuilder == null) { + segmentsVector = blockFactory.newConstantIntVector(0, positions); + } else { + segmentsVector = segmentsBuilder.build(); + } + docsVector = docsBuilder.build(); + page = new Page(new DocVector(shardsVector, segmentsVector, docsVector, null).asBlock(), positionsVector.asBlock()); + } finally { + if (page == null) { + Releasables.close(positionsBuilder, segmentsVector, docsBuilder, positionsVector, shardsVector, docsVector); + } } + return page; } - private Page queryOneLeaf(Weight weight, int leafIndex) throws IOException { - if (weight == null) { - return null; - } - LeafReaderContext leafReaderContext = indexReader.leaves().get(leafIndex); - var scorer = weight.bulkScorer(leafReaderContext); - if (scorer == null) { - return null; - } - IntVector docs = null, segments = null, shards = null, positions = null; - boolean success = false; - try (IntVector.Builder docsBuilder = blockFactory.newIntVectorBuilder(1)) { - scorer.score(new DocCollector(docsBuilder), leafReaderContext.reader().getLiveDocs()); - docs = docsBuilder.build(); - final int positionCount = docs.getPositionCount(); - segments = blockFactory.newConstantIntVector(leafIndex, positionCount); - shards = blockFactory.newConstantIntVector(0, positionCount); - positions = blockFactory.newConstantIntVector(queryPosition, positionCount); - Page page = new Page(new DocVector(shards, segments, docs, true).asBlock(), positions.asBlock()); - success = true; - return page; - } finally { - if (success == false) { - Releasables.close(docs, shards, segments, positions); + private Query nextQuery() { + ++queryPosition; + while (isFinished() == false) { + Query query = queryList.getQuery(queryPosition); + if (query != null) { + return query; } + ++queryPosition; } + return null; } private static class DocCollector implements LeafCollector { final IntVector.Builder docIds; + int matches = 0; DocCollector(IntVector.Builder docIds) { this.docIds = docIds; @@ -127,6 +163,7 @@ public void setScorer(Scorable scorer) { @Override public void collect(int doc) { + ++matches; docIds.appendInt(doc); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java index 170e3de6e4209..73863d308f6e4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; +import org.elasticsearch.xpack.ql.util.NumericUtils; import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.DIV; import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.longToUnsignedLong; @@ -63,21 +64,34 @@ public ArithmeticOperationFactory binaryComparisonInverse() { @Evaluator(extraName = "Ints", warnExceptions = { ArithmeticException.class }) static int processInts(int lhs, int rhs) { + if (rhs == 0) { + throw new ArithmeticException("/ by zero"); + } return lhs / rhs; } @Evaluator(extraName = "Longs", warnExceptions = { ArithmeticException.class }) static long processLongs(long lhs, long rhs) { + if (rhs == 0L) { + throw new ArithmeticException("/ by zero"); + } return lhs / rhs; } @Evaluator(extraName = "UnsignedLongs", warnExceptions = { ArithmeticException.class }) static long processUnsignedLongs(long lhs, long rhs) { + if (rhs == NumericUtils.ZERO_AS_UNSIGNED_LONG) { + throw new ArithmeticException("/ by zero"); + } return longToUnsignedLong(Long.divideUnsigned(longToUnsignedLong(lhs, true), longToUnsignedLong(rhs, true)), true); } - @Evaluator(extraName = "Doubles") + @Evaluator(extraName = "Doubles", warnExceptions = { ArithmeticException.class }) static double processDoubles(double lhs, double rhs) { - return lhs / rhs; + double value = lhs / rhs; + if (Double.isNaN(value) || Double.isInfinite(value)) { + throw new ArithmeticException("/ by zero"); + } + return value; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mod.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mod.java index bc1ad8fcb5f94..df3b8f27c4880 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mod.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mod.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; +import org.elasticsearch.xpack.ql.util.NumericUtils; import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.MOD; import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.longToUnsignedLong; @@ -42,21 +43,34 @@ protected Mod replaceChildren(Expression left, Expression right) { @Evaluator(extraName = "Ints", warnExceptions = { ArithmeticException.class }) static int processInts(int lhs, int rhs) { + if (rhs == 0) { + throw new ArithmeticException("/ by zero"); + } return lhs % rhs; } @Evaluator(extraName = "Longs", warnExceptions = { ArithmeticException.class }) static long processLongs(long lhs, long rhs) { + if (rhs == 0L) { + throw new ArithmeticException("/ by zero"); + } return lhs % rhs; } @Evaluator(extraName = "UnsignedLongs", warnExceptions = { ArithmeticException.class }) static long processUnsignedLongs(long lhs, long rhs) { + if (rhs == NumericUtils.ZERO_AS_UNSIGNED_LONG) { + throw new ArithmeticException("/ by zero"); + } return longToUnsignedLong(Long.remainderUnsigned(longToUnsignedLong(lhs, true), longToUnsignedLong(rhs, true)), true); } - @Evaluator(extraName = "Doubles") + @Evaluator(extraName = "Doubles", warnExceptions = { ArithmeticException.class }) static double processDoubles(double lhs, double rhs) { - return lhs % rhs; + double value = lhs % rhs; + if (Double.isNaN(value) || Double.isInfinite(value)) { + throw new ArithmeticException("/ by zero"); + } + return value; } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index f4ecf38915a29..7a85ca1628048 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -1003,13 +1003,7 @@ public void testCompareIntToString() { from test | where emp_no COMPARISON "foo" """.replace("COMPARISON", comparison))); - assertThat( - e.getMessage(), - containsString( - "first argument of [emp_no COMPARISON \"foo\"] is [numeric] so second argument must also be [numeric] but was [keyword]" - .replace("COMPARISON", comparison) - ) - ); + assertThat(e.getMessage(), containsString("Cannot convert string [foo] to [INTEGER]".replace("COMPARISON", comparison))); } } @@ -1019,13 +1013,7 @@ public void testCompareStringToInt() { from test | where "foo" COMPARISON emp_no """.replace("COMPARISON", comparison))); - assertThat( - e.getMessage(), - containsString( - "first argument of [\"foo\" COMPARISON emp_no] is [keyword] so second argument must also be [keyword] but was [integer]" - .replace("COMPARISON", comparison) - ) - ); + assertThat(e.getMessage(), containsString("Cannot convert string [foo] to [INTEGER]".replace("COMPARISON", comparison))); } } @@ -1051,11 +1039,15 @@ public void testCompareStringToDate() { public void testCompareDateToStringFails() { for (String comparison : COMPARISONS) { - verifyUnsupported(""" - from test - | where date COMPARISON "not-a-date" - | keep date - """.replace("COMPARISON", comparison), "Invalid date [not-a-date]", "mapping-multi-field-variation.json"); + verifyUnsupported( + """ + from test + | where date COMPARISON "not-a-date" + | keep date + """.replace("COMPARISON", comparison), + "Cannot convert string [not-a-date] to [DATETIME]", + "mapping-multi-field-variation.json" + ); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index e558dbe615642..8275f76d9a55c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -317,7 +317,7 @@ public void testSumOnDate() { public void testWrongInputParam() { assertEquals( - "1:19: first argument of [emp_no == ?] is [numeric] so second argument must also be [numeric] but was [keyword]", + "1:29: Cannot convert string [foo] to [INTEGER], error [Cannot parse number [foo]]", error("from test | where emp_no == ?", "foo") ); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperatorTests.java index 7f8e1f7113e22..eef29f0681fbd 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperatorTests.java @@ -48,6 +48,7 @@ import static org.elasticsearch.xpack.ql.type.DataTypes.KEYWORD; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.mockito.Mockito.mock; public class EnrichQuerySourceOperatorTests extends ESTestCase { @@ -120,60 +121,26 @@ public void testQueries() throws Exception { // 3 -> [] -> [] // 4 -> [a1] -> [3] // 5 -> [] -> [] - EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, queryList, reader); - { - Page p0 = queryOperator.getOutput(); - assertNotNull(p0); - assertThat(p0.getPositionCount(), equalTo(2)); - IntVector docs = getDocVector(p0, 0); - assertThat(docs.getInt(0), equalTo(1)); - assertThat(docs.getInt(1), equalTo(4)); - Block positions = p0.getBlock(1); - assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(0)); - assertThat(BlockUtils.toJavaObject(positions, 1), equalTo(0)); - p0.releaseBlocks(); - } - { - Page p1 = queryOperator.getOutput(); - assertNotNull(p1); - assertThat(p1.getPositionCount(), equalTo(3)); - IntVector docs = getDocVector(p1, 0); - assertThat(docs.getInt(0), equalTo(0)); - assertThat(docs.getInt(1), equalTo(1)); - assertThat(docs.getInt(2), equalTo(2)); - Block positions = p1.getBlock(1); - assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(1)); - assertThat(BlockUtils.toJavaObject(positions, 1), equalTo(1)); - assertThat(BlockUtils.toJavaObject(positions, 2), equalTo(1)); - p1.releaseBlocks(); - } - { - Page p2 = queryOperator.getOutput(); - assertNull(p2); - } - { - Page p3 = queryOperator.getOutput(); - assertNull(p3); - } - { - Page p4 = queryOperator.getOutput(); - assertNotNull(p4); - assertThat(p4.getPositionCount(), equalTo(1)); - IntVector docs = getDocVector(p4, 0); - assertThat(docs.getInt(0), equalTo(3)); - Block positions = p4.getBlock(1); - assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(4)); - p4.releaseBlocks(); - } - { - Page p5 = queryOperator.getOutput(); - assertNull(p5); - } - { - assertFalse(queryOperator.isFinished()); - Page p6 = queryOperator.getOutput(); - assertNull(p6); - } + EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, 128, queryList, reader); + Page p0 = queryOperator.getOutput(); + assertNotNull(p0); + assertThat(p0.getPositionCount(), equalTo(6)); + IntVector docs = getDocVector(p0, 0); + assertThat(docs.getInt(0), equalTo(1)); + assertThat(docs.getInt(1), equalTo(4)); + assertThat(docs.getInt(2), equalTo(0)); + assertThat(docs.getInt(3), equalTo(1)); + assertThat(docs.getInt(4), equalTo(2)); + assertThat(docs.getInt(5), equalTo(3)); + + Block positions = p0.getBlock(1); + assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(0)); + assertThat(BlockUtils.toJavaObject(positions, 1), equalTo(0)); + assertThat(BlockUtils.toJavaObject(positions, 2), equalTo(1)); + assertThat(BlockUtils.toJavaObject(positions, 3), equalTo(1)); + assertThat(BlockUtils.toJavaObject(positions, 4), equalTo(1)); + assertThat(BlockUtils.toJavaObject(positions, 5), equalTo(4)); + p0.releaseBlocks(); assertTrue(queryOperator.isFinished()); IOUtils.close(reader, dir, inputTerms); } @@ -220,13 +187,15 @@ public void testRandomMatchQueries() throws Exception { } MappedFieldType uidField = new KeywordFieldMapper.KeywordFieldType("uid"); var queryList = QueryList.termQueryList(uidField, mock(SearchExecutionContext.class), inputTerms, KEYWORD); - EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, queryList, reader); + int maxPageSize = between(1, 256); + EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, maxPageSize, queryList, reader); Map> actualPositions = new HashMap<>(); while (queryOperator.isFinished() == false) { Page page = queryOperator.getOutput(); if (page != null) { IntVector docs = getDocVector(page, 0); IntBlock positions = page.getBlock(1); + assertThat(positions.getPositionCount(), lessThanOrEqualTo(maxPageSize)); for (int i = 0; i < page.getPositionCount(); i++) { int doc = docs.getInt(i); int position = positions.getInt(i); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index b2d00a98dfa6c..0b6c64679dc1f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -857,6 +857,7 @@ protected static String typeErrorMessage(boolean includeOrdinal, List forBinaryCastingToDouble( return suppliers; } - private static void casesCrossProduct( + public static void casesCrossProduct( BinaryOperator expected, List lhsSuppliers, List rhsSuppliers, @@ -251,10 +251,10 @@ private static TestCaseSupplier testCaseSupplier( public static List castToDoubleSuppliersFromRange(Double Min, Double Max) { List suppliers = new ArrayList<>(); - suppliers.addAll(intCases(Min.intValue(), Max.intValue())); - suppliers.addAll(longCases(Min.longValue(), Max.longValue())); - suppliers.addAll(ulongCases(BigInteger.valueOf((long) Math.ceil(Min)), BigInteger.valueOf((long) Math.floor(Max)))); - suppliers.addAll(doubleCases(Min, Max)); + suppliers.addAll(intCases(Min.intValue(), Max.intValue(), true)); + suppliers.addAll(longCases(Min.longValue(), Max.longValue(), true)); + suppliers.addAll(ulongCases(BigInteger.valueOf((long) Math.ceil(Min)), BigInteger.valueOf((long) Math.floor(Max)), true)); + suppliers.addAll(doubleCases(Min, Max, true)); return suppliers; } @@ -279,7 +279,7 @@ public NumericTypeTestConfig get(DataType type) { } } - private static DataType widen(DataType lhs, DataType rhs) { + public static DataType widen(DataType lhs, DataType rhs) { if (lhs == rhs) { return lhs; } @@ -292,21 +292,22 @@ private static DataType widen(DataType lhs, DataType rhs) { throw new IllegalArgumentException("Invalid numeric widening lhs: [" + lhs + "] rhs: [" + rhs + "]"); } - private static List getSuppliersForNumericType(DataType type, Number min, Number max) { + public static List getSuppliersForNumericType(DataType type, Number min, Number max, boolean includeZero) { if (type == DataTypes.INTEGER) { - return intCases(NumericUtils.saturatingIntValue(min), NumericUtils.saturatingIntValue(max)); + return intCases(NumericUtils.saturatingIntValue(min), NumericUtils.saturatingIntValue(max), includeZero); } if (type == DataTypes.LONG) { - return longCases(min.longValue(), max.longValue()); + return longCases(min.longValue(), max.longValue(), includeZero); } if (type == DataTypes.UNSIGNED_LONG) { return ulongCases( min instanceof BigInteger ? (BigInteger) min : BigInteger.valueOf(Math.max(min.longValue(), 0L)), - max instanceof BigInteger ? (BigInteger) max : BigInteger.valueOf(Math.max(max.longValue(), 0L)) + max instanceof BigInteger ? (BigInteger) max : BigInteger.valueOf(Math.max(max.longValue(), 0L)), + includeZero ); } if (type == DataTypes.DOUBLE) { - return doubleCases(min.doubleValue(), max.doubleValue()); + return doubleCases(min.doubleValue(), max.doubleValue(), includeZero); } throw new IllegalArgumentException("bogus numeric type [" + type + "]"); } @@ -315,7 +316,8 @@ public static List forBinaryWithWidening( NumericTypeTestConfigs typeStuff, String lhsName, String rhsName, - List warnings + List warnings, + boolean allowRhsZero ) { List suppliers = new ArrayList<>(); List numericTypes = List.of(DataTypes.INTEGER, DataTypes.LONG, DataTypes.DOUBLE); @@ -336,13 +338,13 @@ public static List forBinaryWithWidening( + "]"; casesCrossProduct( (l, r) -> expectedTypeStuff.expected().apply((Number) l, (Number) r), - getSuppliersForNumericType(lhsType, expectedTypeStuff.min(), expectedTypeStuff.max()), - getSuppliersForNumericType(rhsType, expectedTypeStuff.min(), expectedTypeStuff.max()), + getSuppliersForNumericType(lhsType, expectedTypeStuff.min(), expectedTypeStuff.max(), true), + getSuppliersForNumericType(rhsType, expectedTypeStuff.min(), expectedTypeStuff.max(), allowRhsZero), evaluatorToString, warnings, suppliers, expected, - true + false ); } } @@ -358,7 +360,8 @@ public static List forBinaryNotCasting( DataType expectedType, List lhsSuppliers, List rhsSuppliers, - List warnings + List warnings, + boolean symmetric ) { List suppliers = new ArrayList<>(); casesCrossProduct( @@ -369,7 +372,7 @@ public static List forBinaryNotCasting( warnings, suppliers, expectedType, - true + symmetric ); return suppliers; } @@ -389,7 +392,7 @@ public static void forUnaryInt( unaryNumeric( suppliers, expectedEvaluatorToString, - intCases(lowerBound, upperBound), + intCases(lowerBound, upperBound, true), expectedType, n -> expectedValue.apply(n.intValue()), n -> expectedWarnings.apply(n.intValue()) @@ -423,7 +426,7 @@ public static void forUnaryLong( unaryNumeric( suppliers, expectedEvaluatorToString, - longCases(lowerBound, upperBound), + longCases(lowerBound, upperBound, true), expectedType, n -> expectedValue.apply(n.longValue()), expectedWarnings @@ -457,7 +460,7 @@ public static void forUnaryUnsignedLong( unaryNumeric( suppliers, expectedEvaluatorToString, - ulongCases(lowerBound, upperBound), + ulongCases(lowerBound, upperBound, true), expectedType, n -> expectedValue.apply((BigInteger) n), n -> expectedWarnings.apply((BigInteger) n) @@ -503,7 +506,7 @@ public static void forUnaryDouble( unaryNumeric( suppliers, expectedEvaluatorToString, - doubleCases(lowerBound, upperBound), + doubleCases(lowerBound, upperBound, true), expectedType, n -> expectedValue.apply(n.doubleValue()), n -> expectedWarnings.apply(n.doubleValue()) @@ -729,9 +732,9 @@ public static void unary( unary(suppliers, expectedEvaluatorToString, valueSuppliers, expectedOutputType, expected, unused -> warnings); } - public static List intCases(int min, int max) { + public static List intCases(int min, int max, boolean includeZero) { List cases = new ArrayList<>(); - if (0 <= max && 0 >= min) { + if (0 <= max && 0 >= min && includeZero) { cases.add(new TypedDataSupplier("<0 int>", () -> 0, DataTypes.INTEGER)); } @@ -753,9 +756,9 @@ public static List intCases(int min, int max) { return cases; } - public static List longCases(long min, long max) { + public static List longCases(long min, long max, boolean includeZero) { List cases = new ArrayList<>(); - if (0L <= max && 0L >= min) { + if (0L <= max && 0L >= min && includeZero) { cases.add(new TypedDataSupplier("<0 long>", () -> 0L, DataTypes.LONG)); } @@ -778,11 +781,11 @@ public static List longCases(long min, long max) { return cases; } - public static List ulongCases(BigInteger min, BigInteger max) { + public static List ulongCases(BigInteger min, BigInteger max, boolean includeZero) { List cases = new ArrayList<>(); // Zero - if (BigInteger.ZERO.compareTo(max) <= 0 && BigInteger.ZERO.compareTo(min) >= 0) { + if (BigInteger.ZERO.compareTo(max) <= 0 && BigInteger.ZERO.compareTo(min) >= 0 && includeZero) { cases.add(new TypedDataSupplier("<0 unsigned long>", () -> BigInteger.ZERO, DataTypes.UNSIGNED_LONG)); } @@ -818,11 +821,11 @@ public static List ulongCases(BigInteger min, BigInteger max) return cases; } - public static List doubleCases(double min, double max) { + public static List doubleCases(double min, double max, boolean includeZero) { List cases = new ArrayList<>(); // Zeros - if (0d <= max && 0d >= min) { + if (0d <= max && 0d >= min && includeZero) { cases.add(new TypedDataSupplier("<0 double>", () -> 0.0d, DataTypes.DOUBLE)); cases.add(new TypedDataSupplier("<-0 double>", () -> -0.0d, DataTypes.DOUBLE)); } @@ -1046,7 +1049,7 @@ public static List versionCases(String prefix) { ); } - private static String getCastEvaluator(String original, DataType current, DataType target) { + public static String getCastEvaluator(String original, DataType current, DataType target) { if (current == target) { return original; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToIntegerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToIntegerTests.java index 3a6cb86b7a3c6..e6f6cb7e978f7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToIntegerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToIntegerTests.java @@ -178,7 +178,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.intCases(Integer.MIN_VALUE, Integer.MAX_VALUE) + TestCaseSupplier.intCases(Integer.MIN_VALUE, Integer.MAX_VALUE, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -196,7 +196,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Integer.MIN_VALUE, Integer.MAX_VALUE) + TestCaseSupplier.doubleCases(Integer.MIN_VALUE, Integer.MAX_VALUE, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -214,7 +214,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Integer.MIN_VALUE - 1d) + TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Integer.MIN_VALUE - 1d, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -237,7 +237,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Integer.MAX_VALUE + 1d, Double.POSITIVE_INFINITY) + TestCaseSupplier.doubleCases(Integer.MAX_VALUE + 1d, Double.POSITIVE_INFINITY, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToLongTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToLongTests.java index 031ce6193bcc4..1879b7ce97ea8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToLongTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToLongTests.java @@ -129,7 +129,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.longCases(Long.MIN_VALUE, Long.MAX_VALUE) + TestCaseSupplier.longCases(Long.MIN_VALUE, Long.MAX_VALUE, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -147,7 +147,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Long.MIN_VALUE, Long.MAX_VALUE) + TestCaseSupplier.doubleCases(Long.MIN_VALUE, Long.MAX_VALUE, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -165,7 +165,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Long.MIN_VALUE - 1d) + TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Long.MIN_VALUE - 1d, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -188,7 +188,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Long.MAX_VALUE + 1d, Double.POSITIVE_INFINITY) + TestCaseSupplier.doubleCases(Long.MAX_VALUE + 1d, Double.POSITIVE_INFINITY, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToUnsignedLongTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToUnsignedLongTests.java index 8d5ee002a8f78..3cb9c813fd0b5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToUnsignedLongTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToUnsignedLongTests.java @@ -165,7 +165,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.ulongCases(BigInteger.ZERO, UNSIGNED_LONG_MAX) + TestCaseSupplier.ulongCases(BigInteger.ZERO, UNSIGNED_LONG_MAX, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -183,7 +183,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(0, UNSIGNED_LONG_MAX_AS_DOUBLE) + TestCaseSupplier.doubleCases(0, UNSIGNED_LONG_MAX_AS_DOUBLE, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -201,7 +201,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, -1d) + TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, -1d, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -224,7 +224,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(UNSIGNED_LONG_MAX_AS_DOUBLE + 10e5, Double.POSITIVE_INFINITY) + TestCaseSupplier.doubleCases(UNSIGNED_LONG_MAX_AS_DOUBLE + 10e5, Double.POSITIVE_INFINITY, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java index 6a74dd13c1e3a..143f7e5aaba9f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java @@ -65,7 +65,8 @@ public static Iterable parameters() { ), "lhs", "rhs", - List.of() + List.of(), + true ) ); @@ -79,9 +80,10 @@ public static Iterable parameters() { "rhs", (l, r) -> (((BigInteger) l).add((BigInteger) r)), DataTypes.UNSIGNED_LONG, - TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE)), - TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE)), - List.of() + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + List.of(), + true ) ); @@ -96,7 +98,8 @@ public static Iterable parameters() { EsqlDataTypes.DATE_PERIOD, TestCaseSupplier.datePeriodCases(), TestCaseSupplier.datePeriodCases(), - List.of() + List.of(), + true ) ); suppliers.addAll( @@ -108,7 +111,8 @@ public static Iterable parameters() { EsqlDataTypes.TIME_DURATION, TestCaseSupplier.timeDurationCases(), TestCaseSupplier.timeDurationCases(), - List.of() + List.of(), + true ) ); @@ -134,7 +138,8 @@ public static Iterable parameters() { DataTypes.DATETIME, TestCaseSupplier.dateCases(), TestCaseSupplier.datePeriodCases(), - List.of() + List.of(), + true ) ); suppliers.addAll( @@ -159,7 +164,8 @@ public static Iterable parameters() { DataTypes.DATETIME, TestCaseSupplier.dateCases(), TestCaseSupplier.timeDurationCases(), - List.of() + List.of(), + true ) ); suppliers.addAll(TestCaseSupplier.dateCases().stream().mapMulti((tds, consumer) -> { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java index 4aa8786f2cd69..1f5d57394ff4d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java @@ -10,7 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; -import org.elasticsearch.compute.data.Block; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.tree.Source; @@ -18,140 +18,149 @@ import org.elasticsearch.xpack.ql.type.DataTypes; import java.math.BigInteger; +import java.util.ArrayList; import java.util.List; +import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Supplier; -import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; -import static org.elasticsearch.xpack.ql.util.NumericUtils.ZERO_AS_UNSIGNED_LONG; -import static org.elasticsearch.xpack.ql.util.NumericUtils.asLongUnsigned; -import static org.elasticsearch.xpack.ql.util.NumericUtils.unsignedLongAsBigInteger; -import static org.hamcrest.Matchers.equalTo; - -public class DivTests extends AbstractArithmeticTestCase { +public class DivTests extends AbstractFunctionTestCase { public DivTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @ParametersFactory public static Iterable parameters() { - return parameterSuppliersFromTypedData(List.of(new TestCaseSupplier("Int / Int", () -> { - int lhs = randomInt(); - int rhs; - do { - rhs = randomInt(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.INTEGER, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.INTEGER, "rhs") - ), - "DivIntsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.INTEGER, - equalTo(lhs / rhs) - ); - }), new TestCaseSupplier("Long / Long", () -> { - long lhs = randomLong(); - long rhs; - do { - rhs = randomLong(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.LONG, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.LONG, "rhs") + List suppliers = new ArrayList<>(); + suppliers.addAll( + TestCaseSupplier.forBinaryWithWidening( + new TestCaseSupplier.NumericTypeTestConfigs( + new TestCaseSupplier.NumericTypeTestConfig( + (Integer.MIN_VALUE >> 1) - 1, + (Integer.MAX_VALUE >> 1) - 1, + (l, r) -> l.intValue() / r.intValue(), + "DivIntsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + (Long.MIN_VALUE >> 1) - 1, + (Long.MAX_VALUE >> 1) - 1, + (l, r) -> l.longValue() / r.longValue(), + "DivLongsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY, + (l, r) -> l.doubleValue() / r.doubleValue(), + "DivDoublesEvaluator" + ) ), - "DivLongsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.LONG, - equalTo(lhs / rhs) - ); - }), new TestCaseSupplier("Double / Double", () -> { - double lhs = randomDouble(); - double rhs; - do { - rhs = randomDouble(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.DOUBLE, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.DOUBLE, "rhs") - ), - "DivDoublesEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(lhs / rhs) - ); - })/*, new TestCaseSupplier("ULong / ULong", () -> { - // Ensure we don't have an overflow - long lhs = randomLong(); - long rhs; - do { - rhs = randomLong(); - } while (rhs == 0); - BigInteger lhsBI = unsignedLongAsBigInteger(lhs); - BigInteger rhsBI = unsignedLongAsBigInteger(rhs); - return new TestCase( - Source.EMPTY, - List.of(new TypedData(lhs, DataTypes.UNSIGNED_LONG, "lhs"), new TypedData(rhs, DataTypes.UNSIGNED_LONG, "rhs")), - "DivUnsignedLongsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - equalTo(asLongUnsigned(lhsBI.divide(rhsBI).longValue())) - ); - }) - */ - )); - } + "lhs", + "rhs", + List.of(), + false + ) + ); + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "DivUnsignedLongsEvaluator", + "lhs", + "rhs", + (l, r) -> (((BigInteger) l).divide((BigInteger) r)), + DataTypes.UNSIGNED_LONG, + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + TestCaseSupplier.ulongCases(BigInteger.ONE, BigInteger.valueOf(Long.MAX_VALUE), true), + List.of(), + false + ) + ); - // run dedicated test to avoid the JVM optimized ArithmeticException that lacks a message - public void testDivisionByZero() { - DataType testCaseType = testCase.getData().get(0).type(); - List data = switch (testCaseType.typeName()) { - case "INTEGER" -> List.of(randomInt(), 0); - case "LONG" -> List.of(randomLong(), 0L); - case "UNSIGNED_LONG" -> List.of(randomLong(), ZERO_AS_UNSIGNED_LONG); - default -> null; - }; - if (data != null) { - var op = build(Source.EMPTY, field("lhs", testCaseType), field("rhs", testCaseType)); - try (Block block = evaluator(op).get(driverContext()).eval(row(data))) { - assertCriticalWarnings( - "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", - "Line -1:-1: java.lang.ArithmeticException: / by zero" + suppliers = errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers), DivTests::divErrorMessageString); + + // Divide by zero cases - all of these should warn and return null + TestCaseSupplier.NumericTypeTestConfigs typeStuff = new TestCaseSupplier.NumericTypeTestConfigs( + new TestCaseSupplier.NumericTypeTestConfig( + (Integer.MIN_VALUE >> 1) - 1, + (Integer.MAX_VALUE >> 1) - 1, + (l, r) -> null, + "DivIntsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + (Long.MIN_VALUE >> 1) - 1, + (Long.MAX_VALUE >> 1) - 1, + (l, r) -> null, + "DivLongsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY, + (l, r) -> null, + "DivDoublesEvaluator" + ) + ); + List numericTypes = List.of(DataTypes.INTEGER, DataTypes.LONG, DataTypes.DOUBLE); + + for (DataType lhsType : numericTypes) { + for (DataType rhsType : numericTypes) { + DataType expected = TestCaseSupplier.widen(lhsType, rhsType); + TestCaseSupplier.NumericTypeTestConfig expectedTypeStuff = typeStuff.get(expected); + BiFunction evaluatorToString = (lhs, rhs) -> expectedTypeStuff.evaluatorName() + + "[" + + "lhs" + + "=" + + TestCaseSupplier.getCastEvaluator("Attribute[channel=0]", lhs, expected) + + ", " + + "rhs" + + "=" + + TestCaseSupplier.getCastEvaluator("Attribute[channel=1]", rhs, expected) + + "]"; + TestCaseSupplier.casesCrossProduct( + (l1, r1) -> expectedTypeStuff.expected().apply((Number) l1, (Number) r1), + TestCaseSupplier.getSuppliersForNumericType(lhsType, expectedTypeStuff.min(), expectedTypeStuff.max(), true), + TestCaseSupplier.getSuppliersForNumericType(rhsType, 0, 0, true), + evaluatorToString, + List.of( + "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", + "Line -1:-1: java.lang.ArithmeticException: / by zero" + ), + suppliers, + expected, + false ); - assertNull(toJavaObject(block, 0)); } } - } - @Override - protected boolean rhsOk(Object o) { - if (o instanceof Number n) { - return n.doubleValue() != 0; - } - return true; - } - - @Override - protected Div build(Source source, Expression lhs, Expression rhs) { - return new Div(source, lhs, rhs); - } + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "DivUnsignedLongsEvaluator", + "lhs", + "rhs", + (l, r) -> null, + DataTypes.UNSIGNED_LONG, + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.ZERO, true), + List.of( + "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", + "Line -1:-1: java.lang.ArithmeticException: / by zero" + ), + false + ) + ); - @Override - protected double expectedValue(double lhs, double rhs) { - return lhs / rhs; + return parameterSuppliersFromTypedData(suppliers); } - @Override - protected int expectedValue(int lhs, int rhs) { - return lhs / rhs; - } + private static String divErrorMessageString(boolean includeOrdinal, List> validPerPosition, List types) { + try { + return typeErrorMessage(includeOrdinal, validPerPosition, types); + } catch (IllegalStateException e) { + // This means all the positional args were okay, so the expected error is from the combination + return "[/] has arguments with incompatible types [" + types.get(0).typeName() + "] and [" + types.get(1).typeName() + "]"; - @Override - protected long expectedValue(long lhs, long rhs) { - return lhs / rhs; + } } @Override - protected long expectedUnsignedLongValue(long lhs, long rhs) { - BigInteger lhsBI = unsignedLongAsBigInteger(lhs); - BigInteger rhsBI = unsignedLongAsBigInteger(rhs); - return asLongUnsigned(lhsBI.divide(rhsBI).longValue()); + protected Expression build(Source source, List args) { + return new Div(source, args.get(0), args.get(1)); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModTests.java index 5beaf0b782af7..03fbbf6a21ebe 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModTests.java @@ -10,7 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; -import org.elasticsearch.compute.data.Block; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.tree.Source; @@ -18,140 +18,149 @@ import org.elasticsearch.xpack.ql.type.DataTypes; import java.math.BigInteger; +import java.util.ArrayList; import java.util.List; +import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Supplier; -import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; -import static org.elasticsearch.xpack.ql.util.NumericUtils.ZERO_AS_UNSIGNED_LONG; -import static org.elasticsearch.xpack.ql.util.NumericUtils.asLongUnsigned; -import static org.elasticsearch.xpack.ql.util.NumericUtils.unsignedLongAsBigInteger; -import static org.hamcrest.Matchers.equalTo; - -public class ModTests extends AbstractArithmeticTestCase { +public class ModTests extends AbstractFunctionTestCase { public ModTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @ParametersFactory public static Iterable parameters() { - return parameterSuppliersFromTypedData(List.of(new TestCaseSupplier("Int % Int", () -> { - int lhs = randomInt(); - int rhs; - do { - rhs = randomInt(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.INTEGER, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.INTEGER, "rhs") - ), - "ModIntsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.INTEGER, - equalTo(lhs % rhs) - ); - }), new TestCaseSupplier("Long % Long", () -> { - long lhs = randomLong(); - long rhs; - do { - rhs = randomLong(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.LONG, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.LONG, "rhs") + List suppliers = new ArrayList<>(); + suppliers.addAll( + TestCaseSupplier.forBinaryWithWidening( + new TestCaseSupplier.NumericTypeTestConfigs( + new TestCaseSupplier.NumericTypeTestConfig( + (Integer.MIN_VALUE >> 1) - 1, + (Integer.MAX_VALUE >> 1) - 1, + (l, r) -> l.intValue() % r.intValue(), + "ModIntsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + (Long.MIN_VALUE >> 1) - 1, + (Long.MAX_VALUE >> 1) - 1, + (l, r) -> l.longValue() % r.longValue(), + "ModLongsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY, + (l, r) -> l.doubleValue() % r.doubleValue(), + "ModDoublesEvaluator" + ) ), - "ModLongsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.LONG, - equalTo(lhs % rhs) - ); - }), new TestCaseSupplier("Double % Double", () -> { - double lhs = randomDouble(); - double rhs; - do { - rhs = randomDouble(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.DOUBLE, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.DOUBLE, "rhs") - ), - "ModDoublesEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(lhs % rhs) - ); - })/*, new TestCaseSupplier("ULong % ULong", () -> { - // Ensure we don't have an overflow - long lhs = randomLong(); - long rhs; - do { - rhs = randomLong(); - } while (rhs == 0); - BigInteger lhsBI = unsignedLongAsBigInteger(lhs); - BigInteger rhsBI = unsignedLongAsBigInteger(rhs); - return new TestCase( - Source.EMPTY, - List.of(new TypedData(lhs, DataTypes.UNSIGNED_LONG, "lhs"), new TypedData(rhs, DataTypes.UNSIGNED_LONG, "rhs")), - "ModUnsignedLongsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - equalTo(asLongUnsigned(lhsBI.mod(rhsBI).longValue())) - ); - }) - */ - )); - } + "lhs", + "rhs", + List.of(), + false + ) + ); + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "ModUnsignedLongsEvaluator", + "lhs", + "rhs", + (l, r) -> (((BigInteger) l).mod((BigInteger) r)), + DataTypes.UNSIGNED_LONG, + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + TestCaseSupplier.ulongCases(BigInteger.ONE, BigInteger.valueOf(Long.MAX_VALUE), true), + List.of(), + false + ) + ); - // run dedicated test to avoid the JVM optimized ArithmeticException that lacks a message - public void testDivisionByZero() { - DataType testCaseType = testCase.getData().get(0).type(); - List data = switch (testCaseType.typeName()) { - case "INTEGER" -> List.of(randomInt(), 0); - case "LONG" -> List.of(randomLong(), 0L); - case "UNSIGNED_LONG" -> List.of(randomLong(), ZERO_AS_UNSIGNED_LONG); - default -> null; - }; - if (data != null) { - var op = build(Source.EMPTY, field("lhs", testCaseType), field("rhs", testCaseType)); - try (Block block = evaluator(op).get(driverContext()).eval(row(data))) { - assertCriticalWarnings( - "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", - "Line -1:-1: java.lang.ArithmeticException: / by zero" + suppliers = errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers), ModTests::modErrorMessageString); + + // Divide by zero cases - all of these should warn and return null + TestCaseSupplier.NumericTypeTestConfigs typeStuff = new TestCaseSupplier.NumericTypeTestConfigs( + new TestCaseSupplier.NumericTypeTestConfig( + (Integer.MIN_VALUE >> 1) - 1, + (Integer.MAX_VALUE >> 1) - 1, + (l, r) -> null, + "ModIntsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + (Long.MIN_VALUE >> 1) - 1, + (Long.MAX_VALUE >> 1) - 1, + (l, r) -> null, + "ModLongsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY, + (l, r) -> null, + "ModDoublesEvaluator" + ) + ); + List numericTypes = List.of(DataTypes.INTEGER, DataTypes.LONG, DataTypes.DOUBLE); + + for (DataType lhsType : numericTypes) { + for (DataType rhsType : numericTypes) { + DataType expected = TestCaseSupplier.widen(lhsType, rhsType); + TestCaseSupplier.NumericTypeTestConfig expectedTypeStuff = typeStuff.get(expected); + BiFunction evaluatorToString = (lhs, rhs) -> expectedTypeStuff.evaluatorName() + + "[" + + "lhs" + + "=" + + TestCaseSupplier.getCastEvaluator("Attribute[channel=0]", lhs, expected) + + ", " + + "rhs" + + "=" + + TestCaseSupplier.getCastEvaluator("Attribute[channel=1]", rhs, expected) + + "]"; + TestCaseSupplier.casesCrossProduct( + (l1, r1) -> expectedTypeStuff.expected().apply((Number) l1, (Number) r1), + TestCaseSupplier.getSuppliersForNumericType(lhsType, expectedTypeStuff.min(), expectedTypeStuff.max(), true), + TestCaseSupplier.getSuppliersForNumericType(rhsType, 0, 0, true), + evaluatorToString, + List.of( + "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", + "Line -1:-1: java.lang.ArithmeticException: / by zero" + ), + suppliers, + expected, + false ); - assertNull(toJavaObject(block, 0)); } } - } - @Override - protected boolean rhsOk(Object o) { - if (o instanceof Number n) { - return n.doubleValue() != 0; - } - return true; - } - - @Override - protected Mod build(Source source, Expression lhs, Expression rhs) { - return new Mod(source, lhs, rhs); - } + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "ModUnsignedLongsEvaluator", + "lhs", + "rhs", + (l, r) -> null, + DataTypes.UNSIGNED_LONG, + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.ZERO, true), + List.of( + "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", + "Line -1:-1: java.lang.ArithmeticException: / by zero" + ), + false + ) + ); - @Override - protected double expectedValue(double lhs, double rhs) { - return lhs % rhs; + return parameterSuppliersFromTypedData(suppliers); } - @Override - protected int expectedValue(int lhs, int rhs) { - return lhs % rhs; - } + private static String modErrorMessageString(boolean includeOrdinal, List> validPerPosition, List types) { + try { + return typeErrorMessage(includeOrdinal, validPerPosition, types); + } catch (IllegalStateException e) { + // This means all the positional args were okay, so the expected error is from the combination + return "[%] has arguments with incompatible types [" + types.get(0).typeName() + "] and [" + types.get(1).typeName() + "]"; - @Override - protected long expectedValue(long lhs, long rhs) { - return lhs % rhs; + } } @Override - protected long expectedUnsignedLongValue(long lhs, long rhs) { - BigInteger lhsBI = unsignedLongAsBigInteger(lhs); - BigInteger rhsBI = unsignedLongAsBigInteger(rhs); - return asLongUnsigned(lhsBI.mod(rhsBI).longValue()); + protected Expression build(Source source, List args) { + return new Mod(source, args.get(0), args.get(1)); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 80deb0ea83d86..f6aeb89faff0e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -657,11 +657,9 @@ public void testOutOfRangeFilterPushdown() { new OutOfRangeTestCase("byte", smallerThanInteger, largerThanInteger), new OutOfRangeTestCase("short", smallerThanInteger, largerThanInteger), new OutOfRangeTestCase("integer", smallerThanInteger, largerThanInteger), - new OutOfRangeTestCase("long", smallerThanLong, largerThanLong), + new OutOfRangeTestCase("long", smallerThanLong, largerThanLong) // TODO: add unsigned_long https://github.com/elastic/elasticsearch/issues/102935 // TODO: add half_float, float https://github.com/elastic/elasticsearch/issues/100130 - new OutOfRangeTestCase("double", "-1.0/0.0", "1.0/0.0"), - new OutOfRangeTestCase("scaled_float", "-1.0/0.0", "1.0/0.0") ); final String LT = "<"; @@ -678,8 +676,7 @@ public void testOutOfRangeFilterPushdown() { GT + testCase.tooLow, GTE + testCase.tooLow, NEQ + testCase.tooHigh, - NEQ + testCase.tooLow, - NEQ + "0.0/0.0" + NEQ + testCase.tooLow ); List alwaysFalsePredicates = List.of( LT + testCase.tooLow, @@ -687,12 +684,7 @@ public void testOutOfRangeFilterPushdown() { GT + testCase.tooHigh, GTE + testCase.tooHigh, EQ + testCase.tooHigh, - EQ + testCase.tooLow, - LT + "0.0/0.0", - LTE + "0.0/0.0", - GT + "0.0/0.0", - GTE + "0.0/0.0", - EQ + "0.0/0.0" + EQ + testCase.tooLow ); for (String truePredicate : trueForSingleValuesPredicates) { @@ -700,6 +692,7 @@ public void testOutOfRangeFilterPushdown() { var query = "from test | where " + comparison; Source expectedSource = new Source(1, 18, comparison); + logger.info("Query: " + query); EsQueryExec actualQueryExec = doTestOutOfRangeFilterPushdown(query, allTypeMappingAnalyzer); assertThat(actualQueryExec.query(), is(instanceOf(SingleValueQuery.Builder.class))); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index cfb21ad5a1d94..21bd73c3821c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -24,6 +24,9 @@ import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; @@ -202,6 +205,30 @@ public static List getNamedWriteables() { new NamedWriteableRegistry.Entry(TaskSettings.class, CohereRerankTaskSettings.NAME, CohereRerankTaskSettings::new) ); + // Azure OpenAI + namedWriteables.add( + new NamedWriteableRegistry.Entry( + AzureOpenAiSecretSettings.class, + AzureOpenAiSecretSettings.NAME, + AzureOpenAiSecretSettings::new + ) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AzureOpenAiEmbeddingsServiceSettings.NAME, + AzureOpenAiEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + AzureOpenAiEmbeddingsTaskSettings.NAME, + AzureOpenAiEmbeddingsTaskSettings::new + ) + ); + return namedWriteables; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 3b2c0b3c4ac3e..f41f9a97cec18 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -56,6 +56,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; @@ -176,6 +177,7 @@ public List getInferenceServiceFactories() { context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), context -> new CohereService(httpFactory.get(), serviceComponents.get()), + context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java new file mode 100644 index 0000000000000..39eaaceae08bc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.azureopenai; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.util.Map; +import java.util.Objects; + +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type. + */ +public class AzureOpenAiActionCreator implements AzureOpenAiActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + + public AzureOpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map taskSettings) { + var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, taskSettings); + return new AzureOpenAiEmbeddingsAction(sender, overriddenModel, serviceComponents); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java new file mode 100644 index 0000000000000..49d1ce61b12dd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.azureopenai; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.util.Map; + +public interface AzureOpenAiActionVisitor { + ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsAction.java new file mode 100644 index 0000000000000..a682ad2bb23d5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsAction.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.azureopenai; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiEmbeddingsExecutableRequestCreator; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AzureOpenAiEmbeddingsAction implements ExecutableAction { + + private final String errorMessage; + private final AzureOpenAiEmbeddingsExecutableRequestCreator requestCreator; + private final Sender sender; + + public AzureOpenAiEmbeddingsAction(Sender sender, AzureOpenAiEmbeddingsModel model, ServiceComponents serviceComponents) { + Objects.requireNonNull(serviceComponents); + Objects.requireNonNull(model); + this.sender = Objects.requireNonNull(sender); + requestCreator = new AzureOpenAiEmbeddingsExecutableRequestCreator(model, serviceComponents.truncator()); + errorMessage = constructFailedToSendRequestMessage(model.getUri(), "Azure OpenAI embeddings"); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); + + sender.send(requestCreator, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, errorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiAccount.java new file mode 100644 index 0000000000000..db1f91cc751ee --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiAccount.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.azureopenai; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.util.Objects; + +public record AzureOpenAiAccount( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable SecureString apiKey, + @Nullable SecureString entraId +) { + + public AzureOpenAiAccount { + Objects.requireNonNull(resourceName); + Objects.requireNonNull(deploymentId); + Objects.requireNonNull(apiVersion); + Objects.requireNonNullElse(apiKey, entraId); + } + + public static AzureOpenAiAccount fromModel(AzureOpenAiEmbeddingsModel model) { + return new AzureOpenAiAccount( + model.getServiceSettings().resourceName(), + model.getServiceSettings().deploymentId(), + model.getServiceSettings().apiVersion(), + model.getSecretSettings().apiKey(), + model.getSecretSettings().entraId() + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java new file mode 100644 index 0000000000000..2f72088327468 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.azureopenai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.openai.OpenAiResponseHandler; +import org.elasticsearch.xpack.inference.external.request.Request; + +import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown; + +public class AzureOpenAiResponseHandler extends OpenAiResponseHandler { + + /** + * These headers for Azure OpenAi are mostly the same as the OpenAi ones with the major exception + * that there is no information returned about the request limit or the tokens limit + * + * Microsoft does not seem to have any published information in their docs about this, but more + * information can be found in the following Medium article and accompanying code: + * - https://pablo-81685.medium.com/azure-openais-api-headers-unpacked-6dbe881e732a + * - https://github.com/pablosalvador10/gbbai-azure-ai-aoai + */ + static final String REMAINING_REQUESTS = "x-ratelimit-remaining-requests"; + // The remaining number of tokens that are permitted before exhausting the rate limit. + static final String REMAINING_TOKENS = "x-ratelimit-remaining-tokens"; + + public AzureOpenAiResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction); + } + + @Override + protected RetryException buildExceptionHandling429(Request request, HttpResult result) { + return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result)); + } + + static String buildRateLimitErrorMessage(HttpResult result) { + var response = result.response(); + var remainingTokens = getFirstHeaderOrUnknown(response, REMAINING_TOKENS); + var remainingRequests = getFirstHeaderOrUnknown(response, REMAINING_REQUESTS); + var usageMessage = Strings.format("Remaining tokens [%s]. Remaining requests [%s].", remainingTokens, remainingRequests); + + return RATE_LIMIT + ". " + usageMessage; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsExecutableRequestCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsExecutableRequestCreator.java new file mode 100644 index 0000000000000..b3f53d5f3f236 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsExecutableRequestCreator.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiAccount; +import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class AzureOpenAiEmbeddingsExecutableRequestCreator implements ExecutableRequestCreator { + + private static final Logger logger = LogManager.getLogger(AzureOpenAiEmbeddingsExecutableRequestCreator.class); + + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); + + private static ResponseHandler createEmbeddingsHandler() { + return new AzureOpenAiResponseHandler("azure openai text embedding", OpenAiEmbeddingsResponseEntity::fromResponse); + } + + private final Truncator truncator; + private final AzureOpenAiEmbeddingsModel model; + private final AzureOpenAiAccount account; + + public AzureOpenAiEmbeddingsExecutableRequestCreator(AzureOpenAiEmbeddingsModel model, Truncator truncator) { + this.model = Objects.requireNonNull(model); + this.account = AzureOpenAiAccount.fromModel(model); + this.truncator = Objects.requireNonNull(truncator); + } + + @Override + public Runnable create( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + HttpClientContext context, + ActionListener listener + ) { + var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, account, truncatedInput, model); + return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java index 5924356e610a3..7ca7cf0422fd9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java @@ -18,7 +18,7 @@ public OpenAiChatCompletionResponseHandler(String requestType, ResponseParser pa } @Override - RetryException buildExceptionHandling429(Request request, HttpResult result) { + protected RetryException buildExceptionHandling429(Request request, HttpResult result) { // We don't retry, if the chat completion input is too large return new RetryException(false, buildError(RATE_LIMIT, request, result)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java index db7ca8d6bdc63..c23b94351c187 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java @@ -83,7 +83,7 @@ void checkForFailureStatusCode(Request request, HttpResult result) throws RetryE } } - RetryException buildExceptionHandling429(Request request, HttpResult result) { + protected RetryException buildExceptionHandling429(Request request, HttpResult result) { return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java new file mode 100644 index 0000000000000..c943d5f54b4ff --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiAccount; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; +import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.API_KEY; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.ENTRA_ID; + +public class AzureOpenAiEmbeddingsRequest implements AzureOpenAiRequest { + private static final String MISSING_AUTHENTICATION_ERROR_MESSAGE = + "The request does not have any authentication methods set. One of [%s] or [%s] is required."; + + private final Truncator truncator; + private final AzureOpenAiAccount account; + private final Truncator.TruncationResult truncationResult; + private final URI uri; + private final AzureOpenAiEmbeddingsModel model; + + public AzureOpenAiEmbeddingsRequest( + Truncator truncator, + AzureOpenAiAccount account, + Truncator.TruncationResult input, + AzureOpenAiEmbeddingsModel model + ) { + this.truncator = Objects.requireNonNull(truncator); + this.account = Objects.requireNonNull(account); + this.truncationResult = Objects.requireNonNull(input); + this.model = Objects.requireNonNull(model); + this.uri = model.getUri(); + } + + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(uri); + + String requestEntity = Strings.toString( + new AzureOpenAiEmbeddingsRequestEntity( + truncationResult.input(), + model.getTaskSettings().user(), + model.getServiceSettings().dimensions(), + model.getServiceSettings().dimensionsSetByUser() + ) + ); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + + var entraId = model.getSecretSettings().entraId(); + var apiKey = model.getSecretSettings().apiKey(); + + if (entraId != null && entraId.isEmpty() == false) { + httpPost.setHeader(createAuthBearerHeader(entraId)); + } else if (apiKey != null && apiKey.isEmpty() == false) { + httpPost.setHeader(new BasicHeader(API_KEY_HEADER, apiKey.toString())); + } else { + // should never happen due to the checks on the secret settings, but just in case + ValidationException validationException = new ValidationException(); + validationException.addValidationError(Strings.format(MISSING_AUTHENTICATION_ERROR_MESSAGE, API_KEY, ENTRA_ID)); + throw validationException; + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + + return new AzureOpenAiEmbeddingsRequest(truncator, account, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..2a9a93e99d4e4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntity.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record AzureOpenAiEmbeddingsRequestEntity( + List input, + @Nullable String user, + @Nullable Integer dimensions, + boolean dimensionsSetByUser +) implements ToXContentObject { + + private static final String INPUT_FIELD = "input"; + private static final String USER_FIELD = "user"; + private static final String DIMENSIONS_FIELD = "dimensions"; + + public AzureOpenAiEmbeddingsRequestEntity { + Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUT_FIELD, input); + + if (user != null) { + builder.field(USER_FIELD, user); + } + + if (dimensionsSetByUser && dimensions != null) { + builder.field(DIMENSIONS_FIELD, dimensions); + } + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java new file mode 100644 index 0000000000000..edb7c70b3903e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +import org.elasticsearch.xpack.inference.external.request.Request; + +public interface AzureOpenAiRequest extends Request {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java new file mode 100644 index 0000000000000..16a02a4c06c1c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +public class AzureOpenAiUtils { + + public static final String HOST_SUFFIX = "openai.azure.com"; + public static final String OPENAI_PATH = "openai"; + public static final String DEPLOYMENTS_PATH = "deployments"; + public static final String EMBEDDINGS_PATH = "embeddings"; + public static final String API_VERSION_PARAMETER = "api-version"; + public static final String API_KEY_HEADER = "api-key"; + + private AzureOpenAiUtils() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 72808b6de8132..1631755149578 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -139,6 +139,10 @@ public static String invalidValue(String settingName, String scope, String inval ); } + public static String invalidSettingError(String settingName, String scope) { + return Strings.format("[%s] does not allow the setting [%s]", scope, settingName); + } + // TODO improve URI validation logic public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) { try { @@ -186,6 +190,21 @@ public static SecureString extractRequiredSecureString( return new SecureString(Objects.requireNonNull(requiredField).toCharArray()); } + public static SecureString extractOptionalSecureString( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + String optionalField = extractOptionalString(map, settingName, scope, validationException); + + if (validationException.validationErrors().isEmpty() == false || optionalField == null) { + return null; + } + + return new SecureString(optionalField.toCharArray()); + } + public static SimilarityMeasure extractSimilarity(Map map, String scope, ValidationException validationException) { return extractOptionalEnum( map, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java new file mode 100644 index 0000000000000..66070cab0e517 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor; + +import java.net.URI; +import java.util.Map; + +public abstract class AzureOpenAiModel extends Model { + + protected URI uri; + + public AzureOpenAiModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + protected AzureOpenAiModel(AzureOpenAiModel model, TaskSettings taskSettings) { + super(model, taskSettings); + this.uri = model.getUri(); + } + + protected AzureOpenAiModel(AzureOpenAiModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + this.uri = model.getUri(); + } + + public abstract ExecutableAction accept(AzureOpenAiActionVisitor creator, Map taskSettings); + + public URI getUri() { + return uri; + } + + // Needed for testing + public void setUri(URI newUri) { + this.uri = newUri; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java new file mode 100644 index 0000000000000..f871fe6c080a1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalSecureString; + +public record AzureOpenAiSecretSettings(@Nullable SecureString apiKey, @Nullable SecureString entraId) implements SecretSettings { + + public static final String NAME = "azure_openai_secret_settings"; + public static final String API_KEY = "api_key"; + public static final String ENTRA_ID = "entra_id"; + + public static AzureOpenAiSecretSettings fromMap(@Nullable Map map) { + if (map == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + SecureString secureApiToken = extractOptionalSecureString(map, API_KEY, ModelSecrets.SECRET_SETTINGS, validationException); + SecureString secureEntraId = extractOptionalSecureString(map, ENTRA_ID, ModelSecrets.SECRET_SETTINGS, validationException); + + if (secureApiToken == null && secureEntraId == null) { + validationException.addValidationError( + format("[secret_settings] must have either the [%s] or the [%s] key set", API_KEY, ENTRA_ID) + ); + } + + if (secureApiToken != null && secureEntraId != null) { + validationException.addValidationError( + format("[secret_settings] must have only one of the [%s] or the [%s] key set", API_KEY, ENTRA_ID) + ); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureOpenAiSecretSettings(secureApiToken, secureEntraId); + } + + public AzureOpenAiSecretSettings { + Objects.requireNonNullElse(apiKey, entraId); + } + + public AzureOpenAiSecretSettings(StreamInput in) throws IOException { + this(in.readOptionalSecureString(), in.readOptionalSecureString()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (apiKey != null) { + builder.field(API_KEY, apiKey.toString()); + } + + if (entraId != null) { + builder.field(ENTRA_ID, entraId.toString()); + } + + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalSecureString(apiKey); + out.writeOptionalSecureString(entraId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java new file mode 100644 index 0000000000000..f20c262053d10 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -0,0 +1,296 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class AzureOpenAiService extends SenderService { + public static final String NAME = "azureopenai"; + + public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + AzureOpenAiModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + private static AzureOpenAiModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + private static AzureOpenAiModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + if (taskType == TaskType.TEXT_EMBEDDING) { + return new AzureOpenAiEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + } + + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + + @Override + public AzureOpenAiModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + protected void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof AzureOpenAiModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model; + var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents()); + + var action = azureOpenAiModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } + + @Override + protected void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("Azure OpenAI service does not support inference with query input"); + } + + @Override + protected void doChunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + ActionListener inferListener = listener.delegateFailureAndWrap( + (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response)) + ); + + doInfer(model, input, taskSettings, inputType, timeout, inferListener); + } + + private static List translateToChunkedResults( + List inputs, + InferenceServiceResults inferenceResults + ) { + if (inferenceResults instanceof TextEmbeddingResults textEmbeddingResults) { + return ChunkedTextEmbeddingResults.of(inputs, textEmbeddingResults); + } else if (inferenceResults instanceof ErrorInferenceResults error) { + return List.of(new ErrorChunkedInferenceResults(error.getException())); + } else { + throw createInvalidChunkedResultException(inferenceResults.getWriteableName()); + } + } + + /** + * For text embedding models get the embedding size and + * update the service settings. + * + * @param model The new model + * @param listener The listener + */ + @Override + public void checkModelConfig(Model model, ActionListener listener) { + if (model instanceof AzureOpenAiEmbeddingsModel embeddingsModel) { + ServiceUtils.getEmbeddingSize( + model, + this, + listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size))) + ); + } else { + listener.onResponse(model); + } + } + + private AzureOpenAiEmbeddingsModel updateModelWithEmbeddingDetails(AzureOpenAiEmbeddingsModel model, int embeddingSize) { + if (model.getServiceSettings().dimensionsSetByUser() + && model.getServiceSettings().dimensions() != null + && model.getServiceSettings().dimensions() != embeddingSize) { + throw new ElasticsearchStatusException( + Strings.format( + "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. " + + "Please recreate the [%s] configuration with the correct dimensions", + embeddingSize, + model.getServiceSettings().dimensions(), + model.getConfigurations().getInferenceEntityId() + ), + RestStatus.BAD_REQUEST + ); + } + + var similarityFromModel = model.getServiceSettings().similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + AzureOpenAiEmbeddingsServiceSettings serviceSettings = new AzureOpenAiEmbeddingsServiceSettings( + model.getServiceSettings().resourceName(), + model.getServiceSettings().deploymentId(), + model.getServiceSettings().apiVersion(), + embeddingSize, + model.getServiceSettings().dimensionsSetByUser(), + model.getServiceSettings().maxInputTokens(), + similarityToUse + ); + + return new AzureOpenAiEmbeddingsModel(model, serviceSettings); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceFields.java new file mode 100644 index 0000000000000..a3786ff27224b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceFields.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +public class AzureOpenAiServiceFields { + + public static final String RESOURCE_NAME = "resource_name"; + public static final String DEPLOYMENT_ID = "deployment_id"; + public static final String API_VERSION = "api_version"; + public static final String USER = "user"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java new file mode 100644 index 0000000000000..4c3272013f0e2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +import static org.elasticsearch.core.Strings.format; + +public class AzureOpenAiEmbeddingsModel extends AzureOpenAiModel { + + public static AzureOpenAiEmbeddingsModel of(AzureOpenAiEmbeddingsModel model, Map taskSettings) { + if (taskSettings == null || taskSettings.isEmpty()) { + return model; + } + + var requestTaskSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(taskSettings); + return new AzureOpenAiEmbeddingsModel(model, AzureOpenAiEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public AzureOpenAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + AzureOpenAiEmbeddingsServiceSettings.fromMap(serviceSettings, context), + AzureOpenAiEmbeddingsTaskSettings.fromMap(taskSettings), + AzureOpenAiSecretSettings.fromMap(secrets) + ); + } + + // Should only be used directly for testing + AzureOpenAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + AzureOpenAiEmbeddingsServiceSettings serviceSettings, + AzureOpenAiEmbeddingsTaskSettings taskSettings, + @Nullable AzureOpenAiSecretSettings secrets + ) { + super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets)); + try { + this.uri = getEmbeddingsUri(serviceSettings.resourceName(), serviceSettings.deploymentId(), serviceSettings.apiVersion()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + private AzureOpenAiEmbeddingsModel(AzureOpenAiEmbeddingsModel originalModel, AzureOpenAiEmbeddingsTaskSettings taskSettings) { + super(originalModel, taskSettings); + } + + public AzureOpenAiEmbeddingsModel(AzureOpenAiEmbeddingsModel originalModel, AzureOpenAiEmbeddingsServiceSettings serviceSettings) { + super(originalModel, serviceSettings); + } + + @Override + public AzureOpenAiEmbeddingsServiceSettings getServiceSettings() { + return (AzureOpenAiEmbeddingsServiceSettings) super.getServiceSettings(); + } + + @Override + public AzureOpenAiEmbeddingsTaskSettings getTaskSettings() { + return (AzureOpenAiEmbeddingsTaskSettings) super.getTaskSettings(); + } + + @Override + public AzureOpenAiSecretSettings getSecretSettings() { + return (AzureOpenAiSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(AzureOpenAiActionVisitor creator, Map taskSettings) { + return creator.create(this, taskSettings); + } + + public static URI getEmbeddingsUri(String resourceName, String deploymentId, String apiVersion) throws URISyntaxException { + String hostname = format("%s.%s", resourceName, AzureOpenAiUtils.HOST_SUFFIX); + return new URIBuilder().setScheme("https") + .setHost(hostname) + .setPathSegments( + AzureOpenAiUtils.OPENAI_PATH, + AzureOpenAiUtils.DEPLOYMENTS_PATH, + deploymentId, + AzureOpenAiUtils.EMBEDDINGS_PATH + ) + .addParameter(AzureOpenAiUtils.API_VERSION_PARAMETER, apiVersion) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java new file mode 100644 index 0000000000000..dc7012203a9c8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER; + +/** + * This class handles extracting Azure OpenAI task settings from a request. The difference between this class and + * {@link AzureOpenAiEmbeddingsTaskSettings} is that this class considers all fields as optional. It will not throw an error if a field + * is missing. This allows overriding persistent task settings. + * @param user a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse + */ +public record AzureOpenAiEmbeddingsRequestTaskSettings(@Nullable String user) { + private static final Logger logger = LogManager.getLogger(AzureOpenAiEmbeddingsRequestTaskSettings.class); + + public static final AzureOpenAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new AzureOpenAiEmbeddingsRequestTaskSettings(null); + + /** + * Extracts the task settings from a map. All settings are considered optional and the absence of a setting + * does not throw an error. + * + * @param map the settings received from a request + * @return a {@link AzureOpenAiEmbeddingsRequestTaskSettings} + */ + public static AzureOpenAiEmbeddingsRequestTaskSettings fromMap(Map map) { + if (map.isEmpty()) { + return AzureOpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureOpenAiEmbeddingsRequestTaskSettings(user); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..c3d9e3eb69a5d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java @@ -0,0 +1,282 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME; + +/** + * Defines the service settings for interacting with OpenAI's text embedding models. + */ +public class AzureOpenAiEmbeddingsServiceSettings implements ServiceSettings { + + public static final String NAME = "azure_openai_embeddings_service_settings"; + + static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + + public static AzureOpenAiEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var settings = fromMap(map, validationException, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureOpenAiEmbeddingsServiceSettings(settings); + } + + private static CommonFields fromMap( + Map map, + ValidationException validationException, + ConfigurationParseContext context + ) { + String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException); + String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + + Boolean dimensionsSetByUser = extractOptionalBoolean( + map, + DIMENSIONS_SET_BY_USER, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + switch (context) { + case REQUEST -> { + if (dimensionsSetByUser != null) { + validationException.addValidationError( + ServiceUtils.invalidSettingError(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + dimensionsSetByUser = dims != null; + } + case PERSISTENT -> { + if (dimensionsSetByUser == null) { + validationException.addValidationError( + ServiceUtils.missingSettingErrorMsg(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + } + } + + return new CommonFields( + resourceName, + deploymentId, + apiVersion, + dims, + Boolean.TRUE.equals(dimensionsSetByUser), + maxTokens, + similarity + ); + } + + private record CommonFields( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable Integer dimensions, + Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable SimilarityMeasure similarity + ) {} + + private final String resourceName; + private final String deploymentId; + private final String apiVersion; + private final Integer dimensions; + private final Boolean dimensionsSetByUser; + private final Integer maxInputTokens; + private final SimilarityMeasure similarity; + + public AzureOpenAiEmbeddingsServiceSettings( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable Integer dimensions, + Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable SimilarityMeasure similarity + ) { + this.resourceName = resourceName; + this.deploymentId = deploymentId; + this.apiVersion = apiVersion; + this.dimensions = dimensions; + this.dimensionsSetByUser = Objects.requireNonNull(dimensionsSetByUser); + this.maxInputTokens = maxInputTokens; + this.similarity = similarity; + } + + public AzureOpenAiEmbeddingsServiceSettings(StreamInput in) throws IOException { + resourceName = in.readString(); + deploymentId = in.readString(); + apiVersion = in.readString(); + dimensions = in.readOptionalVInt(); + dimensionsSetByUser = in.readBoolean(); + maxInputTokens = in.readOptionalVInt(); + similarity = in.readOptionalEnum(SimilarityMeasure.class); + } + + private AzureOpenAiEmbeddingsServiceSettings(CommonFields fields) { + this( + fields.resourceName, + fields.deploymentId, + fields.apiVersion, + fields.dimensions, + fields.dimensionsSetByUser, + fields.maxInputTokens, + fields.similarity + ); + } + + public String resourceName() { + return resourceName; + } + + public String deploymentId() { + return deploymentId; + } + + public String apiVersion() { + return apiVersion; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public Boolean dimensionsSetByUser() { + return dimensionsSetByUser; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + + builder.endObject(); + return builder; + } + + private void toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(RESOURCE_NAME, resourceName); + builder.field(DEPLOYMENT_ID, deploymentId); + builder.field(API_VERSION, apiVersion); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return (builder, params) -> { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.endObject(); + return builder; + }; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(resourceName); + out.writeString(deploymentId); + out.writeString(apiVersion); + out.writeOptionalVInt(dimensions); + out.writeBoolean(dimensionsSetByUser); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AzureOpenAiEmbeddingsServiceSettings that = (AzureOpenAiEmbeddingsServiceSettings) o; + + return Objects.equals(resourceName, that.resourceName) + && Objects.equals(deploymentId, that.deploymentId) + && Objects.equals(apiVersion, that.apiVersion) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity); + } + + @Override + public int hashCode() { + return Objects.hash(resourceName, deploymentId, apiVersion, dimensions, dimensionsSetByUser, maxInputTokens, similarity); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java new file mode 100644 index 0000000000000..49329a55a18ef --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER; + +/** + * Defines the task settings for the openai service. + * + * User is an optional unique identifier representing the end-user, which can help OpenAI to monitor and detect abuse + * see the openai docs for more details + */ +public class AzureOpenAiEmbeddingsTaskSettings implements TaskSettings { + + public static final String NAME = "azure_openai_embeddings_task_settings"; + + public static AzureOpenAiEmbeddingsTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureOpenAiEmbeddingsTaskSettings(user); + } + + /** + * Creates a new {@link AzureOpenAiEmbeddingsTaskSettings} object by overriding the values in originalSettings with the ones + * passed in via requestSettings if the fields are not null. + * @param originalSettings the original {@link AzureOpenAiEmbeddingsTaskSettings} from the inference entity configuration from storage + * @param requestSettings the {@link AzureOpenAiEmbeddingsTaskSettings} from the request + * @return a new {@link AzureOpenAiEmbeddingsTaskSettings} + */ + public static AzureOpenAiEmbeddingsTaskSettings of( + AzureOpenAiEmbeddingsTaskSettings originalSettings, + AzureOpenAiEmbeddingsRequestTaskSettings requestSettings + ) { + var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user(); + return new AzureOpenAiEmbeddingsTaskSettings(userToUse); + } + + private final String user; + + public AzureOpenAiEmbeddingsTaskSettings(@Nullable String user) { + this.user = user; + } + + public AzureOpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException { + this.user = in.readOptionalString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (user != null) { + builder.field(USER, user); + } + builder.endObject(); + return builder; + } + + public String user() { + return user; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(user); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AzureOpenAiEmbeddingsTaskSettings that = (AzureOpenAiEmbeddingsTaskSettings) o; + return Objects.equals(user, that.user); + } + + @Override + public int hashCode() { + return Objects.hash(user); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java new file mode 100644 index 0000000000000..4bdba67beec17 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -0,0 +1,454 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.azureopenai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockRequest; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsRequestTaskSettingsTests.getRequestTaskSettingsMap; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class AzureOpenAiActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("resource", "deployment", "apiversion", "orig_user", "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user"); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(1)); + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("abc"), "overridden_user"); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap(null); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(1)); + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("abc"), null); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException { + // timeout as zero for no retries + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data_does_not_exist": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user"); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer))) + ); + assertThat(thrownException.getCause().getMessage(), is("Failed to find required field [data] in OpenAI embeddings response")); + + assertThat(webServer.requests(), hasSize(1)); + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("abc"), "overridden_user"); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + // note - there is no complete documentation on Azure's error messages + // but this error and response has been verified manually via CURL + var contentTooLargeErrorMessage = + "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" + + "0 for the completion). Please reduce your prompt; or completion length."; + + String responseJsonContentTooLarge = Strings.format(""" + { + "error": { + "message": "%s", + "type": "invalid_request_error", + "param": null, + "code": null + } + } + """, contentTooLargeErrorMessage); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user"); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abcd")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(2)); + { + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user"); + } + { + validateRequestWithApiKey(webServer.requests().get(1), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(1).getBody()); + validateRequestMapWithUser(requestMap, List.of("ab"), "overridden_user"); + } + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + // note - there is no complete documentation on Azure's error messages + // but this error and response has been verified manually via CURL + var contentTooLargeErrorMessage = + "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" + + "0 for the completion). Please reduce your prompt; or completion length."; + + String responseJsonContentTooLarge = Strings.format(""" + { + "error": { + "message": "%s", + "type": "invalid_request_error", + "param": null, + "code": null + } + } + """, contentTooLargeErrorMessage); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJsonContentTooLarge)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user"); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abcd")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(2)); + { + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user"); + } + { + validateRequestWithApiKey(webServer.requests().get(1), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(1).getBody()); + validateRequestMapWithUser(requestMap, List.of("ab"), "overridden_user"); + } + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testExecute_TruncatesInputBeforeSending() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + // truncated to 1 token = 3 characters + var model = createModel("resource", "deployment", "apiversion", null, false, 1, null, null, "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user"); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("super long input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(1)); + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("sup"), "overridden_user"); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + private void validateRequestMapWithUser(Map requestMap, List input, @Nullable String user) { + var expectedSize = user == null ? 1 : 2; + + assertThat(requestMap.size(), is(expectedSize)); + assertThat(requestMap.get("input"), is(input)); + + if (user != null) { + assertThat(requestMap.get("user"), is(user)); + } + } + + private void validateRequestWithApiKey(MockRequest request, String apiKey) { + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(request.getHeader(AzureOpenAiUtils.API_KEY_HEADER), equalTo(apiKey)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java new file mode 100644 index 0000000000000..e8eac1a13b180 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java @@ -0,0 +1,219 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.azureopenai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class AzureOpenAiEmbeddingsActionTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = new HttpRequestSender.Factory( + ServiceComponentsTests.createWithEmptySettings(threadPool), + clientManager, + mockClusterServiceEmpty() + ); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(AzureOpenAiUtils.API_KEY_HEADER), equalTo("apikey")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), is(List.of("abc"))); + assertThat(requestMap.get("user"), is("user")); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer)))); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled_WhenUrlIsNull() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer)))); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer)))); + } + + private AzureOpenAiEmbeddingsAction createAction( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable String user, + String apiKey, + Sender sender, + String inferenceEntityId + ) { + AzureOpenAiEmbeddingsModel model = null; + try { + model = createModel(resourceName, deploymentId, apiVersion, user, apiKey, null, inferenceEntityId); + model.setUri(new URI(getUrl(webServer))); + var action = new AzureOpenAiEmbeddingsAction(sender, model, createWithEmptySettings(threadPool)); + return action; + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandlerTests.java new file mode 100644 index 0000000000000..b18d9d76651d5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandlerTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.azureopenai; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AzureOpenAiResponseHandlerTests extends ESTestCase { + + public void testBuildRateLimitErrorMessage() { + int statusCode = 429; + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + var httpResult = new HttpResult(response, new byte[] {}); + + { + when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS)).thenReturn( + new BasicHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS, "2999") + ); + when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS)).thenReturn( + new BasicHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS, "99800") + ); + + var error = AzureOpenAiResponseHandler.buildRateLimitErrorMessage(httpResult); + assertThat(error, containsString("Remaining tokens [99800]. Remaining requests [2999]")); + } + + { + when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS)).thenReturn(null); + var error = AzureOpenAiResponseHandler.buildRateLimitErrorMessage(httpResult); + assertThat(error, containsString("Remaining tokens [unknown]. Remaining requests [2999]")); + } + + { + when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS)).thenReturn( + new BasicHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS, "2999") + ); + when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS)).thenReturn(null); + var error = AzureOpenAiResponseHandler.buildRateLimitErrorMessage(httpResult); + assertThat(error, containsString("Remaining tokens [unknown]. Remaining requests [2999]")); + } + } + + private static HttpResult createContentTooLargeResult(int statusCode) { + return createResult( + statusCode, + "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" + + "0 for the completion). Please reduce your prompt; or completion length." + ); + } + + private static HttpResult createResult(int statusCode, String message) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + + String responseJson = Strings.format(""" + { + "error": { + "message": "%s", + "type": "content_too_large", + "param": null, + "code": null + } + } + """, message); + + return new HttpResult(httpResponse, responseJson.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..14283ed53eed9 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntityTests.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class AzureOpenAiEmbeddingsRequestEntityTests extends ESTestCase { + + public void testXContent_WritesUserWhenDefined() throws IOException { + var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), "testuser", null, false); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"input":["abc"],"user":"testuser"}""")); + } + + public void testXContent_DoesNotWriteUserWhenItIsNull() throws IOException { + var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, null, false); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"input":["abc"]}""")); + } + + public void testXContent_DoesNotWriteDimensionsWhenNotSetByUser() throws IOException { + var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, 100, false); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"input":["abc"]}""")); + } + + public void testXContent_DoesNotWriteDimensionsWhenNull_EvenIfSetByUserIsTrue() throws IOException { + var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, null, true); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"input":["abc"]}""")); + } + + public void testXContent_WritesDimensionsWhenNonNull_AndSetByUserIsTrue() throws IOException { + var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, 100, true); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"input":["abc"],"dimensions":100}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..8e7c831a9820f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestTests.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiAccount; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class AzureOpenAiEmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest_WithApiKeyDefined() throws IOException, URISyntaxException { + var request = createRequest("resource", "deployment", "apiVersion", "apikey", null, "abc", "user"); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var expectedUri = AzureOpenAiEmbeddingsModel.getEmbeddingsUri("resource", "deployment", "apiVersion").toString(); + assertThat(httpPost.getURI().toString(), is(expectedUri)); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is("apikey")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("input"), is(List.of("abc"))); + assertThat(requestMap.get("user"), is("user")); + } + + public void testCreateRequest_WithEntraIdDefined() throws IOException, URISyntaxException { + var request = createRequest("resource", "deployment", "apiVersion", null, "entraId", "abc", "user"); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var expectedUri = AzureOpenAiEmbeddingsModel.getEmbeddingsUri("resource", "deployment", "apiVersion").toString(); + assertThat(httpPost.getURI().toString(), is(expectedUri)); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer entraId")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("input"), is(List.of("abc"))); + assertThat(requestMap.get("user"), is("user")); + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var request = createRequest("resource", "deployment", "apiVersion", "apikey", null, "abcd", null); + var truncatedRequest = request.truncate(); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat(requestMap.get("input"), is(List.of("ab"))); + } + + public void testIsTruncated_ReturnsTrue() { + var request = createRequest("resource", "deployment", "apiVersion", "apikey", null, "abcd", null); + assertFalse(request.getTruncationInfo()[0]); + + var truncatedRequest = request.truncate(); + assertTrue(truncatedRequest.getTruncationInfo()[0]); + } + + public static AzureOpenAiEmbeddingsRequest createRequest( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable String apiKey, + @Nullable String entraId, + String input, + @Nullable String user + ) { + var embeddingsModel = AzureOpenAiEmbeddingsModelTests.createModel( + resourceName, + deploymentId, + apiVersion, + user, + apiKey, + entraId, + "id" + ); + var account = AzureOpenAiAccount.fromModel(embeddingsModel); + + return new AzureOpenAiEmbeddingsRequest( + TruncatorTests.createTruncator(), + account, + new Truncator.TruncationResult(List.of(input), new boolean[] { false }), + embeddingsModel + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java new file mode 100644 index 0000000000000..97fa6efc962bb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.API_KEY; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.ENTRA_ID; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AzureOpenAiSecretSettingsTests extends AbstractWireSerializingTestCase { + + public static AzureOpenAiSecretSettings createRandom() { + return new AzureOpenAiSecretSettings( + new SecureString(randomAlphaOfLength(15).toCharArray()), + new SecureString(randomAlphaOfLength(15).toCharArray()) + ); + } + + public void testFromMap_ApiKey_Only() { + var serviceSettings = AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiSecretSettings.API_KEY, "abc"))); + assertThat(new AzureOpenAiSecretSettings(new SecureString("abc".toCharArray()), null), is(serviceSettings)); + } + + public void testFromMap_EntraId_Only() { + var serviceSettings = AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(ENTRA_ID, "xyz"))); + assertThat(new AzureOpenAiSecretSettings(null, new SecureString("xyz".toCharArray())), is(serviceSettings)); + } + + public void testFromMap_ReturnsNull_WhenMapIsNull() { + assertNull(AzureOpenAiSecretSettings.fromMap(null)); + } + + public void testFromMap_MissingApiKeyAndEntraId_ThrowsError() { + var thrownException = expectThrows(ValidationException.class, () -> AzureOpenAiSecretSettings.fromMap(new HashMap<>())); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "[secret_settings] must have either the [%s] or the [%s] key set", + AzureOpenAiSecretSettings.API_KEY, + ENTRA_ID + ) + ) + ); + } + + public void testFromMap_HasBothApiKeyAndEntraId_ThrowsError() { + var mapValues = getAzureOpenAiSecretSettingsMap("apikey", "entraid"); + var thrownException = expectThrows(ValidationException.class, () -> AzureOpenAiSecretSettings.fromMap(mapValues)); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "[secret_settings] must have only one of the [%s] or the [%s] key set", + AzureOpenAiSecretSettings.API_KEY, + ENTRA_ID + ) + ) + ); + } + + public void testFromMap_EmptyApiKey_ThrowsError() { + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiSecretSettings.API_KEY, ""))) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "[secret_settings] Invalid value empty string. [%s] must be a non-empty string", + AzureOpenAiSecretSettings.API_KEY + ) + ) + ); + } + + public void testFromMap_EmptyEntraId_ThrowsError() { + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(ENTRA_ID, ""))) + ); + + assertThat( + thrownException.getMessage(), + containsString(Strings.format("[secret_settings] Invalid value empty string. [%s] must be a non-empty string", ENTRA_ID)) + ); + } + + // test toXContent + public void testToXContext_WritesApiKeyOnlyWhenEntraIdIsNull() throws IOException { + var testSettings = new AzureOpenAiSecretSettings(new SecureString("apikey"), null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + testSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expectedResult = Strings.format("{\"%s\":\"apikey\"}", API_KEY); + assertThat(xContentResult, CoreMatchers.is(expectedResult)); + } + + public void testToXContext_WritesEntraIdOnlyWhenApiKeyIsNull() throws IOException { + var testSettings = new AzureOpenAiSecretSettings(null, new SecureString("entraid")); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + testSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expectedResult = Strings.format("{\"%s\":\"entraid\"}", ENTRA_ID); + assertThat(xContentResult, CoreMatchers.is(expectedResult)); + } + + @Override + protected Writeable.Reader instanceReader() { + return AzureOpenAiSecretSettings::new; + } + + @Override + protected AzureOpenAiSecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AzureOpenAiSecretSettings mutateInstance(AzureOpenAiSecretSettings instance) throws IOException { + return createRandom(); + } + + public static Map getAzureOpenAiSecretSettingsMap(@Nullable String apiKey, @Nullable String entraId) { + var map = new HashMap(); + if (apiKey != null) { + map.put(AzureOpenAiSecretSettings.API_KEY, apiKey); + } + if (entraId != null) { + map.put(ENTRA_ID, entraId); + } + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java new file mode 100644 index 0000000000000..4e65d987a26ad --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -0,0 +1,1180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER; +import static org.elasticsearch.xpack.inference.results.ChunkedTextEmbeddingResultsTests.asMapWithListsInsteadOfArrays; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettingsTests.getAzureOpenAiSecretSettingsMap; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getPersistentAzureOpenAiServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getRequestAzureOpenAiServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettingsTests.getAzureOpenAiRequestTaskSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class AzureOpenAiServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException { + try (var service = createAzureOpenAiService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createAzureOpenAiService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [azureopenai] service does not support task type [sparse_embedding]")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAzureOpenAiService()) { + var config = getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + config.put("extra_key", "value"); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createAzureOpenAiService()) { + var serviceSettings = getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + serviceSettings, + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + ); + }); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createAzureOpenAiService()) { + var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user"); + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + taskSettingsMap, + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + ); + }); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createAzureOpenAiService()) { + var secretSettingsMap = getAzureOpenAiSecretSettingsMap("secret", null); + secretSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + secretSettingsMap + ); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + ); + }); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_MovesModel() throws IOException { + try (var service = createAzureOpenAiService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnAzureOpenAiEmbeddingsModel() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [azureopenai] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + try (var service = createAzureOpenAiService()) { + var secretSettingsMap = getAzureOpenAiSecretSettingsMap("secret", null); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512), + getAzureOpenAiRequestTaskSettingsMap("user"), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + persistedConfig.secrets.put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createAzureOpenAiService()) { + var serviceSettingsMap = getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createAzureOpenAiService()) { + var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user"); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512), + taskSettingsMap, + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesAnAzureOpenAiEmbeddingsModel() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user") + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user") + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [azureopenai] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user") + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createAzureOpenAiService()) { + var serviceSettingsMap = getPersistentAzureOpenAiServiceSettingsMap( + "resource_name", + "deployment_id", + "api_version", + null, + null + ); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getAzureOpenAiRequestTaskSettingsMap("user")); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createAzureOpenAiService()) { + var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user"); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + taskSettingsMap + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender(anyString())).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(anyString()); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_SendsRequest() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), Matchers.is(2)); + assertThat(requestMap.get("input"), Matchers.is(List.of("abc"))); + assertThat(requestMap.get("user"), Matchers.is("user")); + } + } + + public void testCheckModelConfig_IncludesMaxTokens() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + null, + false, + 100, + null, + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 2, + false, + 100, + SimilarityMeasure.DOT_PRODUCT, + "user", + "apikey", + null, + "id" + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); + } + } + + public void testCheckModelConfig_HasSimilarity() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + null, + false, + null, + SimilarityMeasure.COSINE, + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 2, + false, + null, + SimilarityMeasure.COSINE, + "user", + "apikey", + null, + "id" + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); + } + } + + public void testCheckModelConfig_AddsDefaultSimilarityDotProduct() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + null, + false, + null, + null, + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 2, + false, + null, + SimilarityMeasure.DOT_PRODUCT, + "user", + "apikey", + null, + "id" + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); + } + } + + public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 3, + true, + 100, + null, + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + is( + "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " + + "Please recreate the [id] configuration with the correct dimensions" + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user", "dimensions", 3))); + } + } + + public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException, + URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 100, + false, + 100, + null, + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 2, + false, + 100, + SimilarityMeasure.DOT_PRODUCT, + "user", + "apikey", + null, + "id" + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); + } + } + + public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "error": { + "message": "Incorrect API key provided:", + "type": "invalid_request_error", + "param": null, + "code": "invalid_api_key" + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + assertThat(error.getMessage(), containsString("Error message: [Incorrect API key provided:]")); + assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testChunkedInfer_CallsInfer_ConvertsFloatResponse() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT).get(0); + assertThat(result, CoreMatchers.instanceOf(ChunkedTextEmbeddingResults.class)); + + assertThat( + asMapWithListsInsteadOfArrays((ChunkedTextEmbeddingResults) result), + Matchers.is( + Map.of( + ChunkedTextEmbeddingResults.FIELD_NAME, + List.of( + Map.of( + ChunkedNlpInferenceResults.TEXT, + "abc", + ChunkedNlpInferenceResults.INFERENCE, + List.of((double) 0.0123f, (double) -0.0123f) + ) + ) + ) + ) + ); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), Matchers.is(2)); + assertThat(requestMap.get("input"), Matchers.is(List.of("abc"))); + assertThat(requestMap.get("user"), Matchers.is("user")); + } + } + + private AzureOpenAiService createAzureOpenAiService() { + return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private PeristedConfig getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + + return new PeristedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) + ); + } + + private PeristedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { + + return new PeristedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + null + ); + } + + private record PeristedConfig(Map config, Map secrets) {} +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java new file mode 100644 index 0000000000000..f161cd0b823fe --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettingsTests.getAzureOpenAiRequestTaskSettingsMap; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class AzureOpenAiEmbeddingsModelTests extends ESTestCase { + + public void testOverrideWith_OverridesUser() { + var model = createModel("resource", "deployment", "apiversion", null, "api_key", null, "id"); + var requestTaskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user_override"); + + var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, requestTaskSettingsMap); + + assertThat(overriddenModel, is(createModel("resource", "deployment", "apiversion", "user_override", "api_key", null, "id"))); + } + + public void testOverrideWith_EmptyMap() { + var model = createModel("resource", "deployment", "apiversion", null, "api_key", null, "id"); + + var requestTaskSettingsMap = Map.of(); + + var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, requestTaskSettingsMap); + assertThat(overriddenModel, sameInstance(model)); + } + + public void testOverrideWith_NullMap() { + var model = createModel("resource", "deployment", "apiversion", null, "api_key", null, "id"); + + var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, null); + assertThat(overriddenModel, sameInstance(model)); + } + + public void testCreateModel_FromUpdatedServiceSettings() { + var model = createModel("resource", "deployment", "apiversion", "user", "api_key", null, "id"); + var updatedSettings = new AzureOpenAiEmbeddingsServiceSettings( + "resource", + "deployment", + "override_apiversion", + null, + false, + null, + null + ); + + var overridenModel = new AzureOpenAiEmbeddingsModel(model, updatedSettings); + + assertThat(overridenModel, is(createModel("resource", "deployment", "override_apiversion", "user", "api_key", null, "id"))); + } + + public static AzureOpenAiEmbeddingsModel createModel( + String resourceName, + String deploymentId, + String apiVersion, + String user, + @Nullable String apiKey, + @Nullable String entraId, + String inferenceEntityId + ) { + var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null; + var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null; + return new AzureOpenAiEmbeddingsModel( + inferenceEntityId, + TaskType.TEXT_EMBEDDING, + "service", + new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, false, null, null), + new AzureOpenAiEmbeddingsTaskSettings(user), + new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) + ); + } + + public static AzureOpenAiEmbeddingsModel createModel( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable Integer dimensions, + Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable SimilarityMeasure similarity, + @Nullable String user, + @Nullable String apiKey, + @Nullable String entraId, + String inferenceEntityId + ) { + var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null; + var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null; + + return new AzureOpenAiEmbeddingsModel( + inferenceEntityId, + TaskType.TEXT_EMBEDDING, + "service", + new AzureOpenAiEmbeddingsServiceSettings( + resourceName, + deploymentId, + apiVersion, + dimensions, + dimensionsSetByUser, + maxInputTokens, + similarity + ), + new AzureOpenAiEmbeddingsTaskSettings(user), + new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java new file mode 100644 index 0000000000000..3ff73e0f23656 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsRequestTaskSettings; + +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AzureOpenAiEmbeddingsRequestTaskSettingsTests extends ESTestCase { + public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() { + var settings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of())); + assertThat(settings, is(OpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS)); + } + + public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() { + var settings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))); + assertNull(settings.user()); + } + + public void testFromMap_ReturnsUser() { + var settings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); + assertThat(settings.user(), is("user")); + } + + public void testFromMap_WhenUserIsEmpty_ThrowsValidationException() { + var exception = expectThrows( + ValidationException.class, + () -> OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, ""))) + ); + + assertThat(exception.getMessage(), containsString("[user] must be a non-empty string")); + } + + public static Map getRequestTaskSettingsMap(@Nullable String user) { + var map = new HashMap(); + + if (user != null) { + map.put(OpenAiServiceFields.USER, user); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..be184956b2034 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java @@ -0,0 +1,389 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + + private static AzureOpenAiEmbeddingsServiceSettings createRandom() { + var resourceName = randomAlphaOfLength(8); + var deploymentId = randomAlphaOfLength(8); + var apiVersion = randomAlphaOfLength(8); + Integer dims = randomBoolean() ? 1536 : null; + Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + return new AzureOpenAiEmbeddingsServiceSettings( + resourceName, + deploymentId, + apiVersion, + dims, + randomBoolean(), + maxInputTokens, + null + ); + } + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var dims = 1536; + var maxInputTokens = 512; + var serviceSettings = AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + SIMILARITY, + SimilarityMeasure.COSINE.toString() + ) + ), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is( + new AzureOpenAiEmbeddingsServiceSettings( + resourceName, + deploymentId, + apiVersion, + dims, + true, + maxInputTokens, + SimilarityMeasure.COSINE + ) + ) + ); + } + + public void testFromMap_Request_DimensionsSetByUser_IsFalse_WhenDimensionsAreNotPresent() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var maxInputTokens = 512; + var serviceSettings = AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens + ) + ), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is(new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, false, maxInputTokens, null)) + ); + } + + public void testFromMap_Request_DimensionsSetByUser_ShouldThrowWhenPresent() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var maxInputTokens = 512; + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + ServiceFields.DIMENSIONS, + 1024, + DIMENSIONS_SET_BY_USER, + false + ) + ), + ConfigurationParseContext.REQUEST + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("Validation Failed: 1: [service_settings] does not allow the setting [%s];", DIMENSIONS_SET_BY_USER) + ) + ); + } + + public void testFromMap_Persistent_CreatesSettingsCorrectly() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var encodingFormat = "float"; + var dims = 1536; + var maxInputTokens = 512; + + var serviceSettings = AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + ServiceFields.DIMENSIONS, + dims, + DIMENSIONS_SET_BY_USER, + false, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString() + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new AzureOpenAiEmbeddingsServiceSettings( + resourceName, + deploymentId, + apiVersion, + dims, + false, + maxInputTokens, + SimilarityMeasure.DOT_PRODUCT + ) + ) + ); + } + + public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIsNull() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + + var settings = AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + DIMENSIONS_SET_BY_USER, + true + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(settings, is(new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, true, null, null))); + } + + public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIsPresent() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + + var settings = AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + DIMENSIONS_SET_BY_USER, + true, + SIMILARITY, + SimilarityMeasure.COSINE.toString() + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + settings, + is(new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, true, null, SimilarityMeasure.COSINE)) + ); + } + + public void testFromMap_PersistentContext_ThrowsException_WhenDimensionsSetByUserIsNull() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + + var exception = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + ServiceFields.DIMENSIONS, + 1 + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + exception.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [dimensions_set_by_user];") + ); + } + + public void testToXContent_WritesDimensionsSetByUserTrue() throws IOException { + var entity = new AzureOpenAiEmbeddingsServiceSettings("resource", "deployment", "apiVersion", null, true, null, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """ + "dimensions_set_by_user":true}""")); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new AzureOpenAiEmbeddingsServiceSettings("resource", "deployment", "apiVersion", 1024, false, 512, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """ + "dimensions":1024,"max_input_tokens":512,"dimensions_set_by_user":false}""")); + } + + public void testToFilteredXContent_WritesAllValues_ExceptDimensionsSetByUser() throws IOException { + var entity = new AzureOpenAiEmbeddingsServiceSettings("resource", "deployment", "apiVersion", 1024, false, 512, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + var filteredXContent = entity.getFilteredXContentObject(); + filteredXContent.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """ + "dimensions":1024,"max_input_tokens":512}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return AzureOpenAiEmbeddingsServiceSettings::new; + } + + @Override + protected AzureOpenAiEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AzureOpenAiEmbeddingsServiceSettings mutateInstance(AzureOpenAiEmbeddingsServiceSettings instance) throws IOException { + return createRandom(); + } + + public static Map getPersistentAzureOpenAiServiceSettingsMap( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens + ) { + var map = new HashMap(); + + map.put(AzureOpenAiServiceFields.RESOURCE_NAME, resourceName); + map.put(AzureOpenAiServiceFields.DEPLOYMENT_ID, deploymentId); + map.put(AzureOpenAiServiceFields.API_VERSION, apiVersion); + + if (dimensions != null) { + map.put(ServiceFields.DIMENSIONS, dimensions); + map.put(DIMENSIONS_SET_BY_USER, true); + } else { + map.put(DIMENSIONS_SET_BY_USER, false); + } + + if (maxInputTokens != null) { + map.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + + return map; + } + + public static Map getRequestAzureOpenAiServiceSettingsMap( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens + ) { + var map = new HashMap(); + + map.put(AzureOpenAiServiceFields.RESOURCE_NAME, resourceName); + map.put(AzureOpenAiServiceFields.DEPLOYMENT_ID, deploymentId); + map.put(AzureOpenAiServiceFields.API_VERSION, apiVersion); + + if (dimensions != null) { + map.put(ServiceFields.DIMENSIONS, dimensions); + } + + if (maxInputTokens != null) { + map.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java new file mode 100644 index 0000000000000..cc2d8b9b67620 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AzureOpenAiEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static AzureOpenAiEmbeddingsTaskSettings createRandomWithUser() { + return new AzureOpenAiEmbeddingsTaskSettings(randomAlphaOfLength(15)); + } + + /** + * The created settings can have the user set to null. + */ + public static AzureOpenAiEmbeddingsTaskSettings createRandom() { + var user = randomBoolean() ? randomAlphaOfLength(15) : null; + return new AzureOpenAiEmbeddingsTaskSettings(user); + } + + public void testFromMap_WithUser() { + assertEquals( + new AzureOpenAiEmbeddingsTaskSettings("user"), + AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))) + ); + } + + public void testFromMap_UserIsEmptyString() { + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, ""))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;")) + ); + } + + public void testFromMap_MissingUser_DoesNotThrowException() { + var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())); + assertNull(taskSettings.user()); + } + + public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { + var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); + + var overriddenTaskSettings = AzureOpenAiEmbeddingsTaskSettings.of( + taskSettings, + AzureOpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS + ); + MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); + } + + public void testOverrideWith_UsesOverriddenSettings() { + var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); + + var requestTaskSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user2")) + ); + + var overriddenTaskSettings = AzureOpenAiEmbeddingsTaskSettings.of(taskSettings, requestTaskSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureOpenAiEmbeddingsTaskSettings("user2"))); + } + + @Override + protected Writeable.Reader instanceReader() { + return AzureOpenAiEmbeddingsTaskSettings::new; + } + + @Override + protected AzureOpenAiEmbeddingsTaskSettings createTestInstance() { + return createRandomWithUser(); + } + + @Override + protected AzureOpenAiEmbeddingsTaskSettings mutateInstance(AzureOpenAiEmbeddingsTaskSettings instance) throws IOException { + return createRandomWithUser(); + } + + public static Map getAzureOpenAiRequestTaskSettingsMap(@Nullable String user) { + var map = new HashMap(); + + if (user != null) { + map.put(AzureOpenAiServiceFields.USER, user); + } + + return map; + } +} diff --git a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/MultiNodesStatsTests.java b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/MultiNodesStatsTests.java index c8aae302e357b..3c085b9bb2820 100644 --- a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/MultiNodesStatsTests.java +++ b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/MultiNodesStatsTests.java @@ -87,9 +87,7 @@ public void testMultipleNodes() throws Exception { assertThat(((StringTerms) aggregation).getBuckets().size(), equalTo(nbNodes)); for (String nodeName : internalCluster().getNodeNames()) { - StringTerms.Bucket bucket = ((StringTerms) aggregation).getBucketByKey( - internalCluster().clusterService(nodeName).localNode().getId() - ); + StringTerms.Bucket bucket = ((StringTerms) aggregation).getBucketByKey(getNodeId(nodeName)); // At least 1 doc must exist per node, but it can be more than 1 // because the first node may have already collected many node stats documents // whereas the last node just started to collect node stats. diff --git a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterIntegTests.java b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterIntegTests.java index ef4f22f852b37..69ac9d4ddd876 100644 --- a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterIntegTests.java +++ b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterIntegTests.java @@ -173,7 +173,7 @@ public void testExport() throws Exception { aggregation.getBuckets().size() ); for (String nodeName : internalCluster().getNodeNames()) { - String nodeId = internalCluster().clusterService(nodeName).localNode().getId(); + String nodeId = getNodeId(nodeName); Terms.Bucket bucket = aggregation.getBucketByKey(nodeId); assertTrue("No bucket found for node id [" + nodeId + "]", bucket != null); assertTrue(bucket.getDocCount() >= 1L); @@ -208,7 +208,7 @@ public void testExport() throws Exception { response -> { Terms aggregation = response.getAggregations().get("agg_nodes_ids"); for (String nodeName : internalCluster().getNodeNames()) { - String nodeId = internalCluster().clusterService(nodeName).localNode().getId(); + String nodeId = getNodeId(nodeName); Terms.Bucket bucket = aggregation.getBucketByKey(nodeId); assertTrue("No bucket found for node id [" + nodeId + "]", bucket != null); assertTrue(bucket.getDocCount() >= 1L); diff --git a/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/persistence/ProfilingDataStreamManagerTests.java b/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/persistence/ProfilingDataStreamManagerTests.java index f2245baafe0c0..8414b5cab0f08 100644 --- a/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/persistence/ProfilingDataStreamManagerTests.java +++ b/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/persistence/ProfilingDataStreamManagerTests.java @@ -478,17 +478,10 @@ private ClusterState createClusterState( for (ProfilingDataStreamManager.ProfilingDataStream existingDataStream : existingDataStreams) { String writeIndexName = String.format(Locale.ROOT, ".ds-%s", existingDataStream.getName()); Index writeIndex = new Index(writeIndexName, writeIndexName); - DataStream ds = new DataStream( - existingDataStream.getName(), - List.of(writeIndex), - 1, - Map.of(), - false, - false, - false, - false, - IndexMode.STANDARD - ); + DataStream ds = DataStream.builder(existingDataStream.getName(), List.of(writeIndex)) + .setMetadata(Map.of()) + .setIndexMode(IndexMode.STANDARD) + .build(); metadataBuilder.put(ds); IndexMetadata.Builder builder = new IndexMetadata.Builder(writeIndexName); builder.state(state); diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/PrevalidateNodeRemovalWithSearchableSnapshotIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/PrevalidateNodeRemovalWithSearchableSnapshotIntegTests.java index a651c4b30fcb1..37e2427ae6891 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/PrevalidateNodeRemovalWithSearchableSnapshotIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/PrevalidateNodeRemovalWithSearchableSnapshotIntegTests.java @@ -63,7 +63,7 @@ public void testNodeRemovalFromClusterWihRedSearchableSnapshotIndex() throws Exc PrevalidateNodeRemovalRequest.Builder req = PrevalidateNodeRemovalRequest.builder(); switch (randomIntBetween(0, 2)) { case 0 -> req.setNames(node2); - case 1 -> req.setIds(internalCluster().clusterService(node2).localNode().getId()); + case 1 -> req.setIds(getNodeId(node2)); case 2 -> req.setExternalIds(internalCluster().clusterService(node2).localNode().getExternalId()); default -> throw new IllegalStateException("Unexpected value"); } diff --git a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java index e181a3542d446..2c393ea7ed1df 100644 --- a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java +++ b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java @@ -36,7 +36,10 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -83,7 +86,11 @@ public class RemoteClusterSecurityEsqlIT extends AbstractRemoteClusterSecurityTe { "search": [ { - "names": ["index*", "not_found_index", "employees"] + "names": ["index*", "not_found_index", "employees", "employees2"] + }, + { + "names": ["employees3"], + "query": {"term" : {"department" : "engineering"}} } ] }"""); @@ -188,40 +195,8 @@ public void populateData() throws Exception { performRequestWithAdminUser(client, new Request("DELETE", "/countries")); }; // Fulfilling cluster - { - setupEnrich.accept(fulfillingClusterClient); - Request createIndex = new Request("PUT", "employees"); - createIndex.setJsonEntity(""" - { - "mappings": { - "properties": { - "emp_id": { "type": "keyword" }, - "department": {"type": "keyword" } - } - } - } - """); - assertOK(performRequestAgainstFulfillingCluster(createIndex)); - final Request bulkRequest = new Request("POST", "/_bulk?refresh=true"); - bulkRequest.setJsonEntity(Strings.format(""" - { "index": { "_index": "employees" } } - { "emp_id": "1", "department" : "engineering" } - { "index": { "_index": "employees" } } - { "emp_id": "3", "department" : "sales" } - { "index": { "_index": "employees" } } - { "emp_id": "5", "department" : "marketing" } - { "index": { "_index": "employees" } } - { "emp_id": "7", "department" : "engineering" } - { "index": { "_index": "employees" } } - { "emp_id": "9", "department" : "sales" } - """)); - assertOK(performRequestAgainstFulfillingCluster(bulkRequest)); - } - // Querying cluster - // Index some documents, to use them in a mixed-cluster search - setupEnrich.accept(client()); - Request createIndex = new Request("PUT", "employees"); - createIndex.setJsonEntity(""" + setupEnrich.accept(fulfillingClusterClient); + String employeesMapping = """ { "mappings": { "properties": { @@ -230,9 +205,57 @@ public void populateData() throws Exception { } } } - """); + """; + Request createIndex = new Request("PUT", "employees"); + createIndex.setJsonEntity(employeesMapping); + assertOK(performRequestAgainstFulfillingCluster(createIndex)); + Request createIndex2 = new Request("PUT", "employees2"); + createIndex2.setJsonEntity(employeesMapping); + assertOK(performRequestAgainstFulfillingCluster(createIndex2)); + Request createIndex3 = new Request("PUT", "employees3"); + createIndex3.setJsonEntity(employeesMapping); + assertOK(performRequestAgainstFulfillingCluster(createIndex3)); + Request bulkRequest = new Request("POST", "/_bulk?refresh=true"); + bulkRequest.setJsonEntity(Strings.format(""" + { "index": { "_index": "employees" } } + { "emp_id": "1", "department" : "engineering" } + { "index": { "_index": "employees" } } + { "emp_id": "3", "department" : "sales" } + { "index": { "_index": "employees" } } + { "emp_id": "5", "department" : "marketing" } + { "index": { "_index": "employees" } } + { "emp_id": "7", "department" : "engineering" } + { "index": { "_index": "employees" } } + { "emp_id": "9", "department" : "sales" } + { "index": { "_index": "employees2" } } + { "emp_id": "11", "department" : "engineering" } + { "index": { "_index": "employees2" } } + { "emp_id": "13", "department" : "sales" } + { "index": { "_index": "employees3" } } + { "emp_id": "21", "department" : "engineering" } + { "index": { "_index": "employees3" } } + { "emp_id": "23", "department" : "sales" } + { "index": { "_index": "employees3" } } + { "emp_id": "25", "department" : "engineering" } + { "index": { "_index": "employees3" } } + { "emp_id": "27", "department" : "sales" } + """)); + assertOK(performRequestAgainstFulfillingCluster(bulkRequest)); + + // Querying cluster + // Index some documents, to use them in a mixed-cluster search + setupEnrich.accept(client()); + + createIndex = new Request("PUT", "employees"); + createIndex.setJsonEntity(employeesMapping); assertOK(adminClient().performRequest(createIndex)); - final Request bulkRequest = new Request("POST", "/_bulk?refresh=true"); + createIndex2 = new Request("PUT", "employees2"); + createIndex2.setJsonEntity(employeesMapping); + assertOK(adminClient().performRequest(createIndex2)); + createIndex3 = new Request("PUT", "employees3"); + createIndex3.setJsonEntity(employeesMapping); + assertOK(adminClient().performRequest(createIndex3)); + bulkRequest = new Request("POST", "/_bulk?refresh=true"); bulkRequest.setJsonEntity(Strings.format(""" { "index": { "_index": "employees" } } { "emp_id": "2", "department" : "management" } @@ -242,6 +265,14 @@ public void populateData() throws Exception { { "emp_id": "6", "department" : "marketing"} { "index": { "_index": "employees"} } { "emp_id": "8", "department" : "support"} + { "index": { "_index": "employees2"} } + { "emp_id": "10", "department" : "management"} + { "index": { "_index": "employees2"} } + { "emp_id": "12", "department" : "engineering"} + { "index": { "_index": "employees3"} } + { "emp_id": "20", "department" : "management"} + { "index": { "_index": "employees3"} } + { "emp_id": "22", "department" : "engineering"} """)); assertOK(client().performRequest(bulkRequest)); @@ -259,7 +290,7 @@ public void populateData() throws Exception { "remote_indices": [ { "names": ["employees"], - "privileges": ["read", "read_cross_cluster"], + "privileges": ["read"], "clusters": ["my_remote_cluster"] } ] @@ -278,56 +309,300 @@ public void populateData() throws Exception { public void wipeData() throws Exception { CheckedConsumer wipe = client -> { performRequestWithAdminUser(client, new Request("DELETE", "/employees")); + performRequestWithAdminUser(client, new Request("DELETE", "/employees2")); + performRequestWithAdminUser(client, new Request("DELETE", "/employees3")); performRequestWithAdminUser(client, new Request("DELETE", "/_enrich/policy/countries")); }; wipe.accept(fulfillingClusterClient); wipe.accept(client()); } - @AwaitsFix(bugUrl = "cross-clusters query doesn't work with RCS 2.0") + @SuppressWarnings("unchecked") public void testCrossClusterQuery() throws Exception { configureRemoteCluster(); populateData(); + + // query remote cluster only + Response response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); + assertOK(response); + assertRemoteOnlyResults(response); + + // query remote and local cluster + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees,employees + | SORT emp_id ASC + | LIMIT 10""")); + assertOK(response); + assertRemoteAndLocalResults(response); + + // query remote cluster only - but also include employees2 which the user does not have access to + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees,my_remote_cluster:employees2 + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); + assertOK(response); + assertRemoteOnlyResults(response); // same as above since the user only has access to employees + + // query remote and local cluster - but also include employees2 which the user does not have access to + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees,my_remote_cluster:employees2,employees,employees2 + | SORT emp_id ASC + | LIMIT 10""")); + assertOK(response); + assertRemoteAndLocalResults(response); // same as above since the user only has access to employees + + // update role to include both employees and employees2 for the remote cluster + final var putRoleRequest = new Request("PUT", "/_security/role/" + REMOTE_SEARCH_ROLE); + putRoleRequest.setJsonEntity(""" + { + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["employees*"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"] + } + ] + }"""); + response = adminClient().performRequest(putRoleRequest); + assertOK(response); + + // query remote cluster only - but also include employees2 which the user now access + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees,my_remote_cluster:employees2 + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); + assertOK(response); + assertRemoteOnlyAgainst2IndexResults(response); + } + + @SuppressWarnings("unchecked") + public void testCrossClusterQueryWithRemoteDLSAndFLS() throws Exception { + configureRemoteCluster(); + populateData(); + + // ensure user has access to the employees3 index + final var putRoleRequest = new Request("PUT", "/_security/role/" + REMOTE_SEARCH_ROLE); + putRoleRequest.setJsonEntity(""" + { + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["employees*"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"] + + } + ] + }"""); + Response response = adminClient().performRequest(putRoleRequest); + assertOK(response); + + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees3 + | SORT emp_id ASC + | LIMIT 10 + | KEEP emp_id, department""")); + assertOK(response); + + Map responseAsMap = entityAsMap(response); + List columns = (List) responseAsMap.get("columns"); + List values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(2, values.size()); + List flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + // the APIKey has DLS set to : "query": {"term" : {"department" : "engineering"}} + assertThat(flatList, containsInAnyOrder("21", "25", "engineering", "engineering")); + + // add DLS to the remote indices in the role to restrict access to only emp_id = 21 + putRoleRequest.setJsonEntity(""" + { + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["employees*"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"], + "query": {"term" : {"emp_id" : "21"}} + + } + ] + }"""); + response = adminClient().performRequest(putRoleRequest); + assertOK(response); + + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees3 + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); + assertOK(response); + + responseAsMap = entityAsMap(response); + columns = (List) responseAsMap.get("columns"); + values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(1, values.size()); + flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + // the APIKey has DLS set to : "query": {"term" : {"department" : "engineering"}} + // AND this role has DLS set to: "query": {"term" : {"emp_id" : "21"}} + assertThat(flatList, containsInAnyOrder("21", "engineering")); + + // add FLS to the remote indices in the role to restrict access to only access department + putRoleRequest.setJsonEntity(""" + { + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["employees*"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"], + "query": {"term" : {"emp_id" : "21"}}, + "field_security": {"grant": [ "department" ]} + } + ] + }"""); + response = adminClient().performRequest(putRoleRequest); + assertOK(response); + + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees3 + | LIMIT 2 + """)); + assertOK(response); + responseAsMap = entityAsMap(response); + columns = (List) responseAsMap.get("columns"); + values = (List) responseAsMap.get("values"); + assertEquals(1, columns.size()); + assertEquals(1, values.size()); + flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + // the APIKey has DLS set to : "query": {"term" : {"department" : "engineering"}} + // AND this role has DLS set to: "query": {"term" : {"emp_id" : "21"}} + // AND this role has FLS set to: "field_security": {"grant": [ "department" ]} + assertThat(flatList, containsInAnyOrder("engineering")); + } + + public void testCrossClusterQueryAgainstInvalidRemote() throws Exception { + configureRemoteCluster(); + populateData(); + + // avoids getting 404 errors + updateClusterSettings( + randomBoolean() + ? Settings.builder().put("cluster.remote.invalid_remote.seeds", fulfillingCluster.getRemoteClusterServerEndpoint(0)).build() + : Settings.builder() + .put("cluster.remote.invalid_remote.mode", "proxy") + .put("cluster.remote.invalid_remote.proxy_address", fulfillingCluster.getRemoteClusterServerEndpoint(0)) + .build() + ); + + // invalid remote with local index should return local results + var q = "FROM invalid_remote:employees,employees | SORT emp_id DESC | LIMIT 10"; + Response response = performRequestWithRemoteSearchUser(esqlRequest(q)); + assertOK(response); + assertLocalOnlyResults(response); + + // only calling an invalid remote should error + ResponseException error = expectThrows(ResponseException.class, () -> { + var q2 = "FROM invalid_remote:employees | SORT emp_id DESC | LIMIT 10"; + performRequestWithRemoteSearchUser(esqlRequest(q2)); + }); + assertThat(error.getResponse().getStatusLine().getStatusCode(), equalTo(401)); + assertThat(error.getMessage(), containsString("unable to find apikey")); + } + + @SuppressWarnings("unchecked") + public void testCrossClusterQueryWithOnlyRemotePrivs() throws Exception { + configureRemoteCluster(); + populateData(); + // Query cluster - { + var putRoleRequest = new Request("PUT", "/_security/role/" + REMOTE_SEARCH_ROLE); + putRoleRequest.setJsonEntity(""" { - Response response = performRequestWithRemoteSearchUser(esqlRequest(""" - FROM my_remote_cluster:employees - | SORT emp_id ASC - | LIMIT 2 - | KEEP emp_id, department""")); - assertOK(response); - Map values = entityAsMap(response); - } + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["employees"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"] + } + ] + }"""); + assertOK(adminClient().performRequest(putRoleRequest)); + + // query appropriate privs + Response response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); + assertOK(response); + assertRemoteOnlyResults(response); + + // without the remote index priv + putRoleRequest.setJsonEntity(""" { - Response response = performRequestWithRemoteSearchUser(esqlRequest(""" - FROM my_remote_cluster:employees,employees - | SORT emp_id ASC - | LIMIT 10""")); - assertOK(response); + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["idontexist"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"] + } + ] + }"""); + assertOK(adminClient().performRequest(putRoleRequest)); - } - // Check that authentication fails if we use a non-existent API key - updateClusterSettings( - randomBoolean() - ? Settings.builder() - .put("cluster.remote.invalid_remote.seeds", fulfillingCluster.getRemoteClusterServerEndpoint(0)) - .build() - : Settings.builder() - .put("cluster.remote.invalid_remote.mode", "proxy") - .put("cluster.remote.invalid_remote.proxy_address", fulfillingCluster.getRemoteClusterServerEndpoint(0)) - .build() - ); - for (String indices : List.of("my_remote_cluster:employees,employees", "my_remote_cluster:employees")) { - ResponseException error = expectThrows(ResponseException.class, () -> { - var q = "FROM " + indices + "| SORT emp_id DESC | LIMIT 10"; - performRequestWithLocalSearchUser(esqlRequest(q)); - }); - assertThat(error.getResponse().getStatusLine().getStatusCode(), equalTo(403)); - assertThat(error.getResponse().getStatusLine().getStatusCode(), equalTo(401)); - assertThat(error.getMessage(), containsString("unable to find apikey")); - } - } + ResponseException error = expectThrows(ResponseException.class, () -> performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department"""))); + assertThat(error.getResponse().getStatusLine().getStatusCode(), equalTo(400)); + assertThat(error.getMessage(), containsString("Unknown index [my_remote_cluster:employees]")); + + // no local privs at all will fail + final var putRoleNoLocalPrivs = new Request("PUT", "/_security/role/" + REMOTE_SEARCH_ROLE); + putRoleNoLocalPrivs.setJsonEntity(""" + { + "indices": [], + "remote_indices": [ + { + "names": ["employees"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"] + } + ] + }"""); + assertOK(adminClient().performRequest(putRoleNoLocalPrivs)); + + error = expectThrows(ResponseException.class, () -> { performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); }); + + assertThat(error.getResponse().getStatusLine().getStatusCode(), equalTo(403)); + assertThat( + error.getMessage(), + containsString( + "action [indices:data/read/esql] is unauthorized for user [remote_search_user] with effective roles [remote_search], " + + "this action is granted by the index privileges [read,read_cross_cluster,all]" + ) + ); } @AwaitsFix(bugUrl = "cross-clusters enrich doesn't work with RCS 2.0") @@ -360,7 +635,7 @@ public void testCrossClusterEnrich() throws Exception { "remote_indices": [ { "names": ["employees"], - "privileges": ["read", "read_cross_cluster"], + "privileges": ["read"], "clusters": ["my_remote_cluster"] } ] @@ -434,4 +709,79 @@ private Response performRequestWithLocalSearchUser(final Request request) throws ); return client().performRequest(request); } + + @SuppressWarnings("unchecked") + private void assertRemoteOnlyResults(Response response) throws IOException { + Map responseAsMap = entityAsMap(response); + List columns = (List) responseAsMap.get("columns"); + List values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(2, values.size()); + List flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + assertThat(flatList, containsInAnyOrder("1", "3", "engineering", "sales")); + } + + @SuppressWarnings("unchecked") + private void assertRemoteOnlyAgainst2IndexResults(Response response) throws IOException { + Map responseAsMap = entityAsMap(response); + List columns = (List) responseAsMap.get("columns"); + List values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(2, values.size()); + List flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + assertThat(flatList, containsInAnyOrder("1", "11", "engineering", "engineering")); + } + + @SuppressWarnings("unchecked") + private void assertLocalOnlyResults(Response response) throws IOException { + Map responseAsMap = entityAsMap(response); + List columns = (List) responseAsMap.get("columns"); + List values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(4, values.size()); + List flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + // local results + assertThat(flatList, containsInAnyOrder("2", "4", "6", "8", "support", "management", "engineering", "marketing")); + } + + @SuppressWarnings("unchecked") + private void assertRemoteAndLocalResults(Response response) throws IOException { + Map responseAsMap = entityAsMap(response); + List columns = (List) responseAsMap.get("columns"); + List values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(9, values.size()); + List flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + assertThat( + flatList, + containsInAnyOrder( + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "engineering", + "engineering", + "engineering", + "management", + "sales", + "sales", + "marketing", + "marketing", + "support" + ) + ); + } } diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DataStreamSecurityIT.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DataStreamSecurityIT.java index 5efa2aa46c7bc..96284b2826e48 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DataStreamSecurityIT.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DataStreamSecurityIT.java @@ -89,17 +89,9 @@ public ClusterState execute(ClusterState currentState) throws Exception { String brokenIndexName = shouldBreakIndexName ? original.getIndices().get(0).getName() + "-broken" : original.getIndices().get(0).getName(); - DataStream broken = new DataStream( - original.getName(), - List.of(new Index(brokenIndexName, "broken"), original.getIndices().get(1)), - original.getGeneration(), - original.getMetadata(), - original.isHidden(), - original.isReplicated(), - original.isSystem(), - original.isAllowCustomRouting(), - original.getIndexMode() - ); + DataStream broken = original.copy() + .setIndices(List.of(new Index(brokenIndexName, "broken"), original.getIndices().get(1))) + .build(); brokenDataStreamHolder.set(broken); return ClusterState.builder(currentState) .metadata(Metadata.builder(currentState.getMetadata()).put(broken).build()) diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authz/store/DisableNativeRoleMappingsStoreTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authz/store/DisableNativeRoleMappingsStoreTests.java new file mode 100644 index 0000000000000..4f56d783e117c --- /dev/null +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authz/store/DisableNativeRoleMappingsStoreTests.java @@ -0,0 +1,157 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.authz.store; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.SecurityIntegTestCase; +import org.elasticsearch.test.SecuritySettingsSource; +import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.xpack.core.security.action.rolemapping.DeleteRoleMappingRequest; +import org.elasticsearch.xpack.core.security.action.rolemapping.PutRoleMappingRequest; +import org.elasticsearch.xpack.core.security.authc.RealmConfig; +import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper; +import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; +import org.elasticsearch.xpack.core.security.authc.support.mapper.ExpressionRoleMapping; +import org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class DisableNativeRoleMappingsStoreTests extends SecurityIntegTestCase { + + @Override + protected Collection> nodePlugins() { + List> plugins = new ArrayList<>(super.nodePlugins()); + plugins.add(PrivateCustomPlugin.class); + return plugins; + } + + @Override + protected boolean addMockHttpTransport() { + return false; // need real http + } + + public void testPutRoleMappingDisallowed() { + // transport action + NativeRoleMappingStore nativeRoleMappingStore = internalCluster().getInstance(NativeRoleMappingStore.class); + PlainActionFuture future = new PlainActionFuture<>(); + nativeRoleMappingStore.putRoleMapping(new PutRoleMappingRequest(), future); + ExecutionException e = expectThrows(ExecutionException.class, future::get); + assertThat(e.getMessage(), containsString("Native role mapping management is disabled")); + // rest request + Request request = new Request("POST", "_security/role_mapping/" + randomAlphaOfLength(8)); + RequestOptions.Builder options = request.getOptions().toBuilder(); + options.addHeader( + "Authorization", + UsernamePasswordToken.basicAuthHeaderValue( + SecuritySettingsSource.TEST_USER_NAME, + new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()) + ) + ); + request.setOptions(options); + ResponseException e2 = expectThrows(ResponseException.class, () -> getRestClient().performRequest(request)); + assertThat(e2.getMessage(), containsString("Native role mapping management is not enabled in this Elasticsearch instance")); + assertThat(e2.getResponse().getStatusLine().getStatusCode(), is(410)); // gone + } + + public void testDeleteRoleMappingDisallowed() { + // transport action + NativeRoleMappingStore nativeRoleMappingStore = internalCluster().getInstance(NativeRoleMappingStore.class); + PlainActionFuture future = new PlainActionFuture<>(); + nativeRoleMappingStore.deleteRoleMapping(new DeleteRoleMappingRequest(), future); + ExecutionException e = expectThrows(ExecutionException.class, future::get); + assertThat(e.getMessage(), containsString("Native role mapping management is disabled")); + // rest request + Request request = new Request("DELETE", "_security/role_mapping/" + randomAlphaOfLength(8)); + RequestOptions.Builder options = request.getOptions().toBuilder(); + options.addHeader( + "Authorization", + UsernamePasswordToken.basicAuthHeaderValue( + SecuritySettingsSource.TEST_USER_NAME, + new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()) + ) + ); + request.setOptions(options); + ResponseException e2 = expectThrows(ResponseException.class, () -> getRestClient().performRequest(request)); + assertThat(e2.getMessage(), containsString("Native role mapping management is not enabled in this Elasticsearch instance")); + assertThat(e2.getResponse().getStatusLine().getStatusCode(), is(410)); // gone + } + + public void testGetRoleMappingDisallowed() throws Exception { + // transport action + NativeRoleMappingStore nativeRoleMappingStore = internalCluster().getInstance(NativeRoleMappingStore.class); + PlainActionFuture> future = new PlainActionFuture<>(); + nativeRoleMappingStore.getRoleMappings(randomFrom(Set.of(randomAlphaOfLength(8)), null), future); + assertThat(future.get(), emptyIterable()); + // rest request + Request request = new Request("GET", "_security/role_mapping/" + randomAlphaOfLength(8)); + RequestOptions.Builder options = request.getOptions().toBuilder(); + options.addHeader( + "Authorization", + UsernamePasswordToken.basicAuthHeaderValue( + SecuritySettingsSource.TEST_USER_NAME, + new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()) + ) + ); + request.setOptions(options); + ResponseException e2 = expectThrows(ResponseException.class, () -> getRestClient().performRequest(request)); + assertThat(e2.getMessage(), containsString("Native role mapping management is not enabled in this Elasticsearch instance")); + assertThat(e2.getResponse().getStatusLine().getStatusCode(), is(410)); // gone + } + + public void testResolveRoleMappings() throws Exception { + NativeRoleMappingStore nativeRoleMappingStore = internalCluster().getInstance(NativeRoleMappingStore.class); + UserRoleMapper.UserData userData = new UserRoleMapper.UserData( + randomAlphaOfLength(4), + null, + randomFrom(Set.of(randomAlphaOfLength(4)), Set.of()), + Map.of(), + mock(RealmConfig.class) + ); + PlainActionFuture> future = new PlainActionFuture<>(); + nativeRoleMappingStore.resolveRoles(userData, future); + assertThat(future.get(), emptyIterable()); + } + + public static class PrivateCustomPlugin extends Plugin { + + public static final Setting NATIVE_ROLE_MAPPINGS_SETTING = Setting.boolSetting( + "xpack.security.authc.native_role_mappings.enabled", + true, + Setting.Property.NodeScope + ); + + public PrivateCustomPlugin() {} + + @Override + public Settings additionalSettings() { + return Settings.builder().put(NATIVE_ROLE_MAPPINGS_SETTING.getKey(), false).build(); + } + + @Override + public List> getSettings() { + return List.of(NATIVE_ROLE_MAPPINGS_SETTING); + } + } +} diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index f4457dcbbfaa9..837c58ab6542d 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -296,6 +296,7 @@ import org.elasticsearch.xpack.security.authz.AuthorizationDenialMessages; import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.elasticsearch.xpack.security.authz.DlsFlsRequestCacheDifferentiator; +import org.elasticsearch.xpack.security.authz.FileRoleValidator; import org.elasticsearch.xpack.security.authz.ReservedRoleNameChecker; import org.elasticsearch.xpack.security.authz.SecuritySearchOperationListener; import org.elasticsearch.xpack.security.authz.accesscontrol.OptOutQueryCache; @@ -588,6 +589,7 @@ public class Security extends Plugin private final SetOnce> reloadableComponents = new SetOnce<>(); private final SetOnce authorizationDenialMessages = new SetOnce<>(); private final SetOnce reservedRoleNameCheckerFactory = new SetOnce<>(); + private final SetOnce fileRoleValidator = new SetOnce<>(); private final SetOnce secondaryAuthActions = new SetOnce<>(); public Security(Settings settings) { @@ -828,7 +830,6 @@ Collection createComponents( dlsBitsetCache.set(new DocumentSubsetBitsetCache(settings, threadPool)); final FieldPermissionsCache fieldPermissionsCache = new FieldPermissionsCache(settings); - this.fileRolesStore.set(new FileRolesStore(settings, environment, resourceWatcherService, getLicenseState(), xContentRegistry)); final NativeRolesStore nativeRolesStore = new NativeRolesStore( settings, client, @@ -859,6 +860,12 @@ Collection createComponents( if (reservedRoleNameCheckerFactory.get() == null) { reservedRoleNameCheckerFactory.set(new ReservedRoleNameChecker.Factory.Default()); } + if (fileRoleValidator.get() == null) { + fileRoleValidator.set(new FileRoleValidator.Default()); + } + this.fileRolesStore.set( + new FileRolesStore(settings, environment, resourceWatcherService, getLicenseState(), xContentRegistry, fileRoleValidator.get()) + ); final ReservedRoleNameChecker reservedRoleNameChecker = reservedRoleNameCheckerFactory.get().create(fileRolesStore.get()::exists); components.add(new PluginComponentBinding<>(ReservedRoleNameChecker.class, reservedRoleNameChecker)); @@ -2118,6 +2125,7 @@ public void loadExtensions(ExtensionLoader loader) { loadSingletonExtensionAndSetOnce(loader, hasPrivilegesRequestBuilderFactory, HasPrivilegesRequestBuilderFactory.class); loadSingletonExtensionAndSetOnce(loader, authorizationDenialMessages, AuthorizationDenialMessages.class); loadSingletonExtensionAndSetOnce(loader, reservedRoleNameCheckerFactory, ReservedRoleNameChecker.Factory.class); + loadSingletonExtensionAndSetOnce(loader, fileRoleValidator, FileRoleValidator.class); loadSingletonExtensionAndSetOnce(loader, secondaryAuthActions, SecondaryAuthActions.class); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStore.java index 4abf2e53d0264..926626f2eaf10 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStore.java @@ -48,7 +48,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -85,6 +84,17 @@ */ public class NativeRoleMappingStore implements UserRoleMapper { + /** + * This setting is never registered by the security plugin - in order to disable the native role APIs + * another plugin must register it as a boolean setting and cause it to be set to `false`. + * + * If this setting is set to false then + *
    + *
  • the Rest APIs for native role mappings management are disabled.
  • + *
  • The native role mappings store will not map any roles to any user.
  • + *
+ */ + public static final String NATIVE_ROLE_MAPPINGS_ENABLED = "xpack.security.authc.native_role_mappings.enabled"; private static final Logger logger = LogManager.getLogger(NativeRoleMappingStore.class); static final String DOC_TYPE_FIELD = "doc_type"; static final String DOC_TYPE_ROLE_MAPPING = "role-mapping"; @@ -105,6 +115,7 @@ public class NativeRoleMappingStore implements UserRoleMapper { private final List realmsToRefresh = new CopyOnWriteArrayList<>(); private final boolean lastLoadCacheEnabled; private final AtomicReference> lastLoadRef = new AtomicReference<>(null); + private final boolean enabled; public NativeRoleMappingStore(Settings settings, Client client, SecurityIndexManager securityIndex, ScriptService scriptService) { this.settings = settings; @@ -112,16 +123,7 @@ public NativeRoleMappingStore(Settings settings, Client client, SecurityIndexMan this.securityIndex = securityIndex; this.scriptService = scriptService; this.lastLoadCacheEnabled = LAST_LOAD_CACHE_ENABLED_SETTING.get(settings); - } - - private static String getNameFromId(String id) { - assert id.startsWith(ID_PREFIX); - return id.substring(ID_PREFIX.length()); - } - - // package-private for testing - static String getIdForName(String name) { - return ID_PREFIX + name; + this.enabled = settings.getAsBoolean(NATIVE_ROLE_MAPPINGS_ENABLED, true); } /** @@ -129,6 +131,10 @@ static String getIdForName(String name) { * package private for unit testing */ protected void loadMappings(ActionListener> listener) { + if (enabled == false) { + listener.onResponse(List.of()); + return; + } if (securityIndex.isIndexUpToDate() == false) { listener.onFailure( new IllegalStateException( @@ -164,32 +170,21 @@ protected void loadMappings(ActionListener> listener () -> format("failed to load role mappings from index [%s] skipping all mappings.", SECURITY_MAIN_ALIAS), ex ); - listener.onResponse(Collections.emptyList()); + listener.onResponse(List.of()); })), doc -> buildMapping(getNameFromId(doc.getId()), doc.getSourceRef()) ); } } - protected static ExpressionRoleMapping buildMapping(String id, BytesReference source) { - try ( - XContentParser parser = XContentHelper.createParserNotCompressed( - LoggingDeprecationHandler.XCONTENT_PARSER_CONFIG, - source, - XContentType.JSON - ) - ) { - return ExpressionRoleMapping.parse(id, parser); - } catch (Exception e) { - logger.warn(() -> "Role mapping [" + id + "] cannot be parsed and will be skipped", e); - return null; - } - } - /** * Stores (create or update) a single mapping in the index */ public void putRoleMapping(PutRoleMappingRequest request, ActionListener listener) { + if (enabled == false) { + listener.onFailure(new IllegalStateException("Native role mapping management is disabled")); + return; + } // Validate all templates before storing the role mapping for (TemplateRoleName templateRoleName : request.getRoleTemplates()) { templateRoleName.validate(scriptService); @@ -201,6 +196,10 @@ public void putRoleMapping(PutRoleMappingRequest request, ActionListener listener) { + if (enabled == false) { + listener.onFailure(new IllegalStateException("Native role mapping management is disabled")); + return; + } modifyMapping(request.getName(), this::innerDeleteMapping, request, listener); } @@ -229,6 +228,10 @@ private void modifyMapping( } private void innerPutMapping(PutRoleMappingRequest request, ActionListener listener) { + if (enabled == false) { + listener.onFailure(new IllegalStateException("Native role mapping management is disabled")); + return; + } final ExpressionRoleMapping mapping = request.getMapping(); securityIndex.prepareIndexIfNeededThenExecute(listener::onFailure, () -> { final XContentBuilder xContentBuilder; @@ -266,6 +269,10 @@ public void onFailure(Exception e) { } private void innerDeleteMapping(DeleteRoleMappingRequest request, ActionListener listener) { + if (enabled == false) { + listener.onFailure(new IllegalStateException("Native role mapping management is disabled")); + return; + } final SecurityIndexManager frozenSecurityIndex = securityIndex.defensiveCopy(); if (frozenSecurityIndex.indexExists() == false) { listener.onResponse(false); @@ -307,7 +314,9 @@ public void onFailure(Exception e) { * Otherwise it retrieves the specified mappings by name. */ public void getRoleMappings(Set names, ActionListener> listener) { - if (names == null || names.isEmpty()) { + if (enabled == false) { + listener.onResponse(List.of()); + } else if (names == null || names.isEmpty()) { getMappings(listener); } else { getMappings(listener.safeMap(mappings -> mappings.stream().filter(m -> names.contains(m.getName())).toList())); @@ -315,10 +324,14 @@ public void getRoleMappings(Set names, ActionListener> listener) { + if (enabled == false) { + listener.onResponse(List.of()); + return; + } final SecurityIndexManager frozenSecurityIndex = securityIndex.defensiveCopy(); if (frozenSecurityIndex.indexExists() == false) { logger.debug("The security index does not exist - no role mappings can be loaded"); - listener.onResponse(Collections.emptyList()); + listener.onResponse(List.of()); return; } final List lastLoad = lastLoadRef.get(); @@ -329,7 +342,7 @@ private void getMappings(ActionListener> listener) { listener.onResponse(lastLoad); } else { logger.debug("The security index exists but is closed - no role mappings can be loaded"); - listener.onResponse(Collections.emptyList()); + listener.onResponse(List.of()); } } else if (frozenSecurityIndex.isAvailable(SEARCH_SHARDS) == false) { final ElasticsearchException unavailableReason = frozenSecurityIndex.getUnavailableReason(SEARCH_SHARDS); @@ -365,20 +378,15 @@ List getLastLoad() { * */ public void usageStats(ActionListener> listener) { - if (securityIndex.indexIsClosed() || securityIndex.isAvailable(SEARCH_SHARDS) == false) { - reportStats(listener, Collections.emptyList()); + if (enabled == false) { + reportStats(listener, List.of()); + } else if (securityIndex.indexIsClosed() || securityIndex.isAvailable(SEARCH_SHARDS) == false) { + reportStats(listener, List.of()); } else { getMappings(ActionListener.wrap(mappings -> reportStats(listener, mappings), listener::onFailure)); } } - private static void reportStats(ActionListener> listener, List mappings) { - Map usageStats = new HashMap<>(); - usageStats.put("size", mappings.size()); - usageStats.put("enabled", mappings.stream().filter(ExpressionRoleMapping::isEnabled).count()); - listener.onResponse(usageStats); - } - public void onSecurityIndexStateChange(SecurityIndexManager.State previousState, SecurityIndexManager.State currentState) { if (isMoveFromRedToNonRed(previousState, currentState) || isIndexDeleted(previousState, currentState) @@ -388,28 +396,6 @@ public void onSecurityIndexStateChange(SecurityIndexManager.State previousState, } } - private void refreshRealms(ActionListener listener, Result result) { - if (realmsToRefresh.isEmpty()) { - listener.onResponse(result); - return; - } - - final String[] realmNames = this.realmsToRefresh.toArray(Strings.EMPTY_ARRAY); - executeAsyncWithOrigin( - client, - SECURITY_ORIGIN, - ClearRealmCacheAction.INSTANCE, - new ClearRealmCacheRequest().realms(realmNames), - ActionListener.wrap(response -> { - logger.debug(() -> format("Cleared cached in realms [%s] due to role mapping change", Arrays.toString(realmNames))); - listener.onResponse(result); - }, ex -> { - logger.warn(() -> "Failed to clear cache for realms [" + Arrays.toString(realmNames) + "]", ex); - listener.onFailure(ex); - }) - ); - } - @Override public void resolveRoles(UserData user, ActionListener> listener) { getRoleMappings(null, ActionListener.wrap(mappings -> { @@ -438,4 +424,57 @@ public void resolveRoles(UserData user, ActionListener> listener) { public void refreshRealmOnChange(CachingRealm realm) { realmsToRefresh.add(realm.name()); } + + private void refreshRealms(ActionListener listener, Result result) { + if (enabled == false || realmsToRefresh.isEmpty()) { + listener.onResponse(result); + return; + } + final String[] realmNames = this.realmsToRefresh.toArray(Strings.EMPTY_ARRAY); + executeAsyncWithOrigin( + client, + SECURITY_ORIGIN, + ClearRealmCacheAction.INSTANCE, + new ClearRealmCacheRequest().realms(realmNames), + ActionListener.wrap(response -> { + logger.debug(() -> format("Cleared cached in realms [%s] due to role mapping change", Arrays.toString(realmNames))); + listener.onResponse(result); + }, ex -> { + logger.warn(() -> "Failed to clear cache for realms [" + Arrays.toString(realmNames) + "]", ex); + listener.onFailure(ex); + }) + ); + } + + protected static ExpressionRoleMapping buildMapping(String id, BytesReference source) { + try ( + XContentParser parser = XContentHelper.createParserNotCompressed( + LoggingDeprecationHandler.XCONTENT_PARSER_CONFIG, + source, + XContentType.JSON + ) + ) { + return ExpressionRoleMapping.parse(id, parser); + } catch (Exception e) { + logger.warn(() -> "Role mapping [" + id + "] cannot be parsed and will be skipped", e); + return null; + } + } + + // package-private for testing + static String getIdForName(String name) { + return ID_PREFIX + name; + } + + private static void reportStats(ActionListener> listener, List mappings) { + Map usageStats = new HashMap<>(); + usageStats.put("size", mappings.size()); + usageStats.put("enabled", mappings.stream().filter(ExpressionRoleMapping::isEnabled).count()); + listener.onResponse(usageStats); + } + + private static String getNameFromId(String id) { + assert id.startsWith(ID_PREFIX); + return id.substring(ID_PREFIX.length()); + } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/FileRoleValidator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/FileRoleValidator.java new file mode 100644 index 0000000000000..9f4705d34b320 --- /dev/null +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/FileRoleValidator.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.authz; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; + +/** + * Provides a check which will be applied to roles in the file-based roles store. + */ +@FunctionalInterface +public interface FileRoleValidator { + ActionRequestValidationException validatePredefinedRole(RoleDescriptor roleDescriptor); + + /** + * The default file role validator used in stateful Elasticsearch, a no-op. + */ + class Default implements FileRoleValidator { + @Override + public ActionRequestValidationException validatePredefinedRole(RoleDescriptor roleDescriptor) { + return null; + } + } +} diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/FileRolesStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/FileRolesStore.java index d7c8f11c467f2..d769e44f2d38d 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/FileRolesStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/FileRolesStore.java @@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.set.Sets; @@ -36,13 +37,13 @@ import org.elasticsearch.xpack.core.security.authz.support.DLSRoleQueryValidator; import org.elasticsearch.xpack.core.security.support.NoOpLogger; import org.elasticsearch.xpack.core.security.support.Validation; +import org.elasticsearch.xpack.security.authz.FileRoleValidator; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -55,7 +56,9 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; import static java.util.Collections.unmodifiableMap; +import static java.util.Collections.unmodifiableSet; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.security.SecurityField.DOCUMENT_LEVEL_SECURITY_FEATURE; @@ -67,6 +70,7 @@ public class FileRolesStore implements BiConsumer, ActionListener>> listeners = new ArrayList<>(); @@ -78,9 +82,10 @@ public FileRolesStore( Environment env, ResourceWatcherService watcherService, XPackLicenseState licenseState, - NamedXContentRegistry xContentRegistry + NamedXContentRegistry xContentRegistry, + FileRoleValidator roleValidator ) throws IOException { - this(settings, env, watcherService, null, licenseState, xContentRegistry); + this(settings, env, watcherService, null, roleValidator, licenseState, xContentRegistry); } FileRolesStore( @@ -88,11 +93,13 @@ public FileRolesStore( Environment env, ResourceWatcherService watcherService, Consumer> listener, + FileRoleValidator roleValidator, XPackLicenseState licenseState, NamedXContentRegistry xContentRegistry ) throws IOException { this.settings = settings; this.file = resolveFile(env); + this.roleValidator = roleValidator; if (listener != null) { listeners.add(listener); } @@ -101,7 +108,7 @@ public FileRolesStore( FileWatcher watcher = new FileWatcher(file.getParent()); watcher.addListener(new FileListener()); watcherService.add(watcher, ResourceWatcherService.Frequency.HIGH); - permissions = parseFile(file, logger, settings, licenseState, xContentRegistry); + permissions = parseFile(file, logger, settings, licenseState, xContentRegistry, roleValidator); } @Override @@ -176,27 +183,45 @@ public static Path resolveFile(Environment env) { } public static Set parseFileForRoleNames(Path path, Logger logger) { - // EMPTY is safe here because we never use namedObject as we are just parsing role names - return parseRoleDescriptors(path, logger, false, Settings.EMPTY, NamedXContentRegistry.EMPTY).keySet(); - } + if (logger == null) { + logger = NoOpLogger.INSTANCE; + } + + Map roles = new HashMap<>(); + logger.trace("attempting to read roles file located at [{}]", path.toAbsolutePath()); + if (Files.exists(path)) { + try { + List roleSegments = roleSegments(path); + for (String segment : roleSegments) { + RoleDescriptor rd = parseRoleDescriptor( + segment, + path, + logger, + false, + Settings.EMPTY, + NamedXContentRegistry.EMPTY, + new FileRoleValidator.Default() + ); + if (rd != null) { + roles.put(rd.getName(), rd); + } + } + } catch (IOException ioe) { + logger.error(() -> format("failed to read roles file [%s]. skipping all roles...", path.toAbsolutePath()), ioe); + return emptySet(); + } + } + return unmodifiableSet(roles.keySet()); - public static Map parseFile( - Path path, - Logger logger, - Settings settings, - XPackLicenseState licenseState, - NamedXContentRegistry xContentRegistry - ) { - return parseFile(path, logger, true, settings, licenseState, xContentRegistry); } public static Map parseFile( Path path, Logger logger, - boolean resolvePermission, Settings settings, XPackLicenseState licenseState, - NamedXContentRegistry xContentRegistry + NamedXContentRegistry xContentRegistry, + FileRoleValidator roleValidator ) { if (logger == null) { logger = NoOpLogger.INSTANCE; @@ -210,7 +235,7 @@ public static Map parseFile( final boolean isDlsLicensed = DOCUMENT_LEVEL_SECURITY_FEATURE.checkWithoutTracking(licenseState); for (String segment : roleSegments) { - RoleDescriptor descriptor = parseRoleDescriptor(segment, path, logger, resolvePermission, settings, xContentRegistry); + RoleDescriptor descriptor = parseRoleDescriptor(segment, path, logger, true, settings, xContentRegistry, roleValidator); if (descriptor != null) { if (ReservedRolesStore.isReserved(descriptor.getName())) { logger.warn( @@ -243,36 +268,6 @@ public static Map parseFile( return unmodifiableMap(roles); } - public static Map parseRoleDescriptors( - Path path, - Logger logger, - boolean resolvePermission, - Settings settings, - NamedXContentRegistry xContentRegistry - ) { - if (logger == null) { - logger = NoOpLogger.INSTANCE; - } - - Map roles = new HashMap<>(); - logger.trace("attempting to read roles file located at [{}]", path.toAbsolutePath()); - if (Files.exists(path)) { - try { - List roleSegments = roleSegments(path); - for (String segment : roleSegments) { - RoleDescriptor rd = parseRoleDescriptor(segment, path, logger, resolvePermission, settings, xContentRegistry); - if (rd != null) { - roles.put(rd.getName(), rd); - } - } - } catch (IOException ioe) { - logger.error(() -> format("failed to read roles file [%s]. skipping all roles...", path.toAbsolutePath()), ioe); - return emptyMap(); - } - } - return unmodifiableMap(roles); - } - @Nullable static RoleDescriptor parseRoleDescriptor( String segment, @@ -280,7 +275,8 @@ static RoleDescriptor parseRoleDescriptor( Logger logger, boolean resolvePermissions, Settings settings, - NamedXContentRegistry xContentRegistry + NamedXContentRegistry xContentRegistry, + FileRoleValidator roleValidator ) { String roleName = null; XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withRegistry(xContentRegistry) @@ -311,7 +307,7 @@ static RoleDescriptor parseRoleDescriptor( // we pass true as last parameter because we do not want to reject files if field permissions // are given in 2.x syntax RoleDescriptor descriptor = RoleDescriptor.parse(roleName, parser, true, false); - return checkDescriptor(descriptor, path, logger, settings, xContentRegistry); + return checkDescriptor(descriptor, path, logger, settings, xContentRegistry, roleValidator); } else { logger.error("invalid role definition [{}] in roles file [{}]. skipping role...", roleName, path.toAbsolutePath()); return null; @@ -344,7 +340,8 @@ private static RoleDescriptor checkDescriptor( Path path, Logger logger, Settings settings, - NamedXContentRegistry xContentRegistry + NamedXContentRegistry xContentRegistry, + FileRoleValidator roleValidator ) { String roleName = descriptor.getName(); // first check if FLS/DLS is enabled on the role... @@ -374,6 +371,10 @@ private static RoleDescriptor checkDescriptor( } } } + ActionRequestValidationException ex = roleValidator.validatePredefinedRole(descriptor); + if (ex != null) { + throw ex; + } return descriptor; } @@ -417,7 +418,7 @@ public synchronized void onFileChanged(Path file) { if (file.equals(FileRolesStore.this.file)) { final Map previousPermissions = permissions; try { - permissions = parseFile(file, logger, settings, licenseState, xContentRegistry); + permissions = parseFile(file, logger, settings, licenseState, xContentRegistry, roleValidator); } catch (Exception e) { logger.error( () -> format("could not reload roles file [%s]. Current roles remain unmodified", file.toAbsolutePath()), @@ -431,7 +432,7 @@ public synchronized void onFileChanged(Path file) { .map(Map.Entry::getKey) .collect(Collectors.toSet()); final Set addedRoles = Sets.difference(permissions.keySet(), previousPermissions.keySet()); - final Set changedRoles = Collections.unmodifiableSet(Sets.union(changedOrMissingRoles, addedRoles)); + final Set changedRoles = unmodifiableSet(Sets.union(changedOrMissingRoles, addedRoles)); if (changedRoles.isEmpty() == false) { logger.info("updated roles (roles file [{}] {})", file.toAbsolutePath(), Files.exists(file) ? "changed" : "removed"); listeners.forEach(c -> c.accept(changedRoles)); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/NativeRoleMappingBaseRestHandler.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/NativeRoleMappingBaseRestHandler.java new file mode 100644 index 0000000000000..e0d692814988b --- /dev/null +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/NativeRoleMappingBaseRestHandler.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.rest.action.rolemapping; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore; +import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler; + +abstract class NativeRoleMappingBaseRestHandler extends SecurityBaseRestHandler { + + private static final Logger logger = LogManager.getLogger(NativeRoleMappingBaseRestHandler.class); + + NativeRoleMappingBaseRestHandler(Settings settings, XPackLicenseState licenseState) { + super(settings, licenseState); + } + + @Override + protected Exception innerCheckFeatureAvailable(RestRequest request) { + Boolean nativeRoleMappingsEnabled = settings.getAsBoolean(NativeRoleMappingStore.NATIVE_ROLE_MAPPINGS_ENABLED, true); + if (nativeRoleMappingsEnabled == false) { + logger.debug( + "Attempt to call [{} {}] but [{}] is [{}]", + request.method(), + request.rawPath(), + NativeRoleMappingStore.NATIVE_ROLE_MAPPINGS_ENABLED, + settings.get(NativeRoleMappingStore.NATIVE_ROLE_MAPPINGS_ENABLED) + ); + return new ElasticsearchStatusException( + "Native role mapping management is not enabled in this Elasticsearch instance", + RestStatus.GONE + ); + } else { + return null; + } + } +} diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestDeleteRoleMappingAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestDeleteRoleMappingAction.java index ee1952e359dd2..5964228009c4b 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestDeleteRoleMappingAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestDeleteRoleMappingAction.java @@ -19,7 +19,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.security.action.rolemapping.DeleteRoleMappingRequestBuilder; import org.elasticsearch.xpack.core.security.action.rolemapping.DeleteRoleMappingResponse; -import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler; import java.io.IOException; import java.util.List; @@ -30,7 +29,7 @@ * Rest endpoint to delete a role-mapping from the {@link org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore} */ @ServerlessScope(Scope.INTERNAL) -public class RestDeleteRoleMappingAction extends SecurityBaseRestHandler { +public class RestDeleteRoleMappingAction extends NativeRoleMappingBaseRestHandler { public RestDeleteRoleMappingAction(Settings settings, XPackLicenseState licenseState) { super(settings, licenseState); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestGetRoleMappingsAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestGetRoleMappingsAction.java index 36b3f05668d0a..7a3378d843bca 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestGetRoleMappingsAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestGetRoleMappingsAction.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.core.security.action.rolemapping.GetRoleMappingsRequestBuilder; import org.elasticsearch.xpack.core.security.action.rolemapping.GetRoleMappingsResponse; import org.elasticsearch.xpack.core.security.authc.support.mapper.ExpressionRoleMapping; -import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler; import java.io.IOException; import java.util.List; @@ -31,7 +30,7 @@ * Rest endpoint to retrieve a role-mapping from the org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore */ @ServerlessScope(Scope.INTERNAL) -public class RestGetRoleMappingsAction extends SecurityBaseRestHandler { +public class RestGetRoleMappingsAction extends NativeRoleMappingBaseRestHandler { public RestGetRoleMappingsAction(Settings settings, XPackLicenseState licenseState) { super(settings, licenseState); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestPutRoleMappingAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestPutRoleMappingAction.java index bb6b07c1c3c95..e7e24037543fa 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestPutRoleMappingAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestPutRoleMappingAction.java @@ -19,7 +19,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.security.action.rolemapping.PutRoleMappingRequestBuilder; import org.elasticsearch.xpack.core.security.action.rolemapping.PutRoleMappingResponse; -import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler; import java.io.IOException; import java.util.List; @@ -33,7 +32,7 @@ * @see org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore */ @ServerlessScope(Scope.INTERNAL) -public class RestPutRoleMappingAction extends SecurityBaseRestHandler { +public class RestPutRoleMappingAction extends NativeRoleMappingBaseRestHandler { public RestPutRoleMappingAction(Settings settings, XPackLicenseState licenseState) { super(settings, licenseState); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java index ca08f63a09bb0..462b41a519460 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java @@ -77,7 +77,11 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor "internal:admin/ccr/restore/session/clear", "indices:internal/admin/ccr/restore/session/clear", "internal:admin/ccr/restore/file_chunk/get", - "indices:internal/admin/ccr/restore/file_chunk/get" + "indices:internal/admin/ccr/restore/file_chunk/get", + "internal:data/read/esql/open_exchange", + "cluster:internal:data/read/esql/open_exchange", + "internal:data/read/esql/exchange", + "cluster:internal:data/read/esql/exchange" ); private final AuthenticationService authcService; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/FileRolesStoreTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/FileRolesStoreTests.java index 0f9dd06983792..65f2919541e07 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/FileRolesStoreTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/FileRolesStoreTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.core.security.authz.privilege.IndexPrivilege; import org.elasticsearch.xpack.core.security.authz.store.ReservedRolesStore; import org.elasticsearch.xpack.core.security.support.Automatons; +import org.elasticsearch.xpack.security.authz.FileRoleValidator; import org.junit.BeforeClass; import java.io.BufferedWriter; @@ -104,7 +105,8 @@ public void testParseFile() throws Exception { logger, Settings.builder().put(XPackSettings.DLS_FLS_ENABLED.getKey(), true).build(), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(10)); @@ -295,7 +297,8 @@ public void testParseFileWithRemoteIndices() throws IllegalAccessException, IOEx logger, Settings.builder().put(XPackSettings.DLS_FLS_ENABLED.getKey(), true).build(), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(2)); @@ -359,7 +362,8 @@ public void testParseFileWithFLSAndDLSDisabled() throws Exception { logger, Settings.builder().put(XPackSettings.DLS_FLS_ENABLED.getKey(), false).build(), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(7)); @@ -410,7 +414,14 @@ public void testParseFileWithFLSAndDLSUnlicensed() throws Exception { events.clear(); MockLicenseState licenseState = mock(MockLicenseState.class); when(licenseState.isAllowed(DOCUMENT_LEVEL_SECURITY_FEATURE)).thenReturn(false); - Map roles = FileRolesStore.parseFile(path, logger, Settings.EMPTY, licenseState, xContentRegistry()); + Map roles = FileRolesStore.parseFile( + path, + logger, + Settings.EMPTY, + licenseState, + xContentRegistry(), + new FileRoleValidator.Default() + ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(10)); assertNotNull(roles.get("role_fields")); @@ -445,7 +456,8 @@ public void testDefaultRolesFile() throws Exception { logger, Settings.EMPTY, TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(0)); @@ -474,7 +486,7 @@ public void testAutoReload() throws Exception { FileRolesStore store = new FileRolesStore(settings, env, watcherService, roleSet -> { modifiedRoles.addAll(roleSet); latch.countDown(); - }, TestUtils.newTestLicenseState(), xContentRegistry()); + }, new FileRoleValidator.Default(), TestUtils.newTestLicenseState(), xContentRegistry()); Set descriptors = store.roleDescriptors(Collections.singleton("role1")); assertThat(descriptors, notNullValue()); @@ -534,7 +546,7 @@ public void testAutoReload() throws Exception { if (roleSet.contains("dummy1")) { truncateLatch.countDown(); } - }, TestUtils.newTestLicenseState(), xContentRegistry()); + }, new FileRoleValidator.Default(), TestUtils.newTestLicenseState(), xContentRegistry()); final Set allRolesPreTruncate = store.getAllRoleNames(); assertTrue(allRolesPreTruncate.contains("role5")); @@ -563,7 +575,7 @@ public void testAutoReload() throws Exception { if (roleSet.contains("dummy2")) { modifyLatch.countDown(); } - }, TestUtils.newTestLicenseState(), xContentRegistry()); + }, new FileRoleValidator.Default(), TestUtils.newTestLicenseState(), xContentRegistry()); try (BufferedWriter writer = Files.newBufferedWriter(tmp, StandardCharsets.UTF_8, StandardOpenOption.TRUNCATE_EXISTING)) { writer.append("role5:").append(System.lineSeparator()); @@ -596,7 +608,8 @@ public void testThatEmptyFileDoesNotResultInLoop() throws Exception { logger, Settings.EMPTY, TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles.keySet(), is(empty())); } @@ -611,7 +624,8 @@ public void testThatInvalidRoleDefinitions() throws Exception { logger, Settings.EMPTY, TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles.size(), is(1)); assertThat(roles, hasKey("valid_role")); @@ -660,7 +674,8 @@ public void testReservedRoles() throws Exception { logger, Settings.EMPTY, TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(2)); @@ -696,7 +711,8 @@ public void testUsageStats() throws Exception { env, mock(ResourceWatcherService.class), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); Map usageStats = store.usageStats(); @@ -723,14 +739,16 @@ public void testExists() throws Exception { env, mock(ResourceWatcherService.class), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); Map roles = FileRolesStore.parseFile( path, logger, Settings.builder().put(XPackSettings.DLS_FLS_ENABLED.getKey(), true).build(), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(10)); @@ -745,7 +763,15 @@ public void testBWCFieldPermissions() throws IOException { Path path = getDataPath("roles2xformat.yml"); byte[] bytes = Files.readAllBytes(path); String roleString = new String(bytes, Charset.defaultCharset()); - RoleDescriptor role = FileRolesStore.parseRoleDescriptor(roleString, path, logger, true, Settings.EMPTY, xContentRegistry()); + RoleDescriptor role = FileRolesStore.parseRoleDescriptor( + roleString, + path, + logger, + true, + Settings.EMPTY, + xContentRegistry(), + new FileRoleValidator.Default() + ); RoleDescriptor.IndicesPrivileges indicesPrivileges = role.getIndicesPrivileges()[0]; assertThat(indicesPrivileges.getGrantedFields(), arrayContaining("foo", "boo")); assertNull(indicesPrivileges.getDeniedFields()); diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/DesiredBalanceShutdownIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/DesiredBalanceShutdownIT.java index ceedda30626c6..ce1704639527d 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/DesiredBalanceShutdownIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/DesiredBalanceShutdownIT.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; @@ -36,7 +35,7 @@ protected Collection> nodePlugins() { public void testDesiredBalanceWithShutdown() throws Exception { final var oldNodeName = internalCluster().startNode(); - final var oldNodeId = internalCluster().getInstance(ClusterService.class, oldNodeName).localNode().getId(); + final var oldNodeId = getNodeId(oldNodeName); createIndex( INDEX, diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownPluginsIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownPluginsIT.java index 3a1280307b739..c87fa08e8c972 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownPluginsIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownPluginsIT.java @@ -9,10 +9,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; -import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ShutdownAwarePlugin; import org.elasticsearch.test.ESIntegTestCase; @@ -44,21 +41,8 @@ public void testShutdownAwarePlugin() throws Exception { final String shutdownNode; final String remainNode; - NodesInfoResponse nodes = clusterAdmin().prepareNodesInfo().clear().get(); - final String node1Id = nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(node1)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); - final String node2Id = nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(node2)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); + final String node1Id = getNodeId(node1); + final String node2Id = getNodeId(node2); if (randomBoolean()) { shutdownNode = node1Id; diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownReadinessIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownReadinessIT.java index af0713665731c..6dfbb8360e763 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownReadinessIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownReadinessIT.java @@ -7,10 +7,7 @@ package org.elasticsearch.xpack.shutdown; -import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; -import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.plugins.Plugin; @@ -93,17 +90,6 @@ private void deleteNodeShutdown(String nodeId) { assertAcked(client().execute(DeleteShutdownNodeAction.INSTANCE, new DeleteShutdownNodeAction.Request(nodeId))); } - private String getNodeId(String nodeName) { - NodesInfoResponse nodes = clusterAdmin().prepareNodesInfo().clear().get(); - return nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(nodeName)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); - } - private void assertNoShuttingDownNodes(String nodeId) throws ExecutionException, InterruptedException { var response = client().execute(GetShutdownStatusAction.INSTANCE, new GetShutdownStatusAction.Request(nodeId)).get(); assertThat(response.getShutdownStatuses(), empty()); diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownShardsIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownShardsIT.java index fad05e6f213d5..fda2a5755be55 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownShardsIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownShardsIT.java @@ -8,13 +8,11 @@ package org.elasticsearch.xpack.shutdown; import org.elasticsearch.action.admin.cluster.allocation.ClusterAllocationExplainResponse; -import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.RoutingNodesHelper; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRoutingState; @@ -456,17 +454,6 @@ private String findIdOfNodeWithPrimaryShard(String indexName) { ); } - private String getNodeId(String nodeName) { - NodesInfoResponse nodes = clusterAdmin().prepareNodesInfo().clear().get(); - return nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(nodeName)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); - } - private void putNodeShutdown(String nodeId, SingleNodeShutdownMetadata.Type type, String nodeReplacementName) throws Exception { assertAcked( client().execute( diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java index 7c32311237c57..dc4e6b9c53fda 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java @@ -12,8 +12,6 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; -import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; @@ -78,21 +76,8 @@ public void testTasksAreNotAssignedToShuttingDownNode() throws Exception { final String shutdownNode; final String candidateNode; - NodesInfoResponse nodes = clusterAdmin().prepareNodesInfo().clear().get(); - final String node1Id = nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(node1)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); - final String node2Id = nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(node2)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); + final String node1Id = getNodeId(node1); + final String node2Id = getNodeId(node2); if (randomBoolean()) { shutdownNode = node1Id; diff --git a/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java b/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java index c7efb27509ef7..790af0a201578 100644 --- a/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java +++ b/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java @@ -25,8 +25,8 @@ import org.junit.Before; import java.util.ArrayList; -import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.OptionalDouble; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -321,16 +321,10 @@ private IndexMetadata createIndexMetadata( } private DataStream createDataStream(String name, List backingIndices) { - return new DataStream( - name, - backingIndices, - backingIndices.size(), - Collections.emptyMap(), - false, - false, - false, - false, - IndexMode.STANDARD - ); + return DataStream.builder(name, backingIndices) + .setGeneration(backingIndices.size()) + .setMetadata(Map.of()) + .setIndexMode(IndexMode.STANDARD) + .build(); } }