Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1536: Refactor OpenSearchQueryRequest and move includes to builder #320

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ public class OpenSearchQueryRequest implements OpenSearchRequest {
@ToString.Exclude
private final OpenSearchExprValueFactory exprValueFactory;


/**
* List of includes expected in the response.
*/
@EqualsAndHashCode.Exclude
@ToString.Exclude
private final List<String> includes;

/**
* Indicate the search already done.
*/
Expand All @@ -61,40 +69,38 @@ public class OpenSearchQueryRequest implements OpenSearchRequest {
* Constructor of OpenSearchQueryRequest.
*/
public OpenSearchQueryRequest(String indexName, int size,
OpenSearchExprValueFactory factory) {
this(new IndexName(indexName), size, factory);
OpenSearchExprValueFactory factory, List<String> includes) {
this(new IndexName(indexName), size, factory, includes);
}

/**
* Constructor of OpenSearchQueryRequest.
*/
public OpenSearchQueryRequest(IndexName indexName, int size,
OpenSearchExprValueFactory factory) {
OpenSearchExprValueFactory factory, List<String> includes) {
this.indexName = indexName;
this.sourceBuilder = new SearchSourceBuilder();
sourceBuilder.from(0);
sourceBuilder.size(size);
sourceBuilder.timeout(DEFAULT_QUERY_TIMEOUT);
this.exprValueFactory = factory;
this.includes = includes;
}

/**
* Constructor of OpenSearchQueryRequest.
*/
public OpenSearchQueryRequest(IndexName indexName, SearchSourceBuilder sourceBuilder,
OpenSearchExprValueFactory factory) {
OpenSearchExprValueFactory factory, List<String> includes) {
this.indexName = indexName;
this.sourceBuilder = sourceBuilder;
this.exprValueFactory = factory;
this.includes = includes;
}

@Override
public OpenSearchResponse search(Function<SearchRequest, SearchResponse> searchAction,
Function<SearchScrollRequest, SearchResponse> scrollAction) {
FetchSourceContext fetchSource = this.sourceBuilder.fetchSource();
List<String> includes = fetchSource != null && fetchSource.includes() != null
? Arrays.asList(fetchSource.includes())
: List.of();
if (searchDone) {
return new OpenSearchResponse(SearchHits.empty(), exprValueFactory, includes);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
* OpenSearch search request.
*/
public interface OpenSearchRequest extends Writeable {

/**
* Default query timeout in minutes.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.opensearch.search.sort.SortOrder.ASC;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -98,23 +99,27 @@ public OpenSearchRequestBuilder(int requestedTotalSize,
public OpenSearchRequest build(OpenSearchRequest.IndexName indexName,
int maxResultWindow, TimeValue scrollTimeout) {
int size = requestedTotalSize;
FetchSourceContext fetchSource = this.sourceBuilder.fetchSource();
List<String> includes = fetchSource != null
? Arrays.asList(fetchSource.includes())
: List.of();
if (pageSize == null) {
if (startFrom + size > maxResultWindow) {
sourceBuilder.size(maxResultWindow - startFrom);
return new OpenSearchScrollRequest(
indexName, scrollTimeout, sourceBuilder, exprValueFactory);
indexName, scrollTimeout, sourceBuilder, exprValueFactory, includes);
} else {
sourceBuilder.from(startFrom);
sourceBuilder.size(requestedTotalSize);
return new OpenSearchQueryRequest(indexName, sourceBuilder, exprValueFactory);
return new OpenSearchQueryRequest(indexName, sourceBuilder, exprValueFactory, includes);
}
} else {
if (startFrom != 0) {
throw new UnsupportedOperationException("Non-zero offset is not supported with pagination");
}
sourceBuilder.size(pageSize);
return new OpenSearchScrollRequest(indexName, scrollTimeout,
sourceBuilder, exprValueFactory);
sourceBuilder, exprValueFactory, includes);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,16 @@ public class OpenSearchScrollRequest implements OpenSearchRequest {
private boolean needClean = true;

@Getter
@EqualsAndHashCode.Exclude
@ToString.Exclude
private final List<String> includes;

/** Constructor. */
public OpenSearchScrollRequest(IndexName indexName,
TimeValue scrollTimeout,
SearchSourceBuilder sourceBuilder,
OpenSearchExprValueFactory exprValueFactory) {
OpenSearchExprValueFactory exprValueFactory,
List<String> includes) {
this.indexName = indexName;
this.scrollTimeout = scrollTimeout;
this.exprValueFactory = exprValueFactory;
Expand All @@ -86,9 +89,7 @@ public OpenSearchScrollRequest(IndexName indexName,
.scroll(scrollTimeout)
.source(sourceBuilder);

includes = sourceBuilder.fetchSource() == null
? List.of()
: Arrays.asList(sourceBuilder.fetchSource().includes());
this.includes = includes;
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ void search() {
// Verify response for first scroll request
OpenSearchScrollRequest request = new OpenSearchScrollRequest(
new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1),
new SearchSourceBuilder(), factory);
new SearchSourceBuilder(), factory, List.of("id"));
OpenSearchResponse response1 = client.search(request);
assertFalse(response1.isEmpty());

Expand Down Expand Up @@ -358,7 +358,7 @@ void cleanup() {

OpenSearchScrollRequest request = new OpenSearchScrollRequest(
new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1),
new SearchSourceBuilder(), factory);
new SearchSourceBuilder(), factory, List.of());
request.setScrollId("scroll123");
// Enforce cleaning by setting a private field.
FieldUtils.writeField(request, "needClean", true, true);
Expand All @@ -375,7 +375,7 @@ void cleanup() {
void cleanup_without_scrollId() {
OpenSearchScrollRequest request = new OpenSearchScrollRequest(
new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1),
new SearchSourceBuilder(), factory);
new SearchSourceBuilder(), factory, List.of());
client.cleanup(request);
verify(nodeClient, never()).prepareClearScroll();
}
Expand All @@ -387,7 +387,7 @@ void cleanup_rethrows_exception() {

OpenSearchScrollRequest request = new OpenSearchScrollRequest(
new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1),
new SearchSourceBuilder(), factory);
new SearchSourceBuilder(), factory, List.of());
request.setScrollId("scroll123");
// Enforce cleaning by setting a private field.
FieldUtils.writeField(request, "needClean", true, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ void search() throws IOException {
// Verify response for first scroll request
OpenSearchScrollRequest request = new OpenSearchScrollRequest(
new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1),
new SearchSourceBuilder(), factory);
new SearchSourceBuilder(), factory, List.of("id"));
OpenSearchResponse response1 = client.search(request);
assertFalse(response1.isEmpty());

Expand All @@ -329,7 +329,7 @@ void search_with_IOException() throws IOException {
IllegalStateException.class,
() -> client.search(new OpenSearchScrollRequest(
new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1),
new SearchSourceBuilder(), factory)));
new SearchSourceBuilder(), factory, List.of())));
}

