diff --git a/CHANGELOG.md b/CHANGELOG.md index 93349b974..59d55b7d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062)) - Fix explain exception in hybrid queries with partial subquery matches ([#1123](https://github.com/opensearch-project/neural-search/pull/1123)) - Handle pagination_depth when from =0 and removes default value of pagination_depth ([#1132](https://github.com/opensearch-project/neural-search/pull/1132)) +- Fix single shard pagination issue of from ([#1140](https://github.com/opensearch-project/neural-search/pull/1140)) ### Infrastructure - Update batch related tests to use batch_size in processor & refactor BWC version check ([#852](https://github.com/opensearch-project/neural-search/pull/852)) - Fix CI for JDK upgrade towards 21 ([#835](https://github.com/opensearch-project/neural-search/pull/835)) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 1ddfe75a6..bf95e400f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -113,7 +113,16 @@ private int getFromValueIfSingleShard(final NormalizationProcessorWorkflowExecut if (searchPhaseContext.getNumShards() > 1 || request.fetchSearchResultOptional.isEmpty()) { return -1; } - return searchPhaseContext.getRequest().source().from(); + int from = searchPhaseContext.getRequest().source().from(); + // for the initial searchRequest, it creates a default search context which sets the value of + // from to 0 if it's -1. That's not the case with SearchPhaseContext, that's why need to + // explicitly set to 0 for the single shard case + // Ref: + // https://github.com/opensearch-project/OpenSearch/blob/2.18/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L288 + if (from == -1) { + return 0; + } + return from; } /** diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 61828d822..d98e6a88d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -398,6 +398,73 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu TestUtils.assertFetchResultScores(fetchSearchResult, 4); } + public void testNormalization_whenOneShardAndFromIsNegativeOne_thenSuccess() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + + // Setup query search results + List querySearchResults = new ArrayList<>(); + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + null + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE); + querySearchResult.setShardSearchRequest(shardSearchRequest); + querySearchResults.add(querySearchResult); + SearchHits searchHits = getSearchHits(); + fetchSearchResult.hits(searchHits); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + // searchSourceBuilder.from(); if no from is defined here it would initialize it to -1 + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + when(searchPhaseContext.getNumShards()).thenReturn(1); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + // Setup fetch search result + fetchSearchResult.hits(searchHits); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); + + // Verify that the fetch result has been updated correctly + TestUtils.assertQueryResultScores(querySearchResults); + TestUtils.assertFetchResultScores(fetchSearchResult, 4); + } + public void testNormalization_whenFromIsGreaterThanResultsSize_thenFail() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())