From 4ab60826b8f5a9a646442059cdd93f52dafbc9ce Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 15 Dec 2022 14:15:54 -0500 Subject: [PATCH 1/3] Refactor fuzziness interface on query builders (#5433) (#5584) * Refactor Object to Fuzziness type for all query builders Signed-off-by: noCharger * Revise on bwc Signed-off-by: noCharger * Update change log Signed-off-by: noCharger Signed-off-by: noCharger Co-authored-by: Daniel (dB.) Doubrovkine (cherry picked from commit d3f6dfab89720df155ee9cb4789aa642c2238057) Signed-off-by: Louis Chu Signed-off-by: noCharger Signed-off-by: Louis Chu Co-authored-by: Daniel (dB.) Doubrovkine --- CHANGELOG.md | 2 ++ .../search/query/MultiMatchQueryIT.java | 3 +- .../search/query/SearchQueryIT.java | 11 +++---- .../org/opensearch/common/unit/Fuzziness.java | 10 +++++++ .../query/MatchBoolPrefixQueryBuilder.java | 30 +++++++++++-------- .../index/query/MatchQueryBuilder.java | 11 +++++-- .../index/query/MultiMatchQueryBuilder.java | 9 ++++++ .../index/query/QueryStringQueryBuilder.java | 3 +- .../common/unit/FuzzinessTests.java | 12 ++++++++ .../MatchBoolPrefixQueryBuilderTests.java | 5 ++++ .../index/query/MatchQueryBuilderTests.java | 5 ++++ .../query/MultiMatchQueryBuilderTests.java | 5 ++++ .../query/QueryStringQueryBuilderTests.java | 10 +++++-- 13 files changed, 90 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 490b6ceaa8578..67abde81edef0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Pre conditions check before updating weighted routing metadata ([#4955](https://github.com/opensearch-project/OpenSearch/pull/4955)) ### Deprecated +- Refactor fuzziness interface on query builders ([#5433](https://github.com/opensearch-project/OpenSearch/pull/5433)) + ### Removed ### Fixed - Fix 1.x compatibility bug with stored Tasks ([#5412](https://github.com/opensearch-project/OpenSearch/pull/5412)) diff --git a/server/src/internalClusterTest/java/org/opensearch/search/query/MultiMatchQueryIT.java b/server/src/internalClusterTest/java/org/opensearch/search/query/MultiMatchQueryIT.java index d87bbfb1fb69c..79527039a50f5 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/query/MultiMatchQueryIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/query/MultiMatchQueryIT.java @@ -37,6 +37,7 @@ import org.opensearch.action.index.IndexRequestBuilder; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.Fuzziness; import org.opensearch.common.util.set.Sets; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; @@ -1024,7 +1025,7 @@ public void testFuzzyFieldLevelBoosting() throws InterruptedException, Execution SearchResponse searchResponse = client().prepareSearch(idx) .setExplain(true) - .setQuery(multiMatchQuery("foo").field("title", 100).field("body").fuzziness(0)) + .setQuery(multiMatchQuery("foo").field("title", 100).field("body").fuzziness(Fuzziness.ZERO)) .get(); SearchHit[] hits = searchResponse.getHits().getHits(); assertNotEquals("both documents should be on different shards", hits[0].getShard().getShardId(), hits[1].getShard().getShardId()); diff --git a/server/src/internalClusterTest/java/org/opensearch/search/query/SearchQueryIT.java b/server/src/internalClusterTest/java/org/opensearch/search/query/SearchQueryIT.java index e90d4e8e12c10..d32487df10b38 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/query/SearchQueryIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/query/SearchQueryIT.java @@ -49,6 +49,7 @@ import org.opensearch.common.regex.Regex; import org.opensearch.common.settings.Settings; import org.opensearch.common.time.DateFormatter; +import org.opensearch.common.unit.Fuzziness; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; @@ -762,21 +763,21 @@ public void testMatchQueryFuzzy() throws Exception { client().prepareIndex("test").setId("2").setSource("text", "Unity") ); - SearchResponse searchResponse = client().prepareSearch().setQuery(matchQuery("text", "uniy").fuzziness("0")).get(); + SearchResponse searchResponse = client().prepareSearch().setQuery(matchQuery("text", "uniy").fuzziness(Fuzziness.ZERO)).get(); assertHitCount(searchResponse, 0L); - searchResponse = client().prepareSearch().setQuery(matchQuery("text", "uniy").fuzziness("1")).get(); + searchResponse = client().prepareSearch().setQuery(matchQuery("text", "uniy").fuzziness(Fuzziness.ONE)).get(); assertHitCount(searchResponse, 2L); assertSearchHits(searchResponse, "1", "2"); - searchResponse = client().prepareSearch().setQuery(matchQuery("text", "uniy").fuzziness("AUTO")).get(); + searchResponse = client().prepareSearch().setQuery(matchQuery("text", "uniy").fuzziness(Fuzziness.AUTO)).get(); assertHitCount(searchResponse, 2L); assertSearchHits(searchResponse, "1", "2"); - searchResponse = client().prepareSearch().setQuery(matchQuery("text", "uniy").fuzziness("AUTO:5,7")).get(); + searchResponse = client().prepareSearch().setQuery(matchQuery("text", "uniy").fuzziness(Fuzziness.customAuto(5, 7))).get(); assertHitCount(searchResponse, 0L); - searchResponse = client().prepareSearch().setQuery(matchQuery("text", "unify").fuzziness("AUTO:5,7")).get(); + searchResponse = client().prepareSearch().setQuery(matchQuery("text", "unify").fuzziness(Fuzziness.customAuto(5, 7))).get(); assertHitCount(searchResponse, 1L); assertSearchHits(searchResponse, "2"); } diff --git a/server/src/main/java/org/opensearch/common/unit/Fuzziness.java b/server/src/main/java/org/opensearch/common/unit/Fuzziness.java index c3b6ea6b8c23d..28947b3936843 100644 --- a/server/src/main/java/org/opensearch/common/unit/Fuzziness.java +++ b/server/src/main/java/org/opensearch/common/unit/Fuzziness.java @@ -139,6 +139,16 @@ public static Fuzziness build(Object fuzziness) { return new Fuzziness(string); } + /*** + * Creates a {@link Fuzziness} instance from lowDistance and highDistance. + * where the edit distance is 0 for strings shorter than lowDistance, + * 1 for strings where its length between lowDistance and highDistance (inclusive), + * and 2 for strings longer than highDistance. + */ + public static Fuzziness customAuto(int lowDistance, int highDistance) { + return new Fuzziness("AUTO", lowDistance, highDistance); + } + private static Fuzziness parseCustomAuto(final String string) { assert string.toUpperCase(Locale.ROOT).startsWith(AUTO.asString() + ":"); String[] fuzzinessLimit = string.substring(AUTO.asString().length() + 1).split(","); diff --git a/server/src/main/java/org/opensearch/index/query/MatchBoolPrefixQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/MatchBoolPrefixQueryBuilder.java index f8f84c52309d5..f901fac22d7ae 100644 --- a/server/src/main/java/org/opensearch/index/query/MatchBoolPrefixQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/MatchBoolPrefixQueryBuilder.java @@ -175,12 +175,19 @@ public String minimumShouldMatch() { return this.minimumShouldMatch; } + @Deprecated /** Sets the fuzziness used when evaluated to a fuzzy query type. Defaults to "AUTO". */ public MatchBoolPrefixQueryBuilder fuzziness(Object fuzziness) { this.fuzziness = Fuzziness.build(fuzziness); return this; } + /** Sets the fuzziness used when evaluated to a fuzzy query type. Defaults to "AUTO". */ + public MatchBoolPrefixQueryBuilder fuzziness(Fuzziness fuzziness) { + this.fuzziness = fuzziness; + return this; + } + /** Gets the fuzziness used when evaluated to a fuzzy query type. */ public Fuzziness fuzziness() { return this.fuzziness; @@ -348,19 +355,16 @@ public static MatchBoolPrefixQueryBuilder fromXContent(XContentParser parser) th } } - MatchBoolPrefixQueryBuilder queryBuilder = new MatchBoolPrefixQueryBuilder(fieldName, value); - queryBuilder.analyzer(analyzer); - queryBuilder.operator(operator); - queryBuilder.minimumShouldMatch(minimumShouldMatch); - queryBuilder.boost(boost); - queryBuilder.queryName(queryName); - if (fuzziness != null) { - queryBuilder.fuzziness(fuzziness); - } - queryBuilder.prefixLength(prefixLength); - queryBuilder.maxExpansions(maxExpansion); - queryBuilder.fuzzyTranspositions(fuzzyTranspositions); - queryBuilder.fuzzyRewrite(fuzzyRewrite); + MatchBoolPrefixQueryBuilder queryBuilder = new MatchBoolPrefixQueryBuilder(fieldName, value).analyzer(analyzer) + .operator(operator) + .minimumShouldMatch(minimumShouldMatch) + .boost(boost) + .queryName(queryName) + .fuzziness(fuzziness) + .prefixLength(prefixLength) + .maxExpansions(maxExpansion) + .fuzzyTranspositions(fuzzyTranspositions) + .fuzzyRewrite(fuzzyRewrite); return queryBuilder; } diff --git a/server/src/main/java/org/opensearch/index/query/MatchQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/MatchQueryBuilder.java index 380e8722daca9..8dbe9392bdd95 100644 --- a/server/src/main/java/org/opensearch/index/query/MatchQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/MatchQueryBuilder.java @@ -208,12 +208,19 @@ public String analyzer() { return this.analyzer; } + @Deprecated /** Sets the fuzziness used when evaluated to a fuzzy query type. Defaults to "AUTO". */ public MatchQueryBuilder fuzziness(Object fuzziness) { this.fuzziness = Fuzziness.build(fuzziness); return this; } + /** Sets the fuzziness used when evaluated to a fuzzy query type. Defaults to "AUTO". */ + public MatchQueryBuilder fuzziness(Fuzziness fuzziness) { + this.fuzziness = fuzziness; + return this; + } + /** Gets the fuzziness used when evaluated to a fuzzy query type. */ public Fuzziness fuzziness() { return this.fuzziness; @@ -565,9 +572,7 @@ public static MatchQueryBuilder fromXContent(XContentParser parser) throws IOExc matchQuery.operator(operator); matchQuery.analyzer(analyzer); matchQuery.minimumShouldMatch(minimumShouldMatch); - if (fuzziness != null) { - matchQuery.fuzziness(fuzziness); - } + matchQuery.fuzziness(fuzziness); matchQuery.fuzzyRewrite(fuzzyRewrite); matchQuery.prefixLength(prefixLength); matchQuery.fuzzyTranspositions(fuzzyTranspositions); diff --git a/server/src/main/java/org/opensearch/index/query/MultiMatchQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/MultiMatchQueryBuilder.java index 4e7b6fb51291b..f7009f05f348d 100644 --- a/server/src/main/java/org/opensearch/index/query/MultiMatchQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/MultiMatchQueryBuilder.java @@ -404,6 +404,7 @@ public int slop() { return slop; } + @Deprecated /** * Sets the fuzziness used when evaluated to a fuzzy query type. Defaults to "AUTO". */ @@ -414,6 +415,14 @@ public MultiMatchQueryBuilder fuzziness(Object fuzziness) { return this; } + /** + * Sets the fuzziness used when evaluated to a fuzzy query type. Defaults to "AUTO". + */ + public MultiMatchQueryBuilder fuzziness(Fuzziness fuzziness) { + this.fuzziness = fuzziness; + return this; + } + public Fuzziness fuzziness() { return fuzziness; } diff --git a/server/src/main/java/org/opensearch/index/query/QueryStringQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/QueryStringQueryBuilder.java index 32337f5df34c5..4ee790291f453 100644 --- a/server/src/main/java/org/opensearch/index/query/QueryStringQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/QueryStringQueryBuilder.java @@ -79,6 +79,7 @@ public class QueryStringQueryBuilder extends AbstractQueryBuilder Date: Fri, 16 Dec 2022 09:03:59 -0800 Subject: [PATCH 2/3] Backporting Auto release workflow to 2.x (#5582) * Backporting auto release workflow to 2.x Signed-off-by: Sarat Vemulapalli * Adding Changelog Signed-off-by: Sarat Vemulapalli * Adding Changelog Signed-off-by: Sarat Vemulapalli Signed-off-by: Sarat Vemulapalli --- .github/workflows/auto-release.yml | 29 +++++++++++++++++++++++++++++ CHANGELOG.md | 1 + 2 files changed, 30 insertions(+) create mode 100644 .github/workflows/auto-release.yml diff --git a/.github/workflows/auto-release.yml b/.github/workflows/auto-release.yml new file mode 100644 index 0000000000000..b8d3912c5864a --- /dev/null +++ b/.github/workflows/auto-release.yml @@ -0,0 +1,29 @@ +name: Releases + +on: + push: + tags: + - '*' + +jobs: + + build: + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: GitHub App token + id: github_app_token + uses: tibdex/github-app-token@v1.5.0 + with: + app_id: ${{ secrets.APP_ID }} + private_key: ${{ secrets.APP_PRIVATE_KEY }} + installation_id: 22958780 + - name: Get tag + id: tag + uses: dawidd6/action-get-tag@v1 + - uses: actions/checkout@v2 + - uses: ncipollo/release-action@v1 + with: + github_token: ${{ steps.github_app_token.outputs.token }} + bodyFile: release-notes/opensearch.release-notes-${{steps.tag.outputs.tag}}.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 67abde81edef0..ed39f6057767f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Added jackson dependency to server ([#5366] (https://github.com/opensearch-project/OpenSearch/pull/5366)) - Added experimental extensions to main ([#5347](https://github.com/opensearch-project/OpenSearch/pull/5347)) - Adding support to register settings dynamically ([#5495](https://github.com/opensearch-project/OpenSearch/pull/5495)) +- Adding auto release workflow ([#5582](https://github.com/opensearch-project/OpenSearch/pull/5582)) ### Dependencies - Bump bcpg-fips from 1.0.5.1 to 1.0.7.1 ([#5148](https://github.com/opensearch-project/OpenSearch/pull/5148)) From c0eea5c2f2e890e0f63f9ab06caa909b7f9e996e Mon Sep 17 00:00:00 2001 From: Ryan Bogan <10944539+ryanbogan@users.noreply.github.com> Date: Fri, 16 Dec 2022 15:05:28 -0800 Subject: [PATCH 3/3] Added extension Points, initial REST implementation and registering Transport Actions for extensions (#5596) Signed-off-by: Ryan Bogan Signed-off-by: Ryan Bogan --- CHANGELOG.md | 2 +- ...t.java => InitializeExtensionRequest.java} | 36 +- ....java => InitializeExtensionResponse.java} | 8 +- .../extensions/AcknowledgedResponse.java | 68 ++++ .../extensions/ExtensionReader.java | 45 +++ .../extensions/ExtensionsManager.java | 241 +++++++++---- .../extensions/OpenSearchRequest.java | 73 ++++ .../RegisterTransportActionsRequest.java | 79 +++++ .../rest/RegisterRestActionsRequest.java | 72 ++++ .../rest/RegisterRestActionsResponse.java | 41 +++ .../rest/RestActionsRequestHandler.java | 62 ++++ .../rest/RestExecuteOnExtensionRequest.java | 77 +++++ .../rest/RestExecuteOnExtensionResponse.java | 112 ++++++ .../rest/RestSendToExtensionAction.java | 186 ++++++++++ .../extensions/rest/package-info.java | 10 + .../index/AcknowledgedResponse.java | 42 --- .../main/java/org/opensearch/node/Node.java | 4 +- .../transport/TransportService.java | 2 +- .../extensions/ExtensionsManagerTests.java | 320 ++++++++++++------ .../RegisterTransportActionsRequestTests.java | 42 +++ .../rest/RegisterRestActionsTests.java | 62 ++++ .../rest/RestExecuteOnExtensionTests.java | 94 +++++ .../rest/RestSendToExtensionActionTests.java | 159 +++++++++ 23 files changed, 1599 insertions(+), 238 deletions(-) rename server/src/main/java/org/opensearch/discovery/{PluginRequest.java => InitializeExtensionRequest.java} (60%) rename server/src/main/java/org/opensearch/discovery/{PluginResponse.java => InitializeExtensionResponse.java} (88%) create mode 100644 server/src/main/java/org/opensearch/extensions/AcknowledgedResponse.java create mode 100644 server/src/main/java/org/opensearch/extensions/ExtensionReader.java create mode 100644 server/src/main/java/org/opensearch/extensions/OpenSearchRequest.java create mode 100644 server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsRequest.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsResponse.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionRequest.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/package-info.java delete mode 100644 server/src/main/java/org/opensearch/index/AcknowledgedResponse.java create mode 100644 server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java create mode 100644 server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java create mode 100644 server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java create mode 100644 server/src/test/java/org/opensearch/extensions/rest/RestSendToExtensionActionTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index ed39f6057767f..afcb67923152e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,9 +13,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Reject bulk requests with invalid actions ([#5299](https://github.com/opensearch-project/OpenSearch/issues/5299)) - Add max_shard_size parameter for shrink API ([#5229](https://github.com/opensearch-project/OpenSearch/pull/5229)) - Added jackson dependency to server ([#5366] (https://github.com/opensearch-project/OpenSearch/pull/5366)) -- Added experimental extensions to main ([#5347](https://github.com/opensearch-project/OpenSearch/pull/5347)) - Adding support to register settings dynamically ([#5495](https://github.com/opensearch-project/OpenSearch/pull/5495)) - Adding auto release workflow ([#5582](https://github.com/opensearch-project/OpenSearch/pull/5582)) +- Added experimental support for extensions ([#5347](https://github.com/opensearch-project/OpenSearch/pull/5347)), ([#5518](https://github.com/opensearch-project/OpenSearch/pull/5518)) ### Dependencies - Bump bcpg-fips from 1.0.5.1 to 1.0.7.1 ([#5148](https://github.com/opensearch-project/OpenSearch/pull/5148)) diff --git a/server/src/main/java/org/opensearch/discovery/PluginRequest.java b/server/src/main/java/org/opensearch/discovery/InitializeExtensionRequest.java similarity index 60% rename from server/src/main/java/org/opensearch/discovery/PluginRequest.java rename to server/src/main/java/org/opensearch/discovery/InitializeExtensionRequest.java index 7992de4342d86..b83e9080fa452 100644 --- a/server/src/main/java/org/opensearch/discovery/PluginRequest.java +++ b/server/src/main/java/org/opensearch/discovery/InitializeExtensionRequest.java @@ -15,62 +15,58 @@ import org.opensearch.transport.TransportRequest; import java.io.IOException; -import java.util.List; import java.util.Objects; /** - * PluginRequest to intialize plugin + * InitializeExtensionRequest to intialize plugin * * @opensearch.internal */ -public class PluginRequest extends TransportRequest { +public class InitializeExtensionRequest extends TransportRequest { private final DiscoveryNode sourceNode; - /* - * TODO change DiscoveryNode to Extension information - */ - private final List extensions; + private final DiscoveryExtensionNode extension; - public PluginRequest(DiscoveryNode sourceNode, List extensions) { + public InitializeExtensionRequest(DiscoveryNode sourceNode, DiscoveryExtensionNode extension) { this.sourceNode = sourceNode; - this.extensions = extensions; + this.extension = extension; } - public PluginRequest(StreamInput in) throws IOException { + public InitializeExtensionRequest(StreamInput in) throws IOException { super(in); sourceNode = new DiscoveryNode(in); - extensions = in.readList(DiscoveryExtensionNode::new); + extension = new DiscoveryExtensionNode(in); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); sourceNode.writeTo(out); - out.writeList(extensions); - } - - public List getExtensions() { - return extensions; + extension.writeTo(out); } public DiscoveryNode getSourceNode() { return sourceNode; } + public DiscoveryExtensionNode getExtension() { + return extension; + } + @Override public String toString() { - return "PluginRequest{" + "sourceNode=" + sourceNode + ", extensions=" + extensions + '}'; + return "InitializeExtensionsRequest{" + "sourceNode=" + sourceNode + ", extension=" + extension + '}'; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - PluginRequest that = (PluginRequest) o; - return Objects.equals(sourceNode, that.sourceNode) && Objects.equals(extensions, that.extensions); + InitializeExtensionRequest that = (InitializeExtensionRequest) o; + return Objects.equals(sourceNode, that.sourceNode) && Objects.equals(extension, that.extension); } @Override public int hashCode() { - return Objects.hash(sourceNode, extensions); + return Objects.hash(sourceNode, extension); } } diff --git a/server/src/main/java/org/opensearch/discovery/PluginResponse.java b/server/src/main/java/org/opensearch/discovery/InitializeExtensionResponse.java similarity index 88% rename from server/src/main/java/org/opensearch/discovery/PluginResponse.java rename to server/src/main/java/org/opensearch/discovery/InitializeExtensionResponse.java index f8f20214e5846..3be97816dc541 100644 --- a/server/src/main/java/org/opensearch/discovery/PluginResponse.java +++ b/server/src/main/java/org/opensearch/discovery/InitializeExtensionResponse.java @@ -44,14 +44,14 @@ * * @opensearch.internal */ -public class PluginResponse extends TransportResponse { +public class InitializeExtensionResponse extends TransportResponse { private String name; - public PluginResponse(String name) { + public InitializeExtensionResponse(String name) { this.name = name; } - public PluginResponse(StreamInput in) throws IOException { + public InitializeExtensionResponse(StreamInput in) throws IOException { name = in.readString(); } @@ -77,7 +77,7 @@ public String toString() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - PluginResponse that = (PluginResponse) o; + InitializeExtensionResponse that = (InitializeExtensionResponse) o; return Objects.equals(name, that.name); } diff --git a/server/src/main/java/org/opensearch/extensions/AcknowledgedResponse.java b/server/src/main/java/org/opensearch/extensions/AcknowledgedResponse.java new file mode 100644 index 0000000000000..be7eb9c03076e --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/AcknowledgedResponse.java @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportResponse; +import java.io.IOException; +import java.util.Objects; + +/** + * Generic boolean response indicating the status of some previous request sent to the SDK + * + * @opensearch.internal + */ +public class AcknowledgedResponse extends TransportResponse { + + private final boolean status; + + /** + * @param status Boolean indicating the status of the parse request sent to the SDK + */ + public AcknowledgedResponse(boolean status) { + this.status = status; + } + + public AcknowledgedResponse(StreamInput in) throws IOException { + super(in); + this.status = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(status); + } + + @Override + public String toString() { + return "AcknowledgedResponse{" + "status=" + this.status + "}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AcknowledgedResponse that = (AcknowledgedResponse) o; + return Objects.equals(this.status, that.status); + } + + @Override + public int hashCode() { + return Objects.hash(status); + } + + /** + * Returns a boolean indicating the success of the request sent to the SDK + */ + public boolean getStatus() { + return this.status; + } + +} diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionReader.java b/server/src/main/java/org/opensearch/extensions/ExtensionReader.java new file mode 100644 index 0000000000000..e54e3a6a4f940 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/ExtensionReader.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import java.net.UnknownHostException; +import org.opensearch.cluster.node.DiscoveryNode; + +/** + * Reference to a method that transports a parse request to an extension. By convention, this method takes + * a category class used to identify the reader defined within the JVM that the extension is running on. + * Additionally, this method takes in the extension's corresponding DiscoveryNode and a byte array (context) that the + * extension's reader will be applied to. + * + * By convention the extensions' reader is a constructor that takes StreamInput as an argument for most classes and a static method for things like enums. + * Classes will implement this via a constructor (or a static method in the case of enumerations), it's something that should + * look like: + *

+ * public MyClass(final StreamInput in) throws IOException {
+ * *     this.someValue = in.readVInt();
+ *     this.someMap = in.readMapOfLists(StreamInput::readString, StreamInput::readString);
+ * }
+ * 
+ * + * @opensearch.internal + */ +@FunctionalInterface +public interface ExtensionReader { + + /** + * Transports category class, and StreamInput (context), to the extension identified by the Discovery Node + * + * @param extensionNode Discovery Node identifying the Extension + * @param categoryClass Super class that the reader extends + * @param context Some context to transport + * @throws UnknownHostException if the extension node host IP address could not be determined + */ + void parse(DiscoveryNode extensionNode, Class categoryClass, Object context) throws UnknownHostException; + +} diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java index b809f2e35a483..be843fe35a5f9 100644 --- a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java +++ b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java @@ -17,6 +17,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; @@ -34,17 +35,19 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; -import org.opensearch.discovery.PluginRequest; -import org.opensearch.discovery.PluginResponse; +import org.opensearch.discovery.InitializeExtensionRequest; +import org.opensearch.discovery.InitializeExtensionResponse; import org.opensearch.extensions.ExtensionsSettings.Extension; +import org.opensearch.extensions.rest.RegisterRestActionsRequest; +import org.opensearch.extensions.rest.RestActionsRequestHandler; import org.opensearch.index.IndexModule; import org.opensearch.index.IndexService; -import org.opensearch.index.AcknowledgedResponse; import org.opensearch.index.IndicesModuleRequest; import org.opensearch.index.IndicesModuleResponse; import org.opensearch.index.shard.IndexEventListener; import org.opensearch.indices.cluster.IndicesClusterStateService; import org.opensearch.plugins.PluginInfo; +import org.opensearch.rest.RestController; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportResponse; @@ -55,7 +58,7 @@ import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; /** - * The main class for Plugin Extensibility + * The main class for managing Extension communication with the OpenSearch Node. * * @opensearch.internal */ @@ -66,6 +69,11 @@ public class ExtensionsManager { public static final String REQUEST_EXTENSION_CLUSTER_STATE = "internal:discovery/clusterstate"; public static final String REQUEST_EXTENSION_LOCAL_NODE = "internal:discovery/localnode"; public static final String REQUEST_EXTENSION_CLUSTER_SETTINGS = "internal:discovery/clustersettings"; + public static final String REQUEST_EXTENSION_REGISTER_REST_ACTIONS = "internal:discovery/registerrestactions"; + public static final String REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY = "internal:discovery/namedwriteableregistry"; + public static final String REQUEST_OPENSEARCH_PARSE_NAMED_WRITEABLE = "internal:discovery/parsenamedwriteable"; + public static final String REQUEST_REST_EXECUTE_ON_EXTENSION_ACTION = "internal:extensions/restexecuteonextensiontaction"; + public static final String REQUEST_EXTENSION_REGISTER_TRANSPORT_ACTIONS = "internal:discovery/registertransportactions"; private static final Logger logger = LogManager.getLogger(ExtensionsManager.class); @@ -78,28 +86,46 @@ public static enum RequestType { REQUEST_EXTENSION_CLUSTER_STATE, REQUEST_EXTENSION_LOCAL_NODE, REQUEST_EXTENSION_CLUSTER_SETTINGS, + REQUEST_EXTENSION_REGISTER_REST_ACTIONS, CREATE_COMPONENT, ON_INDEX_MODULE, GET_SETTINGS }; + /** + * Enum for OpenSearch Requests + * + * @opensearch.internal + */ + public static enum OpenSearchRequestType { + REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY + } + private final Path extensionsPath; - private final List uninitializedExtensions; + // A list of initialized extensions, a subset of the values of map below which includes all extensions private List extensions; + private Map extensionIdMap; + private RestActionsRequestHandler restActionsRequestHandler; private TransportService transportService; private ClusterService clusterService; public ExtensionsManager() { this.extensionsPath = Path.of(""); - this.uninitializedExtensions = new ArrayList(); } + /** + * Instantiate a new ExtensionsManager object to handle requests and responses from extensions. This is called during Node bootstrap. + * + * @param settings Settings from the node the orchestrator is running on. + * @param extensionsPath Path to a directory containing extensions. + * @throws IOException If the extensions discovery file is not properly retrieved. + */ public ExtensionsManager(Settings settings, Path extensionsPath) throws IOException { logger.info("ExtensionsManager initialized"); this.extensionsPath = extensionsPath; this.transportService = null; - this.uninitializedExtensions = new ArrayList(); this.extensions = new ArrayList(); + this.extensionIdMap = new HashMap(); this.clusterService = null; /* @@ -109,16 +135,34 @@ public ExtensionsManager(Settings settings, Path extensionsPath) throws IOExcept } - public void setTransportService(TransportService transportService) { + /** + * Initializes the {@link RestActionsRequestHandler}, {@link TransportService} and {@link ClusterService}. This is called during Node bootstrap. + * Lists/maps of extensions have already been initialized but not yet populated. + * + * @param restController The RestController on which to register Rest Actions. + * @param transportService The Node's transport service. + * @param clusterService The Node's cluster service. + */ + public void initializeServicesAndRestHandler( + RestController restController, + TransportService transportService, + ClusterService clusterService + ) { + this.restActionsRequestHandler = new RestActionsRequestHandler(restController, extensionIdMap, transportService); this.transportService = transportService; - registerRequestHandler(); - } - - public void setClusterService(ClusterService clusterService) { this.clusterService = clusterService; + registerRequestHandler(); } private void registerRequestHandler() { + transportService.registerRequestHandler( + REQUEST_EXTENSION_REGISTER_REST_ACTIONS, + ThreadPool.Names.GENERIC, + false, + false, + RegisterRestActionsRequest::new, + ((request, channel, task) -> channel.sendResponse(restActionsRequestHandler.handleRegisterRestActionsRequest(request))) + ); transportService.registerRequestHandler( REQUEST_EXTENSION_CLUSTER_STATE, ThreadPool.Names.GENERIC, @@ -143,6 +187,14 @@ private void registerRequestHandler() { ExtensionRequest::new, ((request, channel, task) -> channel.sendResponse(handleExtensionRequest(request))) ); + transportService.registerRequestHandler( + REQUEST_EXTENSION_REGISTER_TRANSPORT_ACTIONS, + ThreadPool.Names.GENERIC, + false, + false, + RegisterTransportActionsRequest::new, + ((request, channel, task) -> channel.sendResponse(handleRegisterTransportActionsRequest(request))) + ); } /* @@ -164,7 +216,7 @@ private void discover() throws IOException { for (Extension extension : extensions) { loadExtension(extension); } - if (!uninitializedExtensions.isEmpty()) { + if (!extensionIdMap.isEmpty()) { logger.info("Loaded all extensions"); } } else { @@ -177,9 +229,11 @@ private void discover() throws IOException { * @param extension The extension to be loaded */ private void loadExtension(Extension extension) throws IOException { - try { - uninitializedExtensions.add( - new DiscoveryExtensionNode( + if (extensionIdMap.containsKey(extension.getUniqueId())) { + logger.info("Duplicate uniqueId " + extension.getUniqueId() + ". Did not load extension: " + extension); + } else { + try { + DiscoveryExtensionNode discoveryExtensionNode = new DiscoveryExtensionNode( extension.getName(), extension.getUniqueId(), // placeholder for ephemeral id, will change with POC discovery @@ -199,42 +253,51 @@ private void loadExtension(Extension extension) throws IOException { new ArrayList(), Boolean.parseBoolean(extension.hasNativeController()) ) - ) - ); - logger.info("Loaded extension: " + extension); - } catch (IllegalArgumentException e) { - throw e; + ); + extensionIdMap.put(extension.getUniqueId(), discoveryExtensionNode); + logger.info("Loaded extension with uniqueId " + extension.getUniqueId() + ": " + extension); + } catch (IllegalArgumentException e) { + throw e; + } } } + /** + * Iterate through all extensions and initialize them. Initialized extensions will be added to the {@link #extensions}. + */ public void initialize() { - for (DiscoveryNode extensionNode : uninitializedExtensions) { - initializeExtension(extensionNode); + for (DiscoveryExtensionNode extension : extensionIdMap.values()) { + initializeExtension(extension); } } - private void initializeExtension(DiscoveryNode extensionNode) { + private void initializeExtension(DiscoveryExtensionNode extension) { - final TransportResponseHandler pluginResponseHandler = new TransportResponseHandler() { + final CompletableFuture inProgressFuture = new CompletableFuture<>(); + final TransportResponseHandler initializeExtensionResponseHandler = new TransportResponseHandler< + InitializeExtensionResponse>() { @Override - public PluginResponse read(StreamInput in) throws IOException { - return new PluginResponse(in); + public InitializeExtensionResponse read(StreamInput in) throws IOException { + return new InitializeExtensionResponse(in); } @Override - public void handleResponse(PluginResponse response) { - for (DiscoveryExtensionNode extension : uninitializedExtensions) { + public void handleResponse(InitializeExtensionResponse response) { + for (DiscoveryExtensionNode extension : extensionIdMap.values()) { if (extension.getName().equals(response.getName())) { extensions.add(extension); + logger.info("Initialized extension: " + extension.getName()); break; } } + inProgressFuture.complete(response); } @Override public void handleException(TransportException exp) { - logger.error(new ParameterizedMessage("Plugin request failed"), exp); + logger.error(new ParameterizedMessage("Extension initialization failed"), exp); + inProgressFuture.completeExceptionally(exp); } @Override @@ -243,39 +306,63 @@ public String executor() { } }; try { - transportService.connectToExtensionNode(extensionNode); + logger.info("Sending extension request type: " + REQUEST_EXTENSION_ACTION_NAME); + transportService.connectToExtensionNode(extension); transportService.sendRequest( - extensionNode, + extension, REQUEST_EXTENSION_ACTION_NAME, - new PluginRequest(transportService.getLocalNode(), new ArrayList(uninitializedExtensions)), - pluginResponseHandler + new InitializeExtensionRequest(transportService.getLocalNode(), extension), + initializeExtensionResponseHandler ); + // TODO: make asynchronous + inProgressFuture.get(100, TimeUnit.SECONDS); } catch (Exception e) { - throw e; + try { + throw e; + } catch (Exception e1) { + logger.error(e.toString()); + } } } + /** + * Handles a {@link RegisterTransportActionsRequest}. + * + * @param transportActionsRequest The request to handle. + * @return A {@link AcknowledgedResponse} indicating success. + * @throws Exception if the request is not handled properly. + */ + TransportResponse handleRegisterTransportActionsRequest(RegisterTransportActionsRequest transportActionsRequest) throws Exception { + /* + * TODO: https://github.com/opensearch-project/opensearch-sdk-java/issues/107 + * Register these new Transport Actions with ActionModule + * and add support for NodeClient to recognise these actions when making transport calls. + */ + return new AcknowledgedResponse(true); + } + + /** + * Handles an {@link ExtensionRequest}. + * + * @param extensionRequest The request to handle, of a type defined in the {@link RequestType} enum. + * @return an Response matching the request. + * @throws Exception if the request is not handled properly. + */ TransportResponse handleExtensionRequest(ExtensionRequest extensionRequest) throws Exception { - // Read enum - if (extensionRequest.getRequestType() == RequestType.REQUEST_EXTENSION_CLUSTER_STATE) { - ClusterStateResponse clusterStateResponse = new ClusterStateResponse( - clusterService.getClusterName(), - clusterService.state(), - false - ); - return clusterStateResponse; - } else if (extensionRequest.getRequestType() == RequestType.REQUEST_EXTENSION_LOCAL_NODE) { - LocalNodeResponse localNodeResponse = new LocalNodeResponse(clusterService); - return localNodeResponse; - } else if (extensionRequest.getRequestType() == RequestType.REQUEST_EXTENSION_CLUSTER_SETTINGS) { - ClusterSettingsResponse clusterSettingsResponse = new ClusterSettingsResponse(clusterService); - return clusterSettingsResponse; + switch (extensionRequest.getRequestType()) { + case REQUEST_EXTENSION_CLUSTER_STATE: + return new ClusterStateResponse(clusterService.getClusterName(), clusterService.state(), false); + case REQUEST_EXTENSION_LOCAL_NODE: + return new LocalNodeResponse(clusterService); + case REQUEST_EXTENSION_CLUSTER_SETTINGS: + return new ClusterSettingsResponse(clusterService); + default: + throw new IllegalStateException("Handler not present for the provided request"); } - throw new IllegalStateException("Handler not present for the provided request: " + extensionRequest.getRequestType()); } public void onIndexModule(IndexModule indexModule) throws UnknownHostException { - for (DiscoveryNode extensionNode : uninitializedExtensions) { + for (DiscoveryNode extensionNode : extensionIdMap.values()) { onIndexModule(indexModule, extensionNode); } } @@ -331,20 +418,22 @@ public void beforeIndexRemoved( String indexName = indexService.index().getName(); logger.info("Index Name" + indexName.toString()); try { - logger.info("Sending request of index name to extension"); + logger.info("Sending extension request type: " + INDICES_EXTENSION_NAME_ACTION_NAME); transportService.sendRequest( extensionNode, INDICES_EXTENSION_NAME_ACTION_NAME, new IndicesModuleRequest(indexModule), acknowledgedResponseHandler ); - /* - * Making async synchronous for now. - */ + // TODO: make asynchronous inProgressIndexNameFuture.get(100, TimeUnit.SECONDS); logger.info("Received ack response from Extension"); } catch (Exception e) { - logger.error(e.toString()); + try { + throw e; + } catch (Exception e1) { + logger.error(e.toString()); + } } } }); @@ -365,20 +454,22 @@ public String executor() { }; try { - logger.info("Sending request to extension"); + logger.info("Sending extension request type: " + INDICES_EXTENSION_POINT_ACTION_NAME); transportService.sendRequest( extensionNode, INDICES_EXTENSION_POINT_ACTION_NAME, new IndicesModuleRequest(indexModule), indicesModuleResponseHandler ); - /* - * Making async synchronous for now. - */ + // TODO: make asynchronous inProgressFuture.get(100, TimeUnit.SECONDS); logger.info("Received response from Extension"); } catch (Exception e) { - logger.error(e.toString()); + try { + throw e; + } catch (Exception e1) { + logger.error(e.toString()); + } } } @@ -421,10 +512,6 @@ public Path getExtensionsPath() { return extensionsPath; } - public List getUninitializedExtensions() { - return uninitializedExtensions; - } - public List getExtensions() { return extensions; } @@ -437,4 +524,28 @@ public ClusterService getClusterService() { return clusterService; } + public static String getRequestExtensionRegisterRestActions() { + return REQUEST_EXTENSION_REGISTER_REST_ACTIONS; + } + + public static String getRequestOpensearchNamedWriteableRegistry() { + return REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY; + } + + public static String getRequestOpensearchParseNamedWriteable() { + return REQUEST_OPENSEARCH_PARSE_NAMED_WRITEABLE; + } + + public static String getRequestRestExecuteOnExtensionAction() { + return REQUEST_REST_EXECUTE_ON_EXTENSION_ACTION; + } + + public Map getExtensionIdMap() { + return extensionIdMap; + } + + public RestActionsRequestHandler getRestActionsRequestHandler() { + return restActionsRequestHandler; + } + } diff --git a/server/src/main/java/org/opensearch/extensions/OpenSearchRequest.java b/server/src/main/java/org/opensearch/extensions/OpenSearchRequest.java new file mode 100644 index 0000000000000..62e66f09eb856 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/OpenSearchRequest.java @@ -0,0 +1,73 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.Objects; + +/** + * Request from OpenSearch to an Extension + * + * @opensearch.internal + */ +public class OpenSearchRequest extends TransportRequest { + + private static final Logger logger = LogManager.getLogger(OpenSearchRequest.class); + private ExtensionsManager.OpenSearchRequestType requestType; + + /** + * @param requestType String identifying the default extension point to invoke on the extension + */ + public OpenSearchRequest(ExtensionsManager.OpenSearchRequestType requestType) { + this.requestType = requestType; + } + + /** + * @param in StreamInput from which a string identifying the default extension point to invoke on the extension is read from + */ + public OpenSearchRequest(StreamInput in) throws IOException { + super(in); + this.requestType = in.readEnum(ExtensionsManager.OpenSearchRequestType.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeEnum(requestType); + } + + @Override + public String toString() { + return "OpenSearchRequest{" + "requestType=" + requestType + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenSearchRequest that = (OpenSearchRequest) o; + return Objects.equals(requestType, that.requestType); + } + + @Override + public int hashCode() { + return Objects.hash(requestType); + } + + public ExtensionsManager.OpenSearchRequestType getRequestType() { + return this.requestType; + } + +} diff --git a/server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java b/server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java new file mode 100644 index 0000000000000..a3603aaf22dd0 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Request to register extension Transport actions + * + * @opensearch.internal + */ +public class RegisterTransportActionsRequest extends TransportRequest { + private Map transportActions; + + public RegisterTransportActionsRequest(Map transportActions) { + this.transportActions = new HashMap<>(transportActions); + } + + public RegisterTransportActionsRequest(StreamInput in) throws IOException { + super(in); + Map actions = new HashMap<>(); + int actionCount = in.readVInt(); + for (int i = 0; i < actionCount; i++) { + try { + String actionName = in.readString(); + Class transportAction = Class.forName(in.readString()); + actions.put(actionName, transportAction); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException("Could not read transport action"); + } + } + this.transportActions = actions; + } + + public Map getTransportActions() { + return transportActions; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(this.transportActions.size()); + for (Map.Entry action : transportActions.entrySet()) { + out.writeString(action.getKey()); + out.writeString(action.getValue().getName()); + } + } + + @Override + public String toString() { + return "TransportActionsRequest{actions=" + transportActions + "}"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + RegisterTransportActionsRequest that = (RegisterTransportActionsRequest) obj; + return Objects.equals(transportActions, that.transportActions); + } + + @Override + public int hashCode() { + return Objects.hash(transportActions); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsRequest.java b/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsRequest.java new file mode 100644 index 0000000000000..8c190ff416a62 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsRequest.java @@ -0,0 +1,72 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +/** + * Request to register extension REST actions + * + * @opensearch.internal + */ +public class RegisterRestActionsRequest extends TransportRequest { + private String uniqueId; + private List restActions; + + public RegisterRestActionsRequest(String uniqueId, List restActions) { + this.uniqueId = uniqueId; + this.restActions = new ArrayList<>(restActions); + } + + public RegisterRestActionsRequest(StreamInput in) throws IOException { + super(in); + uniqueId = in.readString(); + restActions = in.readStringList(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(uniqueId); + out.writeStringCollection(restActions); + } + + public String getUniqueId() { + return uniqueId; + } + + public List getRestActions() { + return new ArrayList<>(restActions); + } + + @Override + public String toString() { + return "RestActionsRequest{uniqueId=" + uniqueId + ", restActions=" + restActions + "}"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + RegisterRestActionsRequest that = (RegisterRestActionsRequest) obj; + return Objects.equals(uniqueId, that.uniqueId) && Objects.equals(restActions, that.restActions); + } + + @Override + public int hashCode() { + return Objects.hash(uniqueId, restActions); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsResponse.java b/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsResponse.java new file mode 100644 index 0000000000000..c0a79ad32ce89 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsResponse.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportResponse; + +import java.io.IOException; + +/** + * Response to register REST Actions request. + * + * @opensearch.internal + */ +public class RegisterRestActionsResponse extends TransportResponse { + private String response; + + public RegisterRestActionsResponse(String response) { + this.response = response; + } + + public RegisterRestActionsResponse(StreamInput in) throws IOException { + response = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(response); + } + + public String getResponse() { + return response; + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java b/server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java new file mode 100644 index 0000000000000..e24f5d519bf81 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.extensions.DiscoveryExtensionNode; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; +import org.opensearch.transport.TransportResponse; +import org.opensearch.transport.TransportService; + +import java.util.Map; + +/** + * Handles requests to register extension REST actions. + * + * @opensearch.internal + */ +public class RestActionsRequestHandler { + + private final RestController restController; + private final Map extensionIdMap; + private final TransportService transportService; + + /** + * Instantiates a new REST Actions Request Handler using the Node's RestController. + * + * @param restController The Node's {@link RestController}. + * @param extensionIdMap A map of extension uniqueId to DiscoveryExtensionNode + * @param transportService The Node's transportService + */ + public RestActionsRequestHandler( + RestController restController, + Map extensionIdMap, + TransportService transportService + ) { + this.restController = restController; + this.extensionIdMap = extensionIdMap; + this.transportService = transportService; + } + + /** + * Handles a {@link RegisterRestActionsRequest}. + * + * @param restActionsRequest The request to handle. + * @return A {@link RegisterRestActionsResponse} indicating success. + * @throws Exception if the request is not handled properly. + */ + public TransportResponse handleRegisterRestActionsRequest(RegisterRestActionsRequest restActionsRequest) throws Exception { + DiscoveryExtensionNode discoveryExtensionNode = extensionIdMap.get(restActionsRequest.getUniqueId()); + RestHandler handler = new RestSendToExtensionAction(restActionsRequest, discoveryExtensionNode, transportService); + restController.registerHandler(handler); + return new RegisterRestActionsResponse( + "Registered extension " + restActionsRequest.getUniqueId() + " to handle REST Actions " + restActionsRequest.getRestActions() + ); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionRequest.java b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionRequest.java new file mode 100644 index 0000000000000..128dad2645b42 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionRequest.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.Objects; + +/** + * Request to execute REST actions on extension node + * + * @opensearch.internal + */ +public class RestExecuteOnExtensionRequest extends TransportRequest { + + private Method method; + private String uri; + + public RestExecuteOnExtensionRequest(Method method, String uri) { + this.method = method; + this.uri = uri; + } + + public RestExecuteOnExtensionRequest(StreamInput in) throws IOException { + super(in); + try { + method = RestRequest.Method.valueOf(in.readString()); + } catch (IllegalArgumentException e) { + throw new IOException(e); + } + uri = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(method.name()); + out.writeString(uri); + } + + public Method getMethod() { + return method; + } + + public String getUri() { + return uri; + } + + @Override + public String toString() { + return "RestExecuteOnExtensionRequest{method=" + method + ", uri=" + uri + "}"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + RestExecuteOnExtensionRequest that = (RestExecuteOnExtensionRequest) obj; + return Objects.equals(method, that.method) && Objects.equals(uri, that.uri); + } + + @Override + public int hashCode() { + return Objects.hash(method, uri); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java new file mode 100644 index 0000000000000..b7d7aae3faaab --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java @@ -0,0 +1,112 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.RestStatus; +import org.opensearch.transport.TransportResponse; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * Response to execute REST Actions on the extension node. Wraps the components of a {@link RestResponse}. + * + * @opensearch.internal + */ +public class RestExecuteOnExtensionResponse extends TransportResponse { + private RestStatus status; + private String contentType; + private byte[] content; + private Map> headers; + + /** + * Instantiate this object with a status and response string. + * + * @param status The REST status. + * @param responseString The response content as a String. + */ + public RestExecuteOnExtensionResponse(RestStatus status, String responseString) { + this(status, BytesRestResponse.TEXT_CONTENT_TYPE, responseString.getBytes(StandardCharsets.UTF_8), Collections.emptyMap()); + } + + /** + * Instantiate this object with the components of a {@link RestResponse}. + * + * @param status The REST status. + * @param contentType The type of the content. + * @param content The content. + * @param headers The headers. + */ + public RestExecuteOnExtensionResponse(RestStatus status, String contentType, byte[] content, Map> headers) { + setStatus(status); + setContentType(contentType); + setContent(content); + setHeaders(headers); + } + + /** + * Instantiate this object from a Transport Stream + * + * @param in The stream input. + * @throws IOException on transport failure. + */ + public RestExecuteOnExtensionResponse(StreamInput in) throws IOException { + setStatus(RestStatus.readFrom(in)); + setContentType(in.readString()); + setContent(in.readByteArray()); + setHeaders(in.readMapOfLists(StreamInput::readString, StreamInput::readString)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + RestStatus.writeTo(out, status); + out.writeString(contentType); + out.writeByteArray(content); + out.writeMapOfLists(headers, StreamOutput::writeString, StreamOutput::writeString); + } + + public RestStatus getStatus() { + return status; + } + + public void setStatus(RestStatus status) { + this.status = status; + } + + public String getContentType() { + return contentType; + } + + public void setContentType(String contentType) { + this.contentType = contentType; + } + + public byte[] getContent() { + return content; + } + + public void setContent(byte[] content) { + this.content = content; + } + + public Map> getHeaders() { + return headers; + } + + public void setHeaders(Map> headers) { + this.headers = Map.copyOf(headers); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java b/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java new file mode 100644 index 0000000000000..8f5a2e6b1c8a5 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java @@ -0,0 +1,186 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.extensions.DiscoveryExtensionNode; +import org.opensearch.extensions.ExtensionsManager; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.rest.RestStatus; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.unmodifiableList; + +/** + * An action that forwards REST requests to an extension + */ +public class RestSendToExtensionAction extends BaseRestHandler { + + private static final String SEND_TO_EXTENSION_ACTION = "send_to_extension_action"; + private static final Logger logger = LogManager.getLogger(RestSendToExtensionAction.class); + private static final String CONSUMED_PARAMS_KEY = "extension.consumed.parameters"; + + private final List routes; + private final String uriPrefix; + private final DiscoveryExtensionNode discoveryExtensionNode; + private final TransportService transportService; + + /** + * Instantiates this object using a {@link RegisterRestActionsRequest} to populate the routes. + * + * @param restActionsRequest A request encapsulating a list of Strings with the API methods and URIs. + * @param transportService The OpenSearch transport service + * @param discoveryExtensionNode The extension node to which to send actions + */ + public RestSendToExtensionAction( + RegisterRestActionsRequest restActionsRequest, + DiscoveryExtensionNode discoveryExtensionNode, + TransportService transportService + ) { + this.uriPrefix = "/_extensions/_" + restActionsRequest.getUniqueId(); + List restActionsAsRoutes = new ArrayList<>(); + for (String restAction : restActionsRequest.getRestActions()) { + RestRequest.Method method; + String uri; + try { + int delim = restAction.indexOf(' '); + method = RestRequest.Method.valueOf(restAction.substring(0, delim)); + uri = uriPrefix + restAction.substring(delim).trim(); + } catch (IndexOutOfBoundsException | IllegalArgumentException e) { + throw new IllegalArgumentException(restAction + " does not begin with a valid REST method"); + } + logger.info("Registering: " + method + " " + uri); + restActionsAsRoutes.add(new Route(method, uri)); + } + this.routes = unmodifiableList(restActionsAsRoutes); + this.discoveryExtensionNode = discoveryExtensionNode; + this.transportService = transportService; + } + + @Override + public String getName() { + return SEND_TO_EXTENSION_ACTION; + } + + @Override + public List routes() { + return this.routes; + } + + @Override + public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { + Method method = request.getHttpRequest().method(); + String uri = request.getHttpRequest().uri(); + if (uri.startsWith(uriPrefix)) { + uri = uri.substring(uriPrefix.length()); + } + String message = "Forwarding the request " + method + " " + uri + " to " + discoveryExtensionNode; + logger.info(message); + // Initialize response. Values will be changed in the handler. + final RestExecuteOnExtensionResponse restExecuteOnExtensionResponse = new RestExecuteOnExtensionResponse( + RestStatus.INTERNAL_SERVER_ERROR, + BytesRestResponse.TEXT_CONTENT_TYPE, + message.getBytes(StandardCharsets.UTF_8), + emptyMap() + ); + final CompletableFuture inProgressFuture = new CompletableFuture<>(); + final TransportResponseHandler restExecuteOnExtensionResponseHandler = new TransportResponseHandler< + RestExecuteOnExtensionResponse>() { + + @Override + public RestExecuteOnExtensionResponse read(StreamInput in) throws IOException { + return new RestExecuteOnExtensionResponse(in); + } + + @Override + public void handleResponse(RestExecuteOnExtensionResponse response) { + logger.info("Received response from extension: {}", response.getStatus()); + restExecuteOnExtensionResponse.setStatus(response.getStatus()); + restExecuteOnExtensionResponse.setContentType(response.getContentType()); + restExecuteOnExtensionResponse.setContent(response.getContent()); + // Extract the consumed parameters from the header + Map> headers = response.getHeaders(); + List consumedParams = headers.get(CONSUMED_PARAMS_KEY); + if (consumedParams != null) { + consumedParams.stream().forEach(p -> request.param(p)); + } + Map> headersWithoutConsumedParams = headers.entrySet() + .stream() + .filter(e -> !e.getKey().equals(CONSUMED_PARAMS_KEY)) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); + restExecuteOnExtensionResponse.setHeaders(headersWithoutConsumedParams); + inProgressFuture.complete(response); + } + + @Override + public void handleException(TransportException exp) { + logger.error("REST request failed", exp); + // Status is already defaulted to 500 (INTERNAL_SERVER_ERROR) + byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8); + restExecuteOnExtensionResponse.setContent(responseBytes); + inProgressFuture.completeExceptionally(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.GENERIC; + } + }; + try { + transportService.sendRequest( + discoveryExtensionNode, + ExtensionsManager.REQUEST_REST_EXECUTE_ON_EXTENSION_ACTION, + new RestExecuteOnExtensionRequest(method, uri), + restExecuteOnExtensionResponseHandler + ); + try { + // TODO: make asynchronous + inProgressFuture.get(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + return channel -> channel.sendResponse( + new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, "No response from extension to request.") + ); + } + } catch (Exception e) { + logger.info("Failed to send REST Actions to extension " + discoveryExtensionNode.getName(), e); + } + BytesRestResponse restResponse = new BytesRestResponse( + restExecuteOnExtensionResponse.getStatus(), + restExecuteOnExtensionResponse.getContentType(), + restExecuteOnExtensionResponse.getContent() + ); + for (Entry> headerEntry : restExecuteOnExtensionResponse.getHeaders().entrySet()) { + for (String value : headerEntry.getValue()) { + restResponse.addHeader(headerEntry.getKey(), value); + } + } + + return channel -> channel.sendResponse(restResponse); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/package-info.java b/server/src/main/java/org/opensearch/extensions/rest/package-info.java new file mode 100644 index 0000000000000..5a52a295da6ad --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** REST Actions classes for the extensions package. OpenSearch extensions provide extensibility to OpenSearch.*/ +package org.opensearch.extensions.rest; diff --git a/server/src/main/java/org/opensearch/index/AcknowledgedResponse.java b/server/src/main/java/org/opensearch/index/AcknowledgedResponse.java deleted file mode 100644 index 5993a81158d30..0000000000000 --- a/server/src/main/java/org/opensearch/index/AcknowledgedResponse.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.index; - -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.transport.TransportResponse; - -import java.io.IOException; - -/** - * Response for index name of onIndexModule extension point - * - * @opensearch.internal - */ -public class AcknowledgedResponse extends TransportResponse { - private boolean requestAck; - - public AcknowledgedResponse(StreamInput in) throws IOException { - this.requestAck = in.readBoolean(); - } - - public AcknowledgedResponse(Boolean requestAck) { - this.requestAck = requestAck; - } - - public void AcknowledgedResponse(StreamInput in) throws IOException { - this.requestAck = in.readBoolean(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeBoolean(requestAck); - } - -} diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index bea3b1ac12451..1d981938abf24 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -785,6 +785,7 @@ protected Node( modules.add(actionModule); final RestController restController = actionModule.getRestController(); + final NetworkModule networkModule = new NetworkModule( settings, pluginsService.filterPlugins(NetworkPlugin.class), @@ -829,8 +830,7 @@ protected Node( taskHeaders ); if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) { - this.extensionsManager.setTransportService(transportService); - this.extensionsManager.setClusterService(clusterService); + this.extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); } final GatewayMetaState gatewayMetaState = new GatewayMetaState(); final ResponseCollectorService responseCollectorService = new ResponseCollectorService(clusterService); diff --git a/server/src/main/java/org/opensearch/transport/TransportService.java b/server/src/main/java/org/opensearch/transport/TransportService.java index d9a234b7e5682..edfa8596ba5d6 100644 --- a/server/src/main/java/org/opensearch/transport/TransportService.java +++ b/server/src/main/java/org/opensearch/transport/TransportService.java @@ -776,7 +776,7 @@ public final void sendRequest( final TransportResponseHandler handler ) { try { - logger.info("Action: " + action); + logger.debug("Action: " + action); final TransportResponseHandler delegate; if (request.getParentTask().isSet()) { // TODO: capture the connection instead so that we can cancel child tasks on the remote connections. diff --git a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java index cbd86378c0fac..d45e51ea2bbc8 100644 --- a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java +++ b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java @@ -30,7 +30,9 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Objects; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; @@ -38,6 +40,7 @@ import org.junit.Before; import org.opensearch.Version; import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.ClusterSettingsResponse; import org.opensearch.cluster.LocalNodeResponse; import org.opensearch.cluster.metadata.IndexMetadata; @@ -45,7 +48,10 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.PathUtils; +import org.opensearch.common.io.stream.NamedWriteable; import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; @@ -54,6 +60,8 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.env.Environment; import org.opensearch.env.TestEnvironment; +import org.opensearch.extensions.rest.RegisterRestActionsRequest; +import org.opensearch.extensions.rest.RegisterRestActionsResponse; import org.opensearch.index.IndexModule; import org.opensearch.index.IndexSettings; import org.opensearch.index.analysis.AnalysisRegistry; @@ -61,27 +69,58 @@ import org.opensearch.index.engine.InternalEngineFactory; import org.opensearch.indices.breaker.NoneCircuitBreakerService; import org.opensearch.plugins.PluginInfo; +import org.opensearch.rest.RestController; import org.opensearch.test.IndexSettingsModule; import org.opensearch.test.MockLogAppender; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.transport.MockTransportService; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.ConnectTransportException; import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportService; import org.opensearch.transport.nio.MockNioTransport; +import org.opensearch.usage.UsageService; public class ExtensionsManagerTests extends OpenSearchTestCase { private TransportService transportService; + private RestController restController; private ClusterService clusterService; private MockNioTransport transport; + private Path extensionDir; private final ThreadPool threadPool = new TestThreadPool(ExtensionsManagerTests.class.getSimpleName()); private final Settings settings = Settings.builder() .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) .put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()) .build(); + private final List extensionsYmlLines = Arrays.asList( + "extensions:", + " - name: firstExtension", + " uniqueId: uniqueid1", + " hostName: 'myIndependentPluginHost1'", + " hostAddress: '127.0.0.0'", + " port: '9300'", + " version: '0.0.7'", + " description: Fake description 1", + " opensearchVersion: '3.0.0'", + " javaVersion: '14'", + " className: fakeClass1", + " customFolderName: fakeFolder1", + " hasNativeController: false", + " - name: secondExtension", + " uniqueId: 'uniqueid2'", + " hostName: 'myIndependentPluginHost2'", + " hostAddress: '127.0.0.1'", + " port: '9301'", + " version: '3.14.16'", + " description: Fake description 2", + " opensearchVersion: '2.0.0'", + " javaVersion: '17'", + " className: fakeClass2", + " customFolderName: fakeFolder2", + " hasNativeController: true" + ); @Before public void setup() throws Exception { @@ -112,9 +151,19 @@ public void setup() throws Exception { null, Collections.emptySet() ); + restController = new RestController( + emptySet(), + null, + new NodeClient(Settings.EMPTY, threadPool), + new NoneCircuitBreakerService(), + new UsageService() + ); clusterService = createClusterService(threadPool); + + extensionDir = createTempDir(); } + @Override @After public void tearDown() throws Exception { super.tearDown(); @@ -122,36 +171,9 @@ public void tearDown() throws Exception { ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); } - public void testExtensionsDiscovery() throws Exception { + public void testDiscover() throws Exception { Path extensionDir = createTempDir(); - List extensionsYmlLines = Arrays.asList( - "extensions:", - " - name: firstExtension", - " uniqueId: uniqueid1", - " hostName: 'myIndependentPluginHost1'", - " hostAddress: '127.0.0.0'", - " port: '9300'", - " version: '0.0.7'", - " description: Fake description 1", - " opensearchVersion: '3.0.0'", - " javaVersion: '14'", - " className: fakeClass1", - " customFolderName: fakeFolder1", - " hasNativeController: false", - " - name: secondExtension", - " uniqueId: 'uniqueid2'", - " hostName: 'myIndependentPluginHost2'", - " hostAddress: '127.0.0.1'", - " port: '9301'", - " version: '3.14.16'", - " description: Fake description 2", - " opensearchVersion: '2.0.0'", - " javaVersion: '17'", - " className: fakeClass2", - " customFolderName: fakeFolder2", - " hasNativeController: true" - ); Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); @@ -203,7 +225,48 @@ public void testExtensionsDiscovery() throws Exception { ) ) ); - assertEquals(expectedUninitializedExtensions, extensionsManager.getUninitializedExtensions()); + assertEquals(expectedUninitializedExtensions.size(), extensionsManager.getExtensionIdMap().values().size()); + assertTrue(expectedUninitializedExtensions.containsAll(extensionsManager.getExtensionIdMap().values())); + assertTrue(extensionsManager.getExtensionIdMap().values().containsAll(expectedUninitializedExtensions)); + } + + public void testNonUniqueExtensionsDiscovery() throws Exception { + Path extensionDir = createTempDir(); + + List nonUniqueYmlLines = extensionsYmlLines.stream() + .map(s -> s.replace("uniqueid2", "uniqueid1")) + .collect(Collectors.toList()); + Files.write(extensionDir.resolve("extensions.yml"), nonUniqueYmlLines, StandardCharsets.UTF_8); + + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + + List expectedUninitializedExtensions = new ArrayList(); + + expectedUninitializedExtensions.add( + new DiscoveryExtensionNode( + "firstExtension", + "uniqueid1", + "uniqueid1", + "myIndependentPluginHost1", + "127.0.0.0", + new TransportAddress(InetAddress.getByName("127.0.0.0"), 9300), + new HashMap(), + Version.fromString("3.0.0"), + new PluginInfo( + "firstExtension", + "Fake description 1", + "0.0.7", + Version.fromString("3.0.0"), + "14", + "fakeClass1", + new ArrayList(), + false + ) + ) + ); + assertEquals(expectedUninitializedExtensions.size(), extensionsManager.getExtensionIdMap().values().size()); + assertTrue(expectedUninitializedExtensions.containsAll(extensionsManager.getExtensionIdMap().values())); + assertTrue(extensionsManager.getExtensionIdMap().values().containsAll(expectedUninitializedExtensions)); } public void testNonAccessibleDirectory() throws Exception { @@ -216,8 +279,6 @@ public void testNonAccessibleDirectory() throws Exception { } public void testNoExtensionsFile() throws Exception { - Path extensionDir = createTempDir(); - Settings settings = Settings.builder().build(); try (MockLogAppender mockLogAppender = MockLogAppender.createForLoggers(LogManager.getLogger(ExtensionsManager.class))) { @@ -240,8 +301,8 @@ public void testNoExtensionsFile() throws Exception { public void testEmptyExtensionsFile() throws Exception { Path extensionDir = createTempDir(); - List extensionsYmlLines = Arrays.asList(); - Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); + List emptyExtensionsYmlLines = Arrays.asList(); + Files.write(extensionDir.resolve("extensions.yml"), emptyExtensionsYmlLines, StandardCharsets.UTF_8); Settings settings = Settings.builder().build(); @@ -251,69 +312,114 @@ public void testEmptyExtensionsFile() throws Exception { public void testInitialize() throws Exception { Path extensionDir = createTempDir(); - List extensionsYmlLines = Arrays.asList( - "extensions:", - " - name: firstExtension", - " uniqueId: uniqueid1", - " hostName: 'myIndependentPluginHost1'", - " hostAddress: '127.0.0.0'", - " port: '9300'", - " version: '0.0.7'", - " description: Fake description 1", - " opensearchVersion: '3.0.0'", - " javaVersion: '14'", - " className: fakeClass1", - " customFolderName: fakeFolder1", - " hasNativeController: false", - " - name: secondExtension", - " uniqueId: 'uniqueid2'", - " hostName: 'myIndependentPluginHost2'", - " hostAddress: '127.0.0.1'", - " port: '9301'", - " version: '3.14.16'", - " description: Fake description 2", - " opensearchVersion: '2.0.0'", - " javaVersion: '17'", - " className: fakeClass2", - " customFolderName: fakeFolder2", - " hasNativeController: true" - ); Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); transportService.start(); transportService.acceptIncomingRequests(); - extensionsManager.setTransportService(transportService); + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); + + try (MockLogAppender mockLogAppender = MockLogAppender.createForLoggers(LogManager.getLogger(ExtensionsManager.class))) { + + mockLogAppender.addExpectation( + new MockLogAppender.SeenEventExpectation( + "Connect Transport Exception 1", + "org.opensearch.extensions.ExtensionsManager", + Level.ERROR, + "ConnectTransportException[[firstExtension][127.0.0.0:9300] connect_timeout[30s]]" + ) + ); + + mockLogAppender.addExpectation( + new MockLogAppender.SeenEventExpectation( + "Connect Transport Exception 2", + "org.opensearch.extensions.ExtensionsManager", + Level.ERROR, + "ConnectTransportException[[secondExtension][127.0.0.1:9301] connect_exception]; nested: ConnectException[Connection refused];" + ) + ); - expectThrows(ConnectTransportException.class, () -> extensionsManager.initialize()); + extensionsManager.initialize(); + // Test needs to be changed to mock the connection between the local node and an extension. Assert statment is commented out for + // now. + // Link to issue: https://github.com/opensearch-project/OpenSearch/issues/4045 + // mockLogAppender.assertAllExpectationsMatched(); + } } - public void testHandleExtensionRequest() throws Exception { + public void testHandleRegisterRestActionsRequest() throws Exception { + + Path extensionDir = createTempDir(); + + Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); + + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); + String uniqueIdStr = "uniqueid1"; + List actionsList = List.of("GET /foo", "PUT /bar", "POST /baz"); + RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList); + TransportResponse response = extensionsManager.getRestActionsRequestHandler() + .handleRegisterRestActionsRequest(registerActionsRequest); + assertEquals(RegisterRestActionsResponse.class, response.getClass()); + assertTrue(((RegisterRestActionsResponse) response).getResponse().contains(uniqueIdStr)); + assertTrue(((RegisterRestActionsResponse) response).getResponse().contains(actionsList.toString())); + } + + public void testHandleRegisterRestActionsRequestWithInvalidMethod() throws Exception { + + Path extensionDir = createTempDir(); + + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); + String uniqueIdStr = "uniqueid1"; + List actionsList = List.of("FOO /foo", "PUT /bar", "POST /baz"); + RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList); + expectThrows( + IllegalArgumentException.class, + () -> extensionsManager.getRestActionsRequestHandler().handleRegisterRestActionsRequest(registerActionsRequest) + ); + } + + public void testHandleRegisterRestActionsRequestWithInvalidUri() throws Exception { Path extensionDir = createTempDir(); ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); - extensionsManager.setTransportService(transportService); - extensionsManager.setClusterService(clusterService); + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); + String uniqueIdStr = "uniqueid1"; + List actionsList = List.of("GET", "PUT /bar", "POST /baz"); + RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList); + expectThrows( + IllegalArgumentException.class, + () -> extensionsManager.getRestActionsRequestHandler().handleRegisterRestActionsRequest(registerActionsRequest) + ); + } + + public void testHandleExtensionRequest() throws Exception { + + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); ExtensionRequest clusterStateRequest = new ExtensionRequest(ExtensionsManager.RequestType.REQUEST_EXTENSION_CLUSTER_STATE); - assertEquals(extensionsManager.handleExtensionRequest(clusterStateRequest).getClass(), ClusterStateResponse.class); + assertEquals(ClusterStateResponse.class, extensionsManager.handleExtensionRequest(clusterStateRequest).getClass()); ExtensionRequest clusterSettingRequest = new ExtensionRequest(ExtensionsManager.RequestType.REQUEST_EXTENSION_CLUSTER_SETTINGS); - assertEquals(extensionsManager.handleExtensionRequest(clusterSettingRequest).getClass(), ClusterSettingsResponse.class); + assertEquals(ClusterSettingsResponse.class, extensionsManager.handleExtensionRequest(clusterSettingRequest).getClass()); ExtensionRequest localNodeRequest = new ExtensionRequest(ExtensionsManager.RequestType.REQUEST_EXTENSION_LOCAL_NODE); - assertEquals(extensionsManager.handleExtensionRequest(localNodeRequest).getClass(), LocalNodeResponse.class); + assertEquals(LocalNodeResponse.class, extensionsManager.handleExtensionRequest(localNodeRequest).getClass()); ExtensionRequest exceptionRequest = new ExtensionRequest(ExtensionsManager.RequestType.GET_SETTINGS); Exception exception = expectThrows(IllegalStateException.class, () -> extensionsManager.handleExtensionRequest(exceptionRequest)); - assertEquals(exception.getMessage(), "Handler not present for the provided request: " + exceptionRequest.getRequestType()); + assertEquals("Handler not present for the provided request", exception.getMessage()); } public void testRegisterHandler() throws Exception { - Path extensionDir = createTempDir(); ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); @@ -329,49 +435,57 @@ public void testRegisterHandler() throws Exception { ) ); - extensionsManager.setTransportService(mockTransportService); - verify(mockTransportService, times(3)).registerRequestHandler(anyString(), anyString(), anyBoolean(), anyBoolean(), any(), any()); + extensionsManager.initializeServicesAndRestHandler(restController, mockTransportService, clusterService); + verify(mockTransportService, times(5)).registerRequestHandler(anyString(), anyString(), anyBoolean(), anyBoolean(), any(), any()); } - public void testOnIndexModule() throws Exception { + private static class Example implements NamedWriteable { + public static final String INVALID_NAME = "invalid_name"; + public static final String NAME = "example"; + private final String message; - Path extensionDir = createTempDir(); + Example(String message) { + this.message = message; + } - List extensionsYmlLines = Arrays.asList( - "extensions:", - " - name: firstExtension", - " uniqueId: uniqueid1", - " hostName: 'myIndependentPluginHost1'", - " hostAddress: '127.0.0.0'", - " port: '9300'", - " version: '0.0.7'", - " description: Fake description 1", - " opensearchVersion: '3.0.0'", - " javaVersion: '14'", - " className: fakeClass1", - " customFolderName: fakeFolder1", - " hasNativeController: false", - " - name: secondExtension", - " uniqueId: 'uniqueid2'", - " hostName: 'myIndependentPluginHost2'", - " hostAddress: '127.0.0.1'", - " port: '9301'", - " version: '3.14.16'", - " description: Fake description 2", - " opensearchVersion: '2.0.0'", - " javaVersion: '17'", - " className: fakeClass2", - " customFolderName: fakeFolder2", - " hasNativeController: true" - ); + Example(StreamInput in) throws IOException { + this.message = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(message); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + Example that = (Example) o; + return Objects.equals(message, that.message); + } + + @Override + public int hashCode() { + return Objects.hash(message); + } + } + + public void testOnIndexModule() throws Exception { Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); transportService.start(); transportService.acceptIncomingRequests(); - extensionsManager.setTransportService(transportService); + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); Environment environment = TestEnvironment.newEnvironment(settings); AnalysisRegistry emptyAnalysisRegistry = new AnalysisRegistry( diff --git a/server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java b/server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java new file mode 100644 index 0000000000000..ed36cc5290bb1 --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import org.junit.Before; +import org.opensearch.common.collect.Map; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +public class RegisterTransportActionsRequestTests extends OpenSearchTestCase { + private RegisterTransportActionsRequest originalRequest; + + @Before + public void setup() { + this.originalRequest = new RegisterTransportActionsRequest(Map.of("testAction", Map.class)); + } + + public void testRegisterTransportActionsRequest() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + originalRequest.writeTo(output); + StreamInput input = output.bytes().streamInput(); + RegisterTransportActionsRequest parsedRequest = new RegisterTransportActionsRequest(input); + assertEquals(parsedRequest.getTransportActions(), originalRequest.getTransportActions()); + assertEquals(parsedRequest.getTransportActions().get("testAction"), originalRequest.getTransportActions().get("testAction")); + assertEquals(parsedRequest.getTransportActions().size(), originalRequest.getTransportActions().size()); + assertEquals(parsedRequest.hashCode(), originalRequest.hashCode()); + assertTrue(originalRequest.equals(parsedRequest)); + } + + public void testToString() { + assertEquals(originalRequest.toString(), "TransportActionsRequest{actions={testAction=class org.opensearch.common.collect.Map}}"); + } +} diff --git a/server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java b/server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java new file mode 100644 index 0000000000000..a8f1739ce82f2 --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import java.util.List; + +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.test.OpenSearchTestCase; + +public class RegisterRestActionsTests extends OpenSearchTestCase { + + public void testRegisterRestActionsRequest() throws Exception { + String uniqueIdStr = "uniqueid1"; + List expected = List.of("GET /foo", "PUT /bar", "POST /baz"); + RegisterRestActionsRequest registerRestActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, expected); + + assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId()); + List restActions = registerRestActionsRequest.getRestActions(); + assertEquals(expected.size(), restActions.size()); + assertTrue(restActions.containsAll(expected)); + assertTrue(expected.containsAll(restActions)); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + registerRestActionsRequest.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + registerRestActionsRequest = new RegisterRestActionsRequest(in); + + assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId()); + restActions = registerRestActionsRequest.getRestActions(); + assertEquals(expected.size(), restActions.size()); + assertTrue(restActions.containsAll(expected)); + assertTrue(expected.containsAll(restActions)); + } + } + } + + public void testRegisterRestActionsResponse() throws Exception { + String response = "This is a response"; + RegisterRestActionsResponse registerRestActionsResponse = new RegisterRestActionsResponse(response); + + assertEquals(response, registerRestActionsResponse.getResponse()); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + registerRestActionsResponse.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + registerRestActionsResponse = new RegisterRestActionsResponse(in); + + assertEquals(response, registerRestActionsResponse.getResponse()); + } + } + } +} diff --git a/server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java b/server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java new file mode 100644 index 0000000000000..98521ddcf1e26 --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.rest.RestStatus; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.test.OpenSearchTestCase; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +public class RestExecuteOnExtensionTests extends OpenSearchTestCase { + + public void testRestExecuteOnExtensionRequest() throws Exception { + Method expectedMethod = Method.GET; + String expectedUri = "/test/uri"; + RestExecuteOnExtensionRequest request = new RestExecuteOnExtensionRequest(expectedMethod, expectedUri); + + assertEquals(expectedMethod, request.getMethod()); + assertEquals(expectedUri, request.getUri()); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + request.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + request = new RestExecuteOnExtensionRequest(in); + + assertEquals(expectedMethod, request.getMethod()); + assertEquals(expectedUri, request.getUri()); + } + } + } + + public void testRestExecuteOnExtensionResponse() throws Exception { + RestStatus expectedStatus = RestStatus.OK; + String expectedContentType = BytesRestResponse.TEXT_CONTENT_TYPE; + String expectedResponse = "Test response"; + byte[] expectedResponseBytes = expectedResponse.getBytes(StandardCharsets.UTF_8); + + RestExecuteOnExtensionResponse response = new RestExecuteOnExtensionResponse(expectedStatus, expectedResponse); + + assertEquals(expectedStatus, response.getStatus()); + assertEquals(expectedContentType, response.getContentType()); + assertArrayEquals(expectedResponseBytes, response.getContent()); + assertEquals(0, response.getHeaders().size()); + + String headerKey = "foo"; + List headerValueList = List.of("bar", "baz"); + Map> expectedHeaders = Map.of(headerKey, headerValueList); + + response = new RestExecuteOnExtensionResponse(expectedStatus, expectedContentType, expectedResponseBytes, expectedHeaders); + + assertEquals(expectedStatus, response.getStatus()); + assertEquals(expectedContentType, response.getContentType()); + assertArrayEquals(expectedResponseBytes, response.getContent()); + + assertEquals(1, expectedHeaders.keySet().size()); + assertTrue(expectedHeaders.containsKey(headerKey)); + + List fooList = expectedHeaders.get(headerKey); + assertEquals(2, fooList.size()); + assertTrue(fooList.containsAll(headerValueList)); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + response.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + response = new RestExecuteOnExtensionResponse(in); + + assertEquals(expectedStatus, response.getStatus()); + assertEquals(expectedContentType, response.getContentType()); + assertArrayEquals(expectedResponseBytes, response.getContent()); + + assertEquals(1, expectedHeaders.keySet().size()); + assertTrue(expectedHeaders.containsKey(headerKey)); + + fooList = expectedHeaders.get(headerKey); + assertEquals(2, fooList.size()); + assertTrue(fooList.containsAll(headerValueList)); + } + } + } +} diff --git a/server/src/test/java/org/opensearch/extensions/rest/RestSendToExtensionActionTests.java b/server/src/test/java/org/opensearch/extensions/rest/RestSendToExtensionActionTests.java new file mode 100644 index 0000000000000..2a593a8d251e9 --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/rest/RestSendToExtensionActionTests.java @@ -0,0 +1,159 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.network.NetworkService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.util.PageCacheRecycler; +import org.opensearch.extensions.DiscoveryExtensionNode; +import org.opensearch.indices.breaker.NoneCircuitBreakerService; +import org.opensearch.plugins.PluginInfo; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.transport.MockTransportService; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.nio.MockNioTransport; + +public class RestSendToExtensionActionTests extends OpenSearchTestCase { + + private TransportService transportService; + private MockNioTransport transport; + private DiscoveryExtensionNode discoveryExtensionNode; + private final ThreadPool threadPool = new TestThreadPool(RestSendToExtensionActionTests.class.getSimpleName()); + + @Before + public void setup() throws Exception { + Settings settings = Settings.builder().put("cluster.name", "test").build(); + transport = new MockNioTransport( + settings, + Version.CURRENT, + threadPool, + new NetworkService(Collections.emptyList()), + PageCacheRecycler.NON_RECYCLING_INSTANCE, + new NamedWriteableRegistry(Collections.emptyList()), + new NoneCircuitBreakerService() + ); + transportService = new MockTransportService( + settings, + transport, + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + (boundAddress) -> new DiscoveryNode( + "test_node", + "test_node", + boundAddress.publishAddress(), + emptyMap(), + emptySet(), + Version.CURRENT + ), + null, + Collections.emptySet() + ); + discoveryExtensionNode = new DiscoveryExtensionNode( + "firstExtension", + "uniqueid1", + "uniqueid1", + "myIndependentPluginHost1", + "127.0.0.0", + new TransportAddress(InetAddress.getByName("127.0.0.0"), 9300), + new HashMap(), + Version.fromString("3.0.0"), + new PluginInfo( + "firstExtension", + "Fake description 1", + "0.0.7", + Version.fromString("3.0.0"), + "14", + "fakeClass1", + new ArrayList(), + false + ) + ); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + transportService.close(); + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + + public void testRestSendToExtensionAction() throws Exception { + RegisterRestActionsRequest registerRestActionRequest = new RegisterRestActionsRequest( + "uniqueid1", + List.of("GET /foo", "PUT /bar", "POST /baz") + ); + RestSendToExtensionAction restSendToExtensionAction = new RestSendToExtensionAction( + registerRestActionRequest, + discoveryExtensionNode, + transportService + ); + + assertEquals("send_to_extension_action", restSendToExtensionAction.getName()); + List expected = new ArrayList<>(); + String uriPrefix = "/_extensions/_uniqueid1"; + expected.add(new Route(Method.GET, uriPrefix + "/foo")); + expected.add(new Route(Method.PUT, uriPrefix + "/bar")); + expected.add(new Route(Method.POST, uriPrefix + "/baz")); + + List routes = restSendToExtensionAction.routes(); + assertEquals(expected.size(), routes.size()); + List expectedPaths = expected.stream().map(Route::getPath).collect(Collectors.toList()); + List paths = routes.stream().map(Route::getPath).collect(Collectors.toList()); + List expectedMethods = expected.stream().map(Route::getMethod).collect(Collectors.toList()); + List methods = routes.stream().map(Route::getMethod).collect(Collectors.toList()); + assertTrue(paths.containsAll(expectedPaths)); + assertTrue(expectedPaths.containsAll(paths)); + assertTrue(methods.containsAll(expectedMethods)); + assertTrue(expectedMethods.containsAll(methods)); + } + + public void testRestSendToExtensionActionBadMethod() throws Exception { + RegisterRestActionsRequest registerRestActionRequest = new RegisterRestActionsRequest( + "uniqueid1", + List.of("/foo", "PUT /bar", "POST /baz") + ); + expectThrows( + IllegalArgumentException.class, + () -> new RestSendToExtensionAction(registerRestActionRequest, discoveryExtensionNode, transportService) + ); + } + + public void testRestSendToExtensionActionMissingUri() throws Exception { + RegisterRestActionsRequest registerRestActionRequest = new RegisterRestActionsRequest( + "uniqueid1", + List.of("GET", "PUT /bar", "POST /baz") + ); + expectThrows( + IllegalArgumentException.class, + () -> new RestSendToExtensionAction(registerRestActionRequest, discoveryExtensionNode, transportService) + ); + } +}