@Test
Expand All @@ -351,7 +351,7 @@ void scroll_with_IOException() throws IOException {
// First request run successfully
OpenSearchScrollRequest scrollRequest = new OpenSearchScrollRequest(
new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1),
new SearchSourceBuilder(), factory);
new SearchSourceBuilder(), factory, List.of());
client.search(scrollRequest);
assertThrows(
IllegalStateException.class, () -> client.search(scrollRequest));
Expand All @@ -370,7 +370,7 @@ void schedule() {
void cleanup() {
OpenSearchScrollRequest request = new OpenSearchScrollRequest(
new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1),
new SearchSourceBuilder(), factory);
new SearchSourceBuilder(), factory, List.of());
// Enforce cleaning by setting a private field.
FieldUtils.writeField(request, "needClean", true, true);
request.setScrollId("scroll123");
Expand All @@ -383,7 +383,7 @@ void cleanup() {
void cleanup_without_scrollId() throws IOException {
OpenSearchScrollRequest request = new OpenSearchScrollRequest(
new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1),
new SearchSourceBuilder(), factory);
new SearchSourceBuilder(), factory, List.of());
client.cleanup(request);
verify(restClient, never()).clearScroll(any(), any());
}
Expand All @@ -395,7 +395,7 @@ void cleanup_with_IOException() {

OpenSearchScrollRequest request = new OpenSearchScrollRequest(
new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1),
new SearchSourceBuilder(), factory);
new SearchSourceBuilder(), factory, List.of());
// Enforce cleaning by setting a private field.
FieldUtils.writeField(request, "needClean", true, true);
request.setScrollId("scroll123");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.mockito.Mockito.when;
import static org.opensearch.sql.opensearch.request.OpenSearchRequest.DEFAULT_QUERY_TIMEOUT;

