Skip to content

Commit

Permalink
Fix explain exception in hybrid queries with partial subquery matches (
Browse files Browse the repository at this point in the history
…#1123) (#1137)

* Fixed exception for explain in hybrid query when partial match in subqueries

Signed-off-by: Martin Gaievski <[email protected]>
(cherry picked from commit 8c743ec)

Co-authored-by: Martin Gaievski <[email protected]>
  • Loading branch information
1 parent 212e782 commit 20339b5
Show file tree
Hide file tree
Showing 14 changed files with 893 additions and 88 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
- Update NeuralQueryBuilder doEquals() and doHashCode() to cater the missing parameters information ([#1045](https://github.com/opensearch-project/neural-search/pull/1045)).
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
- Fix explain exception in hybrid queries with partial subquery matches ([#1123](https://github.com/opensearch-project/neural-search/pull/1123))
- Handle pagination_depth when from =0 and removes default value of pagination_depth ([#1132](https://github.com/opensearch-project/neural-search/pull/1132))
### Infrastructure
- Update batch related tests to use batch_size in processor & refactor BWC version check ([#852](https://github.com/opensearch-project/neural-search/pull/852))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.lucene.search.Explanation;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -21,6 +23,7 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

Expand All @@ -32,6 +35,7 @@
*/
@Getter
@AllArgsConstructor
@Log4j2
public class ExplanationResponseProcessor implements SearchResponseProcessor {

public static final String TYPE = "hybrid_score_explanation";
Expand Down Expand Up @@ -99,16 +103,40 @@ public SearchResponse processResponse(
ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations();
ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations();
// Create normalized explanations for each detail
Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length];
for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) {
normalizedExplanation[i] = Explanation.match(
// normalized score
normalizationExplanation.getScoreDetails().get(i).getKey(),
// description of normalized score
normalizationExplanation.getScoreDetails().get(i).getValue(),
// shard level details
queryLevelExplanation.getDetails()[i]
if (normalizationExplanation.getScoreDetails().size() != queryLevelExplanation.getDetails().length) {
log.error(
String.format(
Locale.ROOT,
"length of query level explanations %d must match length of explanations after normalization %d",
queryLevelExplanation.getDetails().length,
normalizationExplanation.getScoreDetails().size()
)
);
throw new IllegalStateException("mismatch in number of query level explanations and normalization explanations");
}
List<Explanation> normalizedExplanation = new ArrayList<>(queryLevelExplanation.getDetails().length);
int normalizationExplanationIndex = 0;
for (Explanation queryExplanation : queryLevelExplanation.getDetails()) {
// adding only explanations where this hit has matched
if (Float.compare(queryExplanation.getValue().floatValue(), 0.0f) > 0) {
Pair<Float, String> normalizedScoreDetails = normalizationExplanation.getScoreDetails()
.get(normalizationExplanationIndex);
if (Objects.isNull(normalizedScoreDetails)) {
throw new IllegalStateException("normalized score details must not be null");
}
normalizedExplanation.add(
Explanation.match(
// normalized score
normalizedScoreDetails.getKey(),
// description of normalized score
normalizedScoreDetails.getValue(),
// shard level details
queryExplanation
)
);
}
// we increment index in all cases, scores in query explanation can be 0.0
normalizationExplanationIndex++;
}
// Create and set final explanation combining all components
Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,19 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs>
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
int numberOfSubQueries = topDocsPerSubQuery.size();
for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; subQueryIndex++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
float normalizedScore = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j));
normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore);
float normalizedScore = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(subQueryIndex));
ScoreNormalizationUtil.setNormalizedScore(
normalizedScores,
docIdAtSearchShard,
subQueryIndex,
numberOfSubQueries,
normalizedScore
);
scoreDoc.score = normalizedScore;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/
package org.opensearch.neuralsearch.processor.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -92,16 +91,23 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(final List<CompoundTo
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
int numberOfSubQueries = topDocsPerSubQuery.size();
for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; subQueryIndex++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
float normalizedScore = normalizeSingleScore(
scoreDoc.score,
minMaxScores.getMinScoresPerSubquery()[j],
minMaxScores.getMaxScoresPerSubquery()[j]
minMaxScores.getMinScoresPerSubquery()[subQueryIndex],
minMaxScores.getMaxScoresPerSubquery()[subQueryIndex]
);
ScoreNormalizationUtil.setNormalizedScore(
normalizedScores,
docIdAtSearchShard,
subQueryIndex,
numberOfSubQueries,
normalizedScore
);
normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore);
scoreDoc.score = normalizedScore;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,24 @@

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Locale;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.stream.IntStream;

import org.apache.commons.lang3.Range;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.common.TriConsumer;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

import lombok.ToString;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
Expand Down Expand Up @@ -65,7 +67,7 @@ public RRFNormalizationTechnique(final Map<String, Object> params, final ScoreNo
public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
final List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
processTopDocs(compoundQueryTopDocs, (docId, score) -> {});
processTopDocs(compoundQueryTopDocs, (docId, score, subQueryIndex) -> {});
}
}

Expand All @@ -79,31 +81,51 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs>
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();

for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
int numberOfSubQueries = topDocsPerSubQuery.size();
processTopDocs(
compoundQueryTopDocs,
(docId, score) -> normalizedScores.computeIfAbsent(docId, k -> new ArrayList<>()).add(score)
(docId, score, subQueryIndex) -> ScoreNormalizationUtil.setNormalizedScore(
normalizedScores,
docId,
subQueryIndex,
numberOfSubQueries,
score
)
);
}

return getDocIdAtQueryForNormalization(normalizedScores, this);
}

private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, BiConsumer<DocIdAtSearchShard, Float> scoreProcessor) {
private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, TriConsumer<DocIdAtSearchShard, Float, Integer> scoreProcessor) {
if (Objects.isNull(compoundQueryTopDocs)) {
return;
}

compoundQueryTopDocs.getTopDocs().forEach(topDocs -> {
IntStream.range(0, topDocs.scoreDocs.length).forEach(position -> {
float normalizedScore = calculateNormalizedScore(position);
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(
topDocs.scoreDocs[position].doc,
compoundQueryTopDocs.getSearchShard()
);
scoreProcessor.accept(docIdAtSearchShard, normalizedScore);
topDocs.scoreDocs[position].score = normalizedScore;
});
});
List<TopDocs> topDocsList = compoundQueryTopDocs.getTopDocs();
SearchShard searchShard = compoundQueryTopDocs.getSearchShard();

for (int topDocsIndex = 0; topDocsIndex < topDocsList.size(); topDocsIndex++) {
processTopDocsEntry(topDocsList.get(topDocsIndex), searchShard, topDocsIndex, scoreProcessor);
}
}

private void processTopDocsEntry(
TopDocs topDocs,
SearchShard searchShard,
int topDocsIndex,
TriConsumer<DocIdAtSearchShard, Float, Integer> scoreProcessor
) {
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
float normalizedScore = calculateNormalizedScore(Arrays.asList(topDocs.scoreDocs).indexOf(scoreDoc));
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, searchShard);
scoreProcessor.apply(docIdAtSearchShard, normalizedScore, topDocsIndex);
scoreDoc.score = normalizedScore;
}
}

private float calculateNormalizedScore(int position) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
package org.opensearch.neuralsearch.processor.normalization;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -54,4 +56,30 @@ public void validateParams(final Map<String, Object> actualParams, final Set<Str
}
}
}

