diff --git a/plugin/build.gradle b/plugin/build.gradle index add7262f37..e387dda980 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -188,6 +188,10 @@ integTest { if (System.getProperty("test.debug") != null) { jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5005' } + + // Set this to true this if you want to see the logs in the terminal test output. + // note: if left false the log output will still show in your IDE + testLogging.showStandardStreams = true } testClusters.integTest { @@ -197,7 +201,7 @@ testClusters.integTest { // When running integration tests it doesn't forward the --debug-jvm to the cluster anymore // i.e. we have to use a custom property to flag when we want to debug elasticsearch JVM // since we also support multi node integration tests we increase debugPort per node - if (System.getProperty("opensearch.debug") != null) { + if (System.getProperty("cluster.debug") != null) { def debugPort = 5005 nodes.forEach { node -> node.jvmArgs("-agentlib:jdwp=transport=dt_socket,server=n,suspend=y,address=*:${debugPort}") diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java index 030e2253c6..162bf9e201 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java @@ -11,8 +11,10 @@ import java.util.Optional; import java.util.stream.Collectors; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -20,6 +22,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; @@ -27,9 +30,12 @@ import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import com.google.common.annotations.VisibleForTesting; + import lombok.extern.log4j.Log4j2; @Log4j2 @@ -60,7 +66,7 @@ protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); List excludes = Optional .ofNullable(request.source()) .map(SearchSourceBuilder::fetchSource) @@ -78,16 +84,43 @@ private void search(SearchRequest request, ActionListener action excludes.toArray(new String[0]) ); request.source().fetchSource(rebuiltFetchSourceContext); + + ActionListener doubleWrappedListener = ActionListener + .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleConnectorIndexNotFound(e, actionListener)); + if (connectorAccessControlHelper.skipConnectorAccessControl(user)) { - client.search(request, wrappedListener); + client.search(request, doubleWrappedListener); } else { SearchSourceBuilder sourceBuilder = connectorAccessControlHelper.addUserBackendRolesFilter(user, request.source()); request.source(sourceBuilder); - client.search(request, wrappedListener); + client.search(request, doubleWrappedListener); } } catch (Exception e) { log.error(e.getMessage(), e); actionListener.onFailure(e); } } + + @VisibleForTesting + public static void wrapListenerToHandleConnectorIndexNotFound(Exception e, ActionListener listener) { + if (ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) { + log.debug("Connectors index not created yet, therefore we will swallow the exception and return an empty search result"); + final InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty(); + final SearchResponse emptySearchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 0, + null, + new ShardSearchFailure[] {}, + SearchResponse.Clusters.EMPTY, + null + ); + listener.onResponse(emptySearchResponse); + } else { + listener.onFailure(e); + } + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 1756cbd161..964f79751d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -95,7 +95,15 @@ public void testDeleteConnector() throws IOException { assertEquals("deleted", (String) responseMap.get("result")); } - public void testSearchConnectors() throws IOException { + public void testSearchConnectors_beforeConnectorCreation() throws IOException { + String searchEntity = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " },\n" + " \"size\": 1000\n" + "}"; + Response response = TestHelper + .makeRequest(client(), "GET", "/_plugins/_ml/connectors/_search", null, TestHelper.toHttpEntity(searchEntity), null); + Map responseMap = parseResponseToMap(response); + assertEquals((Double) 0.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + } + + public void testSearchConnectors_afterConnectorCreation() throws IOException { createConnector(completionModelConnectorEntity); String searchEntity = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " },\n" + " \"size\": 1000\n" + "}"; Response response = TestHelper diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java index 97ed37230a..6ff65715ba 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java @@ -20,10 +20,12 @@ import org.apache.lucene.search.TotalHits; import org.hamcrest.Matchers; +import org.junit.Assert; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchWrapperException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; @@ -35,6 +37,8 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; @@ -190,4 +194,51 @@ public void testPrepareRequest_timeout() throws Exception { RestResponse restResponse = responseCaptor.getValue(); assertEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status()); } + + public void testDoubleWrapper_handleIndexNotFound() { + final IndexNotFoundException indexNotFoundException = new IndexNotFoundException("Index not found", ML_CONNECTOR_INDEX); + final DummyActionListener actionListener = new DummyActionListener(); + + SearchConnectorTransportAction.wrapListenerToHandleConnectorIndexNotFound(indexNotFoundException, actionListener); + Assert.assertTrue(actionListener.success); + } + + public void testDoubleWrapper_handleIndexNotFoundWrappedException() { + final WrappedException wrappedException = new WrappedException(); + final DummyActionListener actionListener = new DummyActionListener(); + + SearchConnectorTransportAction.wrapListenerToHandleConnectorIndexNotFound(wrappedException, actionListener); + Assert.assertTrue(actionListener.success); + } + + public void testDoubleWrapper_notRelatedException() { + final RuntimeException exception = new RuntimeException("some random exception"); + final DummyActionListener actionListener = new DummyActionListener(); + + SearchConnectorTransportAction.wrapListenerToHandleConnectorIndexNotFound(exception, actionListener); + Assert.assertFalse(actionListener.success); + } + + public class DummyActionListener implements ActionListener { + public boolean success = false; + + @Override + public void onResponse(SearchResponse searchResponse) { + logger.info("success"); + this.success = true; + } + + @Override + public void onFailure(Exception e) { + logger.error("failure", e); + this.success = false; + } + } + + public static class WrappedException extends Exception implements OpenSearchWrapperException { + @Override + public synchronized Throwable getCause() { + return new IndexNotFoundException("Index not found", ML_CONNECTOR_INDEX); + } + } }