Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle pagination_depth when from =0 #1132

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
- Update NeuralQueryBuilder doEquals() and doHashCode() to cater the missing parameters information ([#1045](https://github.com/opensearch-project/neural-search/pull/1045)).
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
- Handle pagination_depth when from =0 and removes default value of pagination_depth ([#1132](https://github.com/opensearch-project/neural-search/pull/1132))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
private Integer paginationDepth;

static final int MAX_NUMBER_OF_SUB_QUERIES = 5;
private final static int DEFAULT_PAGINATION_DEPTH = 10;
private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 0;

public HybridQueryBuilder(StreamInput in) throws IOException {
Expand Down Expand Up @@ -167,7 +166,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException {
float boost = AbstractQueryBuilder.DEFAULT_BOOST;

int paginationDepth = DEFAULT_PAGINATION_DEPTH;
Integer paginationDepth = null;
final List<QueryBuilder> queries = new ArrayList<>();
String queryName = null;

Expand Down Expand Up @@ -324,7 +323,7 @@ private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, Quer
return queries;
}

private static void validatePaginationDepth(final int paginationDepth, final QueryShardContext queryShardContext) {
private static void validatePaginationDepth(final Integer paginationDepth, final QueryShardContext queryShardContext) {
if (Objects.isNull(paginationDepth)) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,20 @@ private ReduceableSearchResult reduceSearchResults(final List<ReduceableSearchRe
*/
private static int getSubqueryResultsRetrievalSize(final SearchContext searchContext) {
HybridQuery hybridQuery = unwrapHybridQuery(searchContext);
int paginationDepth = hybridQuery.getQueryContext().getPaginationDepth();
Integer paginationDepth = hybridQuery.getQueryContext().getPaginationDepth();

// Switch to from+size retrieval size during standard hybrid query execution.
if (searchContext.from() == 0) {
return searchContext.size();
// Pagination is expected to work only when pagination_depth is provided in the search request.
if (Objects.isNull(paginationDepth) && searchContext.from() > 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "pagination_depth param is missing in the search request"));
}

log.info("pagination_depth is {}", paginationDepth);
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
return paginationDepth;
if (Objects.nonNull(paginationDepth)) {
return paginationDepth;
}

// Switch to from+size retrieval size during standard hybrid query execution where from is 0.
return searchContext.size();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
.endObject()
.endObject()
.endArray()
.field("pagination_depth", 10)
.endObject();

NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -870,40 +870,6 @@ public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSucc
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful() {
try {
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder();
hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder());

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
hybridQueryBuilderOnlyMatchAll,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null,
null,
false,
null,
2
);

assertEquals(2, getHitCount(searchResponseAsMap));
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(4, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
} finally {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE);
}
}

@SneakyThrows
public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() {
try {
Expand All @@ -912,6 +878,7 @@ public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() {
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder();
hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder());
hybridQueryBuilderOnlyMatchAll.paginationDepth(10);

ResponseException responseException = assertThrows(
ResponseException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build();
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().build();

HybridQuery hybridQueryWithMatchAll = new HybridQuery(
List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)),
Expand Down Expand Up @@ -633,7 +633,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build();
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().build();

HybridQuery hybridQueryWithTerm = new HybridQuery(
List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)),
Expand Down Expand Up @@ -1169,4 +1169,41 @@ public void testScrollWithHybridQuery_thenFail() {
illegalArgumentException.getMessage()
);
}

@SneakyThrows
public void testCreateCollectorManager_whenPaginationDepthIsEqualToNullAndFromIsGreaterThanZero_thenFail() {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
SearchContext searchContext = mock(SearchContext.class);
// From >0
when(searchContext.from()).thenReturn(5);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1);
// if pagination_depth ==0 then internally by default it will pick 10 as the depth
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
HybridQuery hybridQuery = new HybridQuery(
List.of(termSubQuery.toQuery(mockQueryShardContext)),
HybridQueryContext.builder().build()
);

when(searchContext.query()).thenReturn(hybridQuery);
ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class);
IndexReader indexReader = mock(IndexReader.class);
when(indexSearcher.getIndexReader()).thenReturn(indexReader);
when(searchContext.searcher()).thenReturn(indexSearcher);
MapperService mapperService = createMapperService();
when(searchContext.mapperService()).thenReturn(mapperService);

Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> classCollectorManagerMap = new HashMap<>();
when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap);
when(searchContext.shouldUseConcurrentSearch()).thenReturn(false);

IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> HybridCollectorManager.createHybridCollectorManager(searchContext)
);
assertEquals(
String.format(Locale.ROOT, "pagination_depth param is missing in the search request"),
illegalArgumentException.getMessage()
);
}
}
Loading