import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.lucene.search.TotalHits;
Expand Down Expand Up @@ -68,27 +69,25 @@ public class OpenSearchQueryRequestTest {
private OpenSearchExprValueFactory factory;

private final OpenSearchQueryRequest request =
new OpenSearchQueryRequest("test", 200, factory);
new OpenSearchQueryRequest("test", 200, factory, List.of());

private final OpenSearchQueryRequest remoteRequest =
new OpenSearchQueryRequest("ccs:test", 200, factory);
new OpenSearchQueryRequest("ccs:test", 200, factory, List.of());

@Test
void search() {
OpenSearchQueryRequest request = new OpenSearchQueryRequest(
new OpenSearchRequest.IndexName("test"),
sourceBuilder,
factory
factory,
List.of()
);

when(sourceBuilder.fetchSource()).thenReturn(fetchSourceContext);
when(fetchSourceContext.includes()).thenReturn(null);
when(searchAction.apply(any())).thenReturn(searchResponse);
when(searchResponse.getHits()).thenReturn(searchHits);
when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit});

OpenSearchResponse searchResponse = request.search(searchAction, scrollAction);
verify(fetchSourceContext, times(1)).includes();
assertFalse(searchResponse.isEmpty());
searchResponse = request.search(searchAction, scrollAction);
assertTrue(searchResponse.isEmpty());
Expand All @@ -100,15 +99,14 @@ void search_withoutContext() {
OpenSearchQueryRequest request = new OpenSearchQueryRequest(
new OpenSearchRequest.IndexName("test"),
sourceBuilder,
factory
factory,
List.of()
);

when(sourceBuilder.fetchSource()).thenReturn(null);
when(searchAction.apply(any())).thenReturn(searchResponse);
when(searchResponse.getHits()).thenReturn(searchHits);
when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit});
OpenSearchResponse searchResponse = request.search(searchAction, scrollAction);
verify(sourceBuilder, times(1)).fetchSource();
assertFalse(searchResponse.isEmpty());
assertFalse(request.hasAnotherBatch());
}
Expand All @@ -118,18 +116,16 @@ void search_withIncludes() {
OpenSearchQueryRequest request = new OpenSearchQueryRequest(
new OpenSearchRequest.IndexName("test"),
sourceBuilder,
factory
factory,
List.of()
);

String[] includes = {"_id", "_index"};
when(sourceBuilder.fetchSource()).thenReturn(fetchSourceContext);
when(fetchSourceContext.includes()).thenReturn(includes);
when(searchAction.apply(any())).thenReturn(searchResponse);
when(searchResponse.getHits()).thenReturn(searchHits);
when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit});

OpenSearchResponse searchResponse = request.search(searchAction, scrollAction);
verify(fetchSourceContext, times(2)).includes();
assertFalse(searchResponse.isEmpty());

searchResponse = request.search(searchAction, scrollAction);
Expand Down
Loading