/**
* Sets a normalized score for a specific document at a specific subquery index
*
* @param normalizedScores map of document IDs to their list of scores
* @param docIdAtSearchShard document ID
* @param subQueryIndex index of the subquery
* @param normalizedScore normalized score to set
*/
public static void setNormalizedScore(
Map<DocIdAtSearchShard, List<Float>> normalizedScores,
DocIdAtSearchShard docIdAtSearchShard,
int subQueryIndex,
int numberOfSubQueries,
float normalizedScore
) {
List<Float> scores = normalizedScores.get(docIdAtSearchShard);
if (Objects.isNull(scores)) {
scores = new ArrayList<>(numberOfSubQueries);
for (int i = 0; i < numberOfSubQueries; i++) {
scores.add(0.0f);
}
normalizedScores.put(docIdAtSearchShard, scores);
}
scores.set(subQueryIndex, normalizedScore);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import java.util.concurrent.Callable;
import java.util.stream.Collectors;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
Expand All @@ -33,6 +35,7 @@
public final class HybridQueryWeight extends Weight {

// The Weights for our subqueries, in 1-1 correspondence
@Getter(AccessLevel.PACKAGE)
private final List<Weight> weights;

private final ScoreMode scoreMode;
Expand Down Expand Up @@ -157,10 +160,13 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
if (e.isMatch()) {
match = true;
double score = e.getValue().doubleValue();
subsOnMatch.add(e);
max = Math.max(max, score);
} else if (!match) {
subsOnNoMatch.add(e);
subsOnMatch.add(e);
} else {
if (!match) {
subsOnNoMatch.add(e);
}
subsOnMatch.add(e);
}
}
if (match) {
Expand Down
Loading

0 comments on commit 20339b5

Please sign in to comment.