Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Reciprocal Rank Fusion (RRF) in hybrid query #1086

Merged
merged 16 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
c2ddbab
Reciprocal Rank Fusion (RRF) normalization technique in hybrid query …
Johnsonisaacn Oct 18, 2024
3a8dee6
Reciprocal Rank Fusion (RRF) normalization technique in hybrid query …
Johnsonisaacn Oct 18, 2024
2f886fa
Add integration and unit tests for missing RRF coverage (#997)
ryanbogan Dec 3, 2024
7d6599c
Integrate explainability for hybrid query into RRF processor (#1037)
martin-gaievski Dec 23, 2024
9ce78f5
Support of new k-NN query parameter expand_nested. (#1013)
bzhangam Dec 18, 2024
1339c79
[Enhancement] Implement pruning for neural sparse search (#988)
zhichao-aws Dec 18, 2024
0afd102
Remove mistakenly added code from HybridSearchIT. (#1032)
bzhangam Dec 18, 2024
d5a176e
Fix bug where ingestion failed for input document containing list of …
yizheliu-amazon Jan 3, 2025
503189c
add support for builder constructor in neural query builder (#1047)
will-hwang Jan 7, 2025
1315ab2
add hybrid search with rescore IT (#1066)
will-hwang Jan 8, 2025
597d2b4
Fix bug where document embedding fails to be generated due to documen…
yizheliu-amazon Jan 8, 2025
6be95ce
Clean up unused validateFieldName() and use existing methods for Text…
yizheliu-amazon Jan 8, 2025
98ab28f
Correct NeuralQueryBuilder doEquals() and doHashCode(). (#1045)
bzhangam Jan 9, 2025
f5ae67a
Reciprocal Rank Fusion (RRF) normalization technique in hybrid query …
Johnsonisaacn Oct 18, 2024
6e5596d
Add integration and unit tests for missing RRF coverage (#997)
ryanbogan Dec 3, 2024
312c7f7
Integrate explainability for hybrid query into RRF processor (#1037)
martin-gaievski Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Pagination in Hybrid query ([#1048](https://github.com/opensearch-project/neural-search/pull/1048))
- Implement Reciprocal Rank Fusion score normalization/combination technique in hybrid query ([#874](https://github.com/opensearch-project/neural-search/pull/874))
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ private void validateIndexQuery(final String modelId) {
.modelId(modelId)
.maxDistance(100000f)
.build();

Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
assertNotNull(responseWithMaxDistanceQuery);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,24 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.RRFProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
Expand Down Expand Up @@ -157,7 +159,9 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseR
) {
return Map.of(
NormalizationProcessor.TYPE,
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory),
RRFProcessor.TYPE,
new RRFProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

import java.util.Optional;

/**
* Base class for all score hybridization processors. This class is responsible for executing the score hybridization process.
* It is a pipeline processor that is executed after the query phase and before the fetch phase.
*/
public abstract class AbstractScoreHybridizationProcessor implements SearchPhaseResultsProcessor {
/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult
* @param searchPhaseContext
* @param requestContextOptional
* @param <Result>
*/
abstract <Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ public SearchResponse processResponse(
);
}
// Create and set final explanation combining all components
Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore();
Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
finalScore,
// combination level explanation is always a single detail
combinationExplanation.getScoreDetails().get(0).getValue(),
normalizedExplanation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ Map<String, Object> buildMapWithTargetKeys(IngestDocument ingestDocument) {
buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes);
mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey));
} else {
mapWithProcessorKeys.put(String.valueOf(targetKey), normalizeSourceValue(sourceAndMetadataMap.get(originalKey)));
mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
}
}
return mapWithProcessorKeys;
Expand Down Expand Up @@ -357,9 +357,8 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
}
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
} else {
Object parentValue = sourceAndMetadataMap.get(parentKey);
String key = String.valueOf(processorKey);
treeRes.put(key, normalizeSourceValue(parentValue));
treeRes.put(key, sourceAndMetadataMap.get(parentKey));
}
}

Expand Down Expand Up @@ -404,7 +403,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
indexName,
clusterService,
environment,
true
false
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

import java.util.List;
import java.util.Optional;

/**
* DTO object to hold data in NormalizationProcessorWorkflow class
* in NormalizationProcessorWorkflow.
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizationExecuteDTO {
@NonNull
private List<QuerySearchResult> querySearchResults;
@NonNull
private Optional<FetchSearchResult> fetchSearchResultOptional;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
@NonNull
private ScoreCombinationTechnique combinationTechnique;
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;

import lombok.AllArgsConstructor;
Expand All @@ -33,7 +31,7 @@
*/
@Log4j2
@AllArgsConstructor
public class NormalizationProcessor implements SearchPhaseResultsProcessor {
public class NormalizationProcessor extends AbstractScoreHybridizationProcessor {
public static final String TYPE = "normalization-processor";

private final String tag;
Expand All @@ -42,38 +40,8 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {
private final ScoreCombinationTechnique combinationTechnique;
private final NormalizationProcessorWorkflow normalizationWorkflow;

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWorkflow(
<Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
Expand Down Expand Up @@ -57,44 +55,30 @@ public class NormalizationProcessorWorkflow {

/**
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
* @param request contains querySearchResults input data with QuerySearchResult
* from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization
* combinationTechnique technique for score combination, and nullable rankConstant only used in RRF technique
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique,
final SearchPhaseContext searchPhaseContext
) {
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResultOptional)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(false)
.searchPhaseContext(searchPhaseContext)
.build();
execute(request);
}

public void execute(final NormalizationProcessorWorkflowExecuteRequest request) {
List<QuerySearchResult> querySearchResults = request.getQuerySearchResults();
Optional<FetchSearchResult> fetchSearchResultOptional = request.getFetchSearchResultOptional();

// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);
List<Integer> unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults());

// pre-process data
log.debug("Pre-process query results");
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);

explain(request, queryTopDocs);

// Data transfer object for score normalization used to pass nullable rankConstant which is only used in RRF
NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder()
.queryTopDocs(queryTopDocs)
.normalizationTechnique(request.getNormalizationTechnique())
.build();

// normalize
log.debug("Do score normalization");
scoreNormalizer.normalizeScores(queryTopDocs, request.getNormalizationTechnique());
scoreNormalizer.normalizeScores(normalizeScoresDTO);

CombineScoresDto combineScoresDTO = CombineScoresDto.builder()
.queryTopDocs(queryTopDocs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;

import java.util.List;

/**
* DTO object to hold data required for score normalization.
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizeScoresDTO {
@NonNull
private List<CompoundTopDocs> queryTopDocs;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
}
Loading
Loading