From 867c1b3e353cffa3c110db5bd06d9de07a2657db Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Wed, 30 Nov 2022 16:00:19 -0800 Subject: [PATCH 1/5] AD model performance benchmark (#729) (#734) This PR adds an AD model performance benchmark so that we can compare model performance across versions. Regarding benchmark data, we randomly generated synthetic data with known anomalies inserted throughout the signal. In particular, these are one/two/four dimensional data where each dimension is a noisy cosine wave. Anomalies are inserted into one dimension with 0.003 probability. Anomalies across each dimension can be independent or dependent. We have approximately 5000 observations per data set. The data set is generated using the same random seed so the result is comparable across versions. We also backported #600 so that we can capture the performance data in CI output. Testing done: * added unit tests to run the benchmark. Signed-off-by: Kaituo Li --- .github/workflows/benchmark.yml | 33 ++ DEVELOPER_GUIDE.md | 2 + build.gradle | 20 +- .../ad/e2e/AbstractSyntheticDataTest.java | 243 +++++++++++ .../ad/e2e/DetectionResultEvalutationIT.java | 394 +----------------- .../ad/e2e/SingleStreamModelPerfIT.java | 230 ++++++++++ .../ad/ml/AbstractCosineDataTest.java | 254 +++++++++++ .../ad/ml/EntityColdStarterTests.java | 219 +--------- .../opensearch/ad/ml/HCADModelPerfTests.java | 342 +++++++++++++++ .../ad/util/LabelledAnomalyGenerator.java | 40 +- 10 files changed, 1165 insertions(+), 612 deletions(-) create mode 100644 .github/workflows/benchmark.yml create mode 100644 src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java create mode 100644 src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java create mode 100644 src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java create mode 100644 src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 000000000..328e051ca --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,33 @@ +name: Run AD benchmark +on: + push: + branches: + - "*" + pull_request: + branches: + - "*" + +jobs: + Build-ad: + strategy: + matrix: + java: [17] + fail-fast: false + + name: Run Anomaly detection model performance benchmark + runs-on: ubuntu-latest + + steps: + - name: Setup Java ${{ matrix.java }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.java }} + + # anomaly-detection + - name: Checkout AD + uses: actions/checkout@v2 + + - name: Build and Run Tests + run: | + ./gradlew ':test' --tests "org.opensearch.ad.ml.HCADModelPerfTests" -Dtests.seed=2AEBDBBAE75AC5E0 -Dtests.security.manager=false -Dtests.locale=es-CU -Dtests.timezone=Chile/EasterIsland -Dtest.logs=true -Dmodel-benchmark=true + ./gradlew integTest --tests "org.opensearch.ad.e2e.SingleStreamModelPerfIT" -Dtests.seed=60CDDB34427ACD0C -Dtests.security.manager=false -Dtests.locale=kab-DZ -Dtests.timezone=Asia/Hebron -Dtest.logs=true -Dmodel-benchmark=true \ No newline at end of file diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index ef9460fb4..e4c66e8f0 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -48,6 +48,8 @@ Currently we just put RCF jar in lib as dependency. Plan to publish to Maven and 8. `./gradlew adBwcCluster#rollingUpgradeClusterTask -Dtests.security.manager=false` launches a cluster with three nodes of bwc version of OpenSearch with anomaly-detection and job-scheduler and tests backwards compatibility by performing rolling upgrade of each node with the current version of OpenSearch with anomaly-detection and job-scheduler. 9. `./gradlew adBwcCluster#fullRestartClusterTask -Dtests.security.manager=false` launches a cluster with three nodes of bwc version of OpenSearch with anomaly-detection and job-scheduler and tests backwards compatibility by performing a full restart on the cluster upgrading all the nodes with the current version of OpenSearch with anomaly-detection and job-scheduler. 10. `./gradlew bwcTestSuite -Dtests.security.manager=false` runs all the above bwc tests combined. +11. `./gradlew ':test' --tests "org.opensearch.ad.ml.HCADModelPerfTests" -Dtests.seed=2AEBDBBAE75AC5E0 -Dtests.security.manager=false -Dtests.locale=es-CU -Dtests.timezone=Chile/EasterIsland -Dtest.logs=true -Dmodel-benchmark=true` launches HCAD model performance tests and logs the result in the standard output +12. `./gradlew integTest --tests "org.opensearch.ad.e2e.SingleStreamModelPerfIT" -Dtests.seed=60CDDB34427ACD0C -Dtests.security.manager=false -Dtests.locale=kab-DZ -Dtests.timezone=Asia/Hebron -Dtest.logs=true -Dmodel-benchmark=true` launches single stream AD model performance tests and logs the result in the standard output When launching a cluster using one of the above commands logs are placed in `/build/cluster/run node0/opensearch-/logs`. Though the logs are teed to the console, in practices it's best to check the actual log file. diff --git a/build.gradle b/build.gradle index 1177c4978..c66674c16 100644 --- a/build.gradle +++ b/build.gradle @@ -134,7 +134,7 @@ configurations.all { if (it.state != Configuration.State.UNRESOLVED) return resolutionStrategy { force "joda-time:joda-time:${versions.joda}" - force "com.fasterxml.jackson.core:jackson-core:2.13.4" + force "com.fasterxml.jackson.core:jackson-core:2.14.0" force "commons-logging:commons-logging:${versions.commonslogging}" force "org.apache.httpcomponents:httpcore:${versions.httpcore}" force "commons-codec:commons-codec:${versions.commonscodec}" @@ -214,6 +214,12 @@ test { } include '**/*Tests.class' systemProperty 'tests.security.manager', 'false' + + if (System.getProperty("model-benchmark") == null || System.getProperty("model-benchmark") == "false") { + filter { + excludeTestsMatching "org.opensearch.ad.ml.HCADModelPerfTests" + } + } } task integTest(type: RestIntegTestTask) { @@ -259,6 +265,12 @@ integTest { } } + if (System.getProperty("model-benchmark") == null || System.getProperty("model-benchmark") == "false") { + filter { + excludeTestsMatching "org.opensearch.ad.e2e.SingleStreamModelPerfIT" + } + } + // The 'doFirst' delays till execution time. doFirst { // Tell the test JVM if the cluster JVM is running under a debugger so that tests can @@ -665,9 +677,9 @@ dependencies { implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc3' // force Jackson version to avoid version conflict issue - implementation "com.fasterxml.jackson.core:jackson-core:2.13.4" - implementation "com.fasterxml.jackson.core:jackson-databind:2.13.4.2" - implementation "com.fasterxml.jackson.core:jackson-annotations:2.13.4" + implementation "com.fasterxml.jackson.core:jackson-core:2.14.0" + implementation "com.fasterxml.jackson.core:jackson-databind:2.14.0" + implementation "com.fasterxml.jackson.core:jackson-annotations:2.14.0" // used for serializing/deserializing rcf models. implementation group: 'io.protostuff', name: 'protostuff-core', version: '1.8.0' diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java new file mode 100644 index 000000000..630c21e06 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java @@ -0,0 +1,243 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.e2e; + +import static org.opensearch.ad.TestHelpers.toHttpEntity; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.charset.Charset; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.opensearch.ad.ODFERestTestCase; +import org.opensearch.ad.TestHelpers; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.WarningsHandler; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.json.JsonXContent; + +import com.google.common.collect.ImmutableList; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +public class AbstractSyntheticDataTest extends ODFERestTestCase { + + /** + * In real time AD, we mute a node for a detector if that node keeps returning + * ResourceNotFoundException (5 times in a row). This is a problem for batch mode + * testing as we issue a large amount of requests quickly. Due to the speed, we + * won't be able to finish cold start before the ResourceNotFoundException mutes + * a node. Since our test case has only one node, there is no other nodes to fall + * back on. Here we disable such fault tolerance by setting max retries before + * muting to a large number and the actual wait time during muting to 0. + * + * @throws IOException when failing to create http request body + */ + protected void disableResourceNotFoundFaultTolerence() throws IOException { + XContentBuilder settingCommand = JsonXContent.contentBuilder(); + + settingCommand.startObject(); + settingCommand.startObject("persistent"); + settingCommand.field(MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); + settingCommand.field(BACKOFF_MINUTES.getKey(), 0); + settingCommand.endObject(); + settingCommand.endObject(); + Request request = new Request("PUT", "/_cluster/settings"); + request.setJsonEntity(Strings.toString(settingCommand)); + + adminClient().performRequest(request); + } + + protected List getData(String datasetFileName) throws Exception { + JsonArray jsonArray = JsonParser + .parseReader(new FileReader(new File(getClass().getResource(datasetFileName).toURI()), Charset.defaultCharset())) + .getAsJsonArray(); + List list = new ArrayList<>(jsonArray.size()); + jsonArray.iterator().forEachRemaining(i -> list.add(i.getAsJsonObject())); + return list; + } + + protected Map getDetectionResult(String detectorId, Instant begin, Instant end, RestClient client) { + try { + Request request = new Request( + "POST", + String.format(Locale.ROOT, "/_opendistro/_anomaly_detection/detectors/%s/_run", detectorId) + ); + request + .setJsonEntity( + String.format(Locale.ROOT, "{ \"period_start\": %d, \"period_end\": %d }", begin.toEpochMilli(), end.toEpochMilli()) + ); + return entityAsMap(client.performRequest(request)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + protected void bulkIndexTrainData( + String datasetName, + List data, + int trainTestSplit, + RestClient client, + String categoryField + ) throws Exception { + Request request = new Request("PUT", datasetName); + String requestBody = null; + if (Strings.isEmpty(categoryField)) { + requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; + } else { + requestBody = String + .format( + Locale.ROOT, + "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" }," + + "\"%s\": { \"type\": \"keyword\"} } } }", + categoryField + ); + } + + request.setJsonEntity(requestBody); + setWarningHandler(request, false); + client.performRequest(request); + Thread.sleep(1_000); + + StringBuilder bulkRequestBuilder = new StringBuilder(); + for (int i = 0; i < trainTestSplit; i++) { + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); + bulkRequestBuilder.append(data.get(i).toString()).append("\n"); + } + TestHelpers + .makeRequest( + client, + "POST", + "_bulk?refresh=true", + null, + toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Thread.sleep(1_000); + waitAllSyncheticDataIngested(trainTestSplit, datasetName, client); + } + + protected String createDetector( + String datasetName, + int intervalMinutes, + RestClient client, + String categoryField, + long windowDelayInMins + ) throws Exception { + Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/"); + String requestBody = null; + if (Strings.isEmpty(categoryField)) { + requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"schema_version\": 0 }", + datasetName, + intervalMinutes, + windowDelayInMins + ); + } else { + requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"category_field\": [\"%s\"], " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"schema_version\": 0 }", + datasetName, + intervalMinutes, + categoryField, + windowDelayInMins + ); + } + + request.setJsonEntity(requestBody); + Map response = entityAsMap(client.performRequest(request)); + String detectorId = (String) response.get("_id"); + Thread.sleep(1_000); + return detectorId; + } + + protected void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { + int maxWaitCycles = 3; + do { + Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); + request + .setJsonEntity( + String + .format( + Locale.ROOT, + "{\"query\": {" + + " \"match_all\": {}" + + " }," + + " \"size\": 1," + + " \"sort\": [" + + " {" + + " \"timestamp\": {" + + " \"order\": \"desc\"" + + " }" + + " }" + + " ]}" + ) + ); + // Make sure all of the test data has been ingested + // Expected response: + // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} + Response response = client.performRequest(request); + JsonObject json = JsonParser + .parseReader(new InputStreamReader(response.getEntity().getContent(), Charset.defaultCharset())) + .getAsJsonObject(); + JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); + if (hits != null + && hits.size() == 1 + && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { + break; + } else { + request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); + client.performRequest(request); + } + Thread.sleep(1_000); + } while (maxWaitCycles-- >= 0); + } + + protected void setWarningHandler(Request request, boolean strictDeprecationMode) { + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE); + request.setOptions(options.build()); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java index 09810556b..0c2cce832 100644 --- a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java +++ b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java @@ -12,212 +12,38 @@ package org.opensearch.ad.e2e; import static org.opensearch.ad.TestHelpers.toHttpEntity; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; -import java.io.File; -import java.io.FileReader; -import java.io.IOException; -import java.io.InputStreamReader; -import java.nio.charset.Charset; import java.text.SimpleDateFormat; import java.time.Clock; import java.time.Instant; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit; -import java.util.AbstractMap.SimpleEntry; import java.util.ArrayList; import java.util.Calendar; import java.util.Date; -import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; import java.util.TimeZone; import java.util.concurrent.TimeUnit; -import org.apache.http.HttpHeaders; -import org.apache.http.message.BasicHeader; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.core.Logger; -import org.opensearch.ad.ODFERestTestCase; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.client.Request; -import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; import org.opensearch.client.RestClient; -import org.opensearch.client.WarningsHandler; -import org.opensearch.common.Strings; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.common.xcontent.support.XContentMapValues; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.gson.JsonArray; import com.google.gson.JsonElement; import com.google.gson.JsonObject; -import com.google.gson.JsonParser; import com.google.gson.JsonPrimitive; -public class DetectionResultEvalutationIT extends ODFERestTestCase { +public class DetectionResultEvalutationIT extends AbstractSyntheticDataTest { protected static final Logger LOG = (Logger) LogManager.getLogger(DetectionResultEvalutationIT.class); - public void testDataset() throws Exception { - // TODO: this test case will run for a much longer time and timeout with security enabled - if (!isHttps()) { - disableResourceNotFoundFaultTolerence(); - verifyAnomaly("synthetic", 1, 1500, 8, .4, .9, 10); - } - } - - private void verifyAnomaly( - String datasetName, - int intervalMinutes, - int trainTestSplit, - int shingleSize, - double minPrecision, - double minRecall, - double maxError - ) throws Exception { - RestClient client = client(); - - String dataFileName = String.format(Locale.ROOT, "data/%s.data", datasetName); - String labelFileName = String.format(Locale.ROOT, "data/%s.label", datasetName); - - List data = getData(dataFileName); - List> anomalies = getAnomalyWindows(labelFileName); - - bulkIndexTrainData(datasetName, data, trainTestSplit, client, null); - // single-stream detector can use window delay 0 here because we give the run api the actual data time - String detectorId = createDetector(datasetName, intervalMinutes, client, null, 0); - simulateSingleStreamStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); - bulkIndexTestData(data, datasetName, trainTestSplit, client); - double[] testResults = getTestResults(detectorId, data, trainTestSplit, intervalMinutes, anomalies, client); - verifyTestResults(testResults, anomalies, minPrecision, minRecall, maxError); - } - - private void verifyTestResults( - double[] testResults, - List> anomalies, - double minPrecision, - double minRecall, - double maxError - ) { - - double positives = testResults[0]; - double truePositives = testResults[1]; - double positiveAnomalies = testResults[2]; - double errors = testResults[3]; - - // precision = predicted anomaly points that are true / predicted anomaly points - double precision = positives > 0 ? truePositives / positives : 1; - assertTrue(precision >= minPrecision); - - // recall = windows containing predicted anomaly points / total anomaly windows - double recall = anomalies.size() > 0 ? positiveAnomalies / anomalies.size() : 1; - assertTrue(recall >= minRecall); - - assertTrue(errors <= maxError); - LOG.info("Precision: {}, Window recall: {}", precision, recall); - } - - private int isAnomaly(Instant time, List> labels) { - for (int i = 0; i < labels.size(); i++) { - Entry window = labels.get(i); - if (time.compareTo(window.getKey()) >= 0 && time.compareTo(window.getValue()) <= 0) { - return i; - } - } - return -1; - } - - private double[] getTestResults( - String detectorId, - List data, - int trainTestSplit, - int intervalMinutes, - List> anomalies, - RestClient client - ) throws Exception { - - double positives = 0; - double truePositives = 0; - Set positiveAnomalies = new HashSet<>(); - double errors = 0; - for (int i = trainTestSplit; i < data.size(); i++) { - Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(i).get("timestamp").getAsString())); - Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); - try { - Map response = getDetectionResult(detectorId, begin, end, client); - double anomalyGrade = (double) response.get("anomalyGrade"); - if (anomalyGrade > 0) { - positives++; - int result = isAnomaly(begin, anomalies); - if (result != -1) { - truePositives++; - positiveAnomalies.add(result); - } - } - } catch (Exception e) { - errors++; - logger.error("failed to get detection results", e); - } - } - return new double[] { positives, truePositives, positiveAnomalies.size(), errors }; - } - - /** - * Simulate starting detector without waiting for job scheduler to run. Our build process is already very slow (takes 10 mins+) - * to finish integration tests. This method triggers run API to simulate job scheduler execution in a fast-paced way. - * @param detectorId Detector Id - * @param data Data in Json format - * @param trainTestSplit Training data size - * @param shingleSize Shingle size - * @param intervalMinutes Detector Interval - * @param client OpenSearch Client - * @throws Exception when failing to query/indexing from/to OpenSearch - */ - private void simulateSingleStreamStartDetector( - String detectorId, - List data, - int trainTestSplit, - int shingleSize, - int intervalMinutes, - RestClient client - ) throws Exception { - - Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(trainTestSplit - 1).get("timestamp").getAsString())); - - Instant begin = null; - Instant end = null; - for (int i = 0; i < shingleSize; i++) { - begin = trainTime.minus(intervalMinutes * (shingleSize - 1 - i), ChronoUnit.MINUTES); - end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); - try { - getDetectionResult(detectorId, begin, end, client); - } catch (Exception e) {} - } - // It takes time to wait for model initialization - long startTime = System.currentTimeMillis(); - do { - try { - Thread.sleep(5_000); - getDetectionResult(detectorId, begin, end, client); - break; - } catch (Exception e) { - long duration = System.currentTimeMillis() - startTime; - // we wait at most 60 secs - if (duration > 60_000) { - throw new RuntimeException(e); - } - } - } while (true); - } - /** * Simulate starting the given HCAD detector. * @param detectorId Detector Id @@ -274,224 +100,6 @@ private void simulateHCADStartDetector( } while (duration <= 60_000); } - private String createDetector(String datasetName, int intervalMinutes, RestClient client, String categoryField, long windowDelayInMins) - throws Exception { - Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/"); - String requestBody = null; - if (Strings.isEmpty(categoryField)) { - requestBody = String - .format( - Locale.ROOT, - "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" - + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " - + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" - + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " - + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " - + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," - + "\"schema_version\": 0 }", - datasetName, - intervalMinutes, - windowDelayInMins - ); - } else { - requestBody = String - .format( - Locale.ROOT, - "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" - + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " - + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" - + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " - + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " - + "\"category_field\": [\"%s\"], " - + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," - + "\"schema_version\": 0 }", - datasetName, - intervalMinutes, - categoryField, - windowDelayInMins - ); - } - - request.setJsonEntity(requestBody); - Map response = entityAsMap(client.performRequest(request)); - String detectorId = (String) response.get("_id"); - Thread.sleep(1_000); - return detectorId; - } - - private List> getAnomalyWindows(String labalFileName) throws Exception { - JsonArray windows = JsonParser - .parseReader(new FileReader(new File(getClass().getResource(labalFileName).toURI()), Charset.defaultCharset())) - .getAsJsonArray(); - List> anomalies = new ArrayList<>(windows.size()); - for (int i = 0; i < windows.size(); i++) { - JsonArray window = windows.get(i).getAsJsonArray(); - Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(0).getAsString())); - Instant end = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(1).getAsString())); - anomalies.add(new SimpleEntry<>(begin, end)); - } - return anomalies; - } - - private void bulkIndexTrainData(String datasetName, List data, int trainTestSplit, RestClient client, String categoryField) - throws Exception { - Request request = new Request("PUT", datasetName); - String requestBody = null; - if (Strings.isEmpty(categoryField)) { - requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," - + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; - } else { - requestBody = String - .format( - Locale.ROOT, - "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," - + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" }," - + "\"%s\": { \"type\": \"keyword\"} } } }", - categoryField - ); - } - - request.setJsonEntity(requestBody); - setWarningHandler(request, false); - client.performRequest(request); - Thread.sleep(1_000); - - StringBuilder bulkRequestBuilder = new StringBuilder(); - for (int i = 0; i < trainTestSplit; i++) { - bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); - bulkRequestBuilder.append(data.get(i).toString()).append("\n"); - } - TestHelpers - .makeRequest( - client, - "POST", - "_bulk?refresh=true", - null, - toHttpEntity(bulkRequestBuilder.toString()), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) - ); - Thread.sleep(1_000); - waitAllSyncheticDataIngested(trainTestSplit, datasetName, client); - } - - private void bulkIndexTestData(List data, String datasetName, int trainTestSplit, RestClient client) throws Exception { - StringBuilder bulkRequestBuilder = new StringBuilder(); - for (int i = trainTestSplit; i < data.size(); i++) { - bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); - bulkRequestBuilder.append(data.get(i).toString()).append("\n"); - } - TestHelpers - .makeRequest( - client, - "POST", - "_bulk?refresh=true", - null, - toHttpEntity(bulkRequestBuilder.toString()), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) - ); - Thread.sleep(1_000); - waitAllSyncheticDataIngested(data.size(), datasetName, client); - } - - private void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { - int maxWaitCycles = 3; - do { - Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); - request - .setJsonEntity( - String - .format( - Locale.ROOT, - "{\"query\": {" - + " \"match_all\": {}" - + " }," - + " \"size\": 1," - + " \"sort\": [" - + " {" - + " \"timestamp\": {" - + " \"order\": \"desc\"" - + " }" - + " }" - + " ]}" - ) - ); - // Make sure all of the test data has been ingested - // Expected response: - // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} - Response response = client.performRequest(request); - JsonObject json = JsonParser - .parseReader(new InputStreamReader(response.getEntity().getContent(), Charset.defaultCharset())) - .getAsJsonObject(); - JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); - if (hits != null - && hits.size() == 1 - && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { - break; - } else { - request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); - client.performRequest(request); - } - Thread.sleep(1_000); - } while (maxWaitCycles-- >= 0); - } - - private void setWarningHandler(Request request, boolean strictDeprecationMode) { - RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); - options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE); - request.setOptions(options.build()); - } - - private List getData(String datasetFileName) throws Exception { - JsonArray jsonArray = JsonParser - .parseReader(new FileReader(new File(getClass().getResource(datasetFileName).toURI()), Charset.defaultCharset())) - .getAsJsonArray(); - List list = new ArrayList<>(jsonArray.size()); - jsonArray.iterator().forEachRemaining(i -> list.add(i.getAsJsonObject())); - return list; - } - - private Map getDetectionResult(String detectorId, Instant begin, Instant end, RestClient client) { - try { - Request request = new Request( - "POST", - String.format(Locale.ROOT, "/_opendistro/_anomaly_detection/detectors/%s/_run", detectorId) - ); - request - .setJsonEntity( - String.format(Locale.ROOT, "{ \"period_start\": %d, \"period_end\": %d }", begin.toEpochMilli(), end.toEpochMilli()) - ); - return entityAsMap(client.performRequest(request)); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - /** - * In real time AD, we mute a node for a detector if that node keeps returning - * ResourceNotFoundException (5 times in a row). This is a problem for batch mode - * testing as we issue a large amount of requests quickly. Due to the speed, we - * won't be able to finish cold start before the ResourceNotFoundException mutes - * a node. Since our test case has only one node, there is no other nodes to fall - * back on. Here we disable such fault tolerance by setting max retries before - * muting to a large number and the actual wait time during muting to 0. - * - * @throws IOException when failing to create http request body - */ - private void disableResourceNotFoundFaultTolerence() throws IOException { - XContentBuilder settingCommand = JsonXContent.contentBuilder(); - - settingCommand.startObject(); - settingCommand.startObject("persistent"); - settingCommand.field(MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); - settingCommand.field(BACKOFF_MINUTES.getKey(), 0); - settingCommand.endObject(); - settingCommand.endObject(); - Request request = new Request("PUT", "/_cluster/settings"); - request.setJsonEntity(Strings.toString(settingCommand)); - - adminClient().performRequest(request); - } - public void testValidationIntervalRecommendation() throws Exception { RestClient client = client(); long recDetectorIntervalMillis = 180000; diff --git a/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java new file mode 100644 index 000000000..61a79ee96 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java @@ -0,0 +1,230 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.e2e; + +import static org.opensearch.ad.TestHelpers.toHttpEntity; + +import java.io.File; +import java.io.FileReader; +import java.nio.charset.Charset; +import java.time.Instant; +import java.time.format.DateTimeFormatter; +import java.time.temporal.ChronoUnit; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.opensearch.ad.TestHelpers; +import org.opensearch.client.RestClient; + +import com.google.common.collect.ImmutableList; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +public class SingleStreamModelPerfIT extends AbstractSyntheticDataTest { + protected static final Logger LOG = (Logger) LogManager.getLogger(SingleStreamModelPerfIT.class); + + public void testDataset() throws Exception { + // TODO: this test case will run for a much longer time and timeout with security enabled + if (!isHttps()) { + disableResourceNotFoundFaultTolerence(); + verifyAnomaly("synthetic", 1, 1500, 8, .4, .9, 10); + } + } + + private void verifyAnomaly( + String datasetName, + int intervalMinutes, + int trainTestSplit, + int shingleSize, + double minPrecision, + double minRecall, + double maxError + ) throws Exception { + RestClient client = client(); + + String dataFileName = String.format(Locale.ROOT, "data/%s.data", datasetName); + String labelFileName = String.format(Locale.ROOT, "data/%s.label", datasetName); + + List data = getData(dataFileName); + List> anomalies = getAnomalyWindows(labelFileName); + + bulkIndexTrainData(datasetName, data, trainTestSplit, client, null); + // single-stream detector can use window delay 0 here because we give the run api the actual data time + String detectorId = createDetector(datasetName, intervalMinutes, client, null, 0); + simulateSingleStreamStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + bulkIndexTestData(data, datasetName, trainTestSplit, client); + double[] testResults = getTestResults(detectorId, data, trainTestSplit, intervalMinutes, anomalies, client); + verifyTestResults(testResults, anomalies, minPrecision, minRecall, maxError); + } + + private void verifyTestResults( + double[] testResults, + List> anomalies, + double minPrecision, + double minRecall, + double maxError + ) { + + double positives = testResults[0]; + double truePositives = testResults[1]; + double positiveAnomalies = testResults[2]; + double errors = testResults[3]; + + // precision = predicted anomaly points that are true / predicted anomaly points + double precision = positives > 0 ? truePositives / positives : 1; + assertTrue(precision >= minPrecision); + + // recall = windows containing predicted anomaly points / total anomaly windows + double recall = anomalies.size() > 0 ? positiveAnomalies / anomalies.size() : 1; + assertTrue(recall >= minRecall); + + assertTrue(errors <= maxError); + LOG.info("Precision: {}, Window recall: {}", precision, recall); + } + + private double[] getTestResults( + String detectorId, + List data, + int trainTestSplit, + int intervalMinutes, + List> anomalies, + RestClient client + ) throws Exception { + + double positives = 0; + double truePositives = 0; + Set positiveAnomalies = new HashSet<>(); + double errors = 0; + for (int i = trainTestSplit; i < data.size(); i++) { + Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(i).get("timestamp").getAsString())); + Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + Map response = getDetectionResult(detectorId, begin, end, client); + double anomalyGrade = (double) response.get("anomalyGrade"); + if (anomalyGrade > 0) { + positives++; + int result = isAnomaly(begin, anomalies); + if (result != -1) { + truePositives++; + positiveAnomalies.add(result); + } + } + } catch (Exception e) { + errors++; + logger.error("failed to get detection results", e); + } + } + return new double[] { positives, truePositives, positiveAnomalies.size(), errors }; + } + + private List> getAnomalyWindows(String labalFileName) throws Exception { + JsonArray windows = JsonParser + .parseReader(new FileReader(new File(getClass().getResource(labalFileName).toURI()), Charset.defaultCharset())) + .getAsJsonArray(); + List> anomalies = new ArrayList<>(windows.size()); + for (int i = 0; i < windows.size(); i++) { + JsonArray window = windows.get(i).getAsJsonArray(); + Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(0).getAsString())); + Instant end = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(1).getAsString())); + anomalies.add(new SimpleEntry<>(begin, end)); + } + return anomalies; + } + + /** + * Simulate starting detector without waiting for job scheduler to run. Our build process is already very slow (takes 10 mins+) + * to finish integration tests. This method triggers run API to simulate job scheduler execution in a fast-paced way. + * @param detectorId Detector Id + * @param data Data in Json format + * @param trainTestSplit Training data size + * @param shingleSize Shingle size + * @param intervalMinutes Detector Interval + * @param client OpenSearch Client + * @throws Exception when failing to query/indexing from/to OpenSearch + */ + private void simulateSingleStreamStartDetector( + String detectorId, + List data, + int trainTestSplit, + int shingleSize, + int intervalMinutes, + RestClient client + ) throws Exception { + + Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(trainTestSplit - 1).get("timestamp").getAsString())); + + Instant begin = null; + Instant end = null; + for (int i = 0; i < shingleSize; i++) { + begin = trainTime.minus(intervalMinutes * (shingleSize - 1 - i), ChronoUnit.MINUTES); + end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + getDetectionResult(detectorId, begin, end, client); + } catch (Exception e) {} + } + // It takes time to wait for model initialization + long startTime = System.currentTimeMillis(); + do { + try { + Thread.sleep(5_000); + getDetectionResult(detectorId, begin, end, client); + break; + } catch (Exception e) { + long duration = System.currentTimeMillis() - startTime; + // we wait at most 60 secs + if (duration > 60_000) { + throw new RuntimeException(e); + } + } + } while (true); + } + + private void bulkIndexTestData(List data, String datasetName, int trainTestSplit, RestClient client) throws Exception { + StringBuilder bulkRequestBuilder = new StringBuilder(); + for (int i = trainTestSplit; i < data.size(); i++) { + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); + bulkRequestBuilder.append(data.get(i).toString()).append("\n"); + } + TestHelpers + .makeRequest( + client, + "POST", + "_bulk?refresh=true", + null, + toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Thread.sleep(1_000); + waitAllSyncheticDataIngested(data.size(), datasetName, client); + } + + private int isAnomaly(Instant time, List> labels) { + for (int i = 0; i < labels.size(); i++) { + Entry window = labels.get(i); + if (time.compareTo(window.getKey()) >= 0 && time.compareTo(window.getValue()) <= 0) { + return i; + } + } + return -1; + } +} diff --git a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java new file mode 100644 index 000000000..4949770f4 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java @@ -0,0 +1,254 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; + +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.dataprocessor.IntegerSensitiveSingleFeatureLinearUniformInterpolator; +import org.opensearch.ad.dataprocessor.Interpolator; +import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; +import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import com.google.common.collect.ImmutableList; + +public class AbstractCosineDataTest extends AbstractADTest { + int numMinSamples; + String modelId; + String entityName; + String detectorId; + ModelState modelState; + Clock clock; + float priority; + EntityColdStarter entityColdStarter; + NodeStateManager stateManager; + SearchFeatureDao searchFeatureDao; + Interpolator interpolator; + CheckpointDao checkpoint; + FeatureManager featureManager; + Settings settings; + ThreadPool threadPool; + AtomicBoolean released; + Runnable releaseSemaphore; + ActionListener listener; + CountDownLatch inProgressLatch; + CheckpointWriteWorker checkpointWriteQueue; + Entity entity; + AnomalyDetector detector; + long rcfSeed; + ModelManager modelManager; + ClientUtil clientUtil; + ClusterService clusterService; + ClusterSettings clusterSettings; + DiscoveryNode discoveryNode; + Set> nodestateSetting; + + @SuppressWarnings("unchecked") + @Override + public void setUp() throws Exception { + super.setUp(); + numMinSamples = AnomalyDetectorSettings.NUM_MIN_SAMPLES; + + clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); + + settings = Settings.EMPTY; + + Client client = mock(Client.class); + clientUtil = mock(ClientUtil.class); + + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .build(); + when(clock.millis()).thenReturn(1602401500000L); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); + nodestateSetting.add(BACKOFF_MINUTES); + nodestateSetting.add(CHECKPOINT_SAVING_FREQ); + clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); + + discoveryNode = new DiscoveryNode( + "node1", + OpenSearchTestCase.buildNewFakeTransportAddress(), + Collections.emptyMap(), + DiscoveryNodeRole.BUILT_IN_ROLES, + Version.CURRENT + ); + + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + clientUtil, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + clusterService + ); + + SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = + new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); + interpolator = new LinearUniformInterpolator(singleFeatureLinearUniformInterpolator); + + searchFeatureDao = mock(SearchFeatureDao.class); + checkpoint = mock(CheckpointDao.class); + + featureManager = new FeatureManager( + searchFeatureDao, + interpolator, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + ); + + checkpointWriteQueue = mock(CheckpointWriteWorker.class); + + rcfSeed = 2051L; + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + interpolator, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + detectorId = "123"; + modelId = "123_entity_abc"; + entityName = "abc"; + priority = 0.3f; + entity = Entity.createSingleAttributeEntity("field", entityName); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + releaseSemaphore = () -> { + released.set(true); + inProgressLatch.countDown(); + }; + listener = ActionListener.wrap(releaseSemaphore); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class), + settings, + clusterService + ); + } + + protected void checkSemaphoreRelease() throws InterruptedException { + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + assertTrue(released.get()); + } + + public int searchInsert(long[] timestamps, long target) { + int pivot, left = 0, right = timestamps.length - 1; + while (left <= right) { + pivot = left + (right - left) / 2; + if (timestamps[pivot] == target) + return pivot; + if (target < timestamps[pivot]) + right = pivot - 1; + else + left = pivot + 1; + } + return left; + } +} diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index ddea2510b..1236c6217 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -19,9 +19,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; import java.io.IOException; import java.time.Clock; @@ -31,7 +28,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; -import java.util.HashSet; import java.util.List; import java.util.Map.Entry; import java.util.Optional; @@ -39,48 +35,29 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import org.junit.AfterClass; import org.junit.BeforeClass; -import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.AbstractADTest; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.common.exception.AnomalyDetectionException; -import org.opensearch.ad.dataprocessor.IntegerSensitiveSingleFeatureLinearUniformInterpolator; -import org.opensearch.ad.dataprocessor.Interpolator; -import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; -import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.settings.EnabledSetting; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; -import org.opensearch.test.ClusterServiceUtils; -import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ThreadPool; import test.org.opensearch.ad.util.LabelledAnomalyGenerator; import test.org.opensearch.ad.util.MLUtil; @@ -91,33 +68,7 @@ import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.collect.ImmutableList; -public class EntityColdStarterTests extends AbstractADTest { - int numMinSamples; - String modelId; - String entityName; - String detectorId; - ModelState modelState; - Clock clock; - float priority; - EntityColdStarter entityColdStarter; - NodeStateManager stateManager; - SearchFeatureDao searchFeatureDao; - Interpolator interpolator; - CheckpointDao checkpoint; - FeatureManager featureManager; - Settings settings; - ThreadPool threadPool; - AtomicBoolean released; - Runnable releaseSemaphore; - ActionListener listener; - CountDownLatch inProgressLatch; - CheckpointWriteWorker checkpointWriteQueue; - Entity entity; - AnomalyDetector detector; - long rcfSeed; - ModelManager modelManager; - ClientUtil clientUtil; - ClusterService clusterService; +public class EntityColdStarterTests extends AbstractCosineDataTest { @BeforeClass public static void initOnce() { @@ -136,145 +87,10 @@ public static void clearOnce() { EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false); } - @SuppressWarnings("unchecked") @Override public void setUp() throws Exception { super.setUp(); - numMinSamples = AnomalyDetectorSettings.NUM_MIN_SAMPLES; - - clock = mock(Clock.class); - when(clock.instant()).thenReturn(Instant.now()); - - threadPool = mock(ThreadPool.class); - setUpADThreadPool(threadPool); - - settings = Settings.EMPTY; - - Client client = mock(Client.class); - clientUtil = mock(ClientUtil.class); - - detector = TestHelpers.AnomalyDetectorBuilder - .newInstance() - .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) - .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) - .build(); - when(clock.millis()).thenReturn(1602401500000L); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - - listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); - - return null; - }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); - - Set> nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); - nodestateSetting.add(BACKOFF_MINUTES); - nodestateSetting.add(CHECKPOINT_SAVING_FREQ); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); - - DiscoveryNode discoveryNode = new DiscoveryNode( - "node1", - OpenSearchTestCase.buildNewFakeTransportAddress(), - Collections.emptyMap(), - DiscoveryNodeRole.BUILT_IN_ROLES, - Version.CURRENT - ); - - clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); - - stateManager = new NodeStateManager( - client, - xContentRegistry(), - settings, - clientUtil, - clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - clusterService - ); - - SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = - new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); - interpolator = new LinearUniformInterpolator(singleFeatureLinearUniformInterpolator); - - searchFeatureDao = mock(SearchFeatureDao.class); - checkpoint = mock(CheckpointDao.class); - - featureManager = new FeatureManager( - searchFeatureDao, - interpolator, - clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, - AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, - AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, - AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, - AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME - ); - - checkpointWriteQueue = mock(CheckpointWriteWorker.class); - - rcfSeed = 2051L; - entityColdStarter = new EntityColdStarter( - clock, - threadPool, - stateManager, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.TIME_DECAY, - numMinSamples, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - interpolator, - searchFeatureDao, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, - featureManager, - settings, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - rcfSeed, - AnomalyDetectorSettings.MAX_COLD_START_ROUNDS - ); EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.TRUE); - - detectorId = "123"; - modelId = "123_entity_abc"; - entityName = "abc"; - priority = 0.3f; - entity = Entity.createSingleAttributeEntity("field", entityName); - - released = new AtomicBoolean(); - - inProgressLatch = new CountDownLatch(1); - releaseSemaphore = () -> { - released.set(true); - inProgressLatch.countDown(); - }; - listener = ActionListener.wrap(releaseSemaphore); - - modelManager = new ModelManager( - mock(CheckpointDao.class), - mock(Clock.class), - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.TIME_DECAY, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, - AnomalyDetectorSettings.MIN_PREVIEW_SIZE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, - entityColdStarter, - mock(FeatureManager.class), - mock(MemoryTracker.class), - settings, - clusterService - - ); } @Override @@ -283,11 +99,6 @@ public void tearDown() throws Exception { super.tearDown(); } - private void checkSemaphoreRelease() throws InterruptedException { - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); - assertTrue(released.get()); - } - // train using samples directly public void testTrainUsingSamples() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(numMinSamples); @@ -757,7 +568,18 @@ private void accuracyTemplate(int detectorIntervalMins, float precisionThreshold LOG.info("seed = " + seed); // create labelled data MultiDimDataWithTime dataWithKeys = LabelledAnomalyGenerator - .getMultiDimData(dataSize + detector.getShingleSize() - 1, 50, 100, 5, seed, baseDimension, false, trainTestSplit, delta); + .getMultiDimData( + dataSize + detector.getShingleSize() - 1, + 50, + 100, + 5, + seed, + baseDimension, + false, + trainTestSplit, + delta, + false + ); long[] timestamps = dataWithKeys.timestampsMs; double[][] data = dataWithKeys.data; when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); @@ -858,21 +680,6 @@ public int compare(Entry p1, Entry p2) { assertTrue("precision is " + prec, prec >= precisionThreshold); assertTrue("recall is " + recall, recall >= recallThreshold); - LOG.info("Interval {}, Precision: {}, recall: {}", detectorIntervalMins, prec, recall); - } - - public int searchInsert(long[] timestamps, long target) { - int pivot, left = 0, right = timestamps.length - 1; - while (left <= right) { - pivot = left + (right - left) / 2; - if (timestamps[pivot] == target) - return pivot; - if (target < timestamps[pivot]) - right = pivot - 1; - else - left = pivot + 1; - } - return left; } public void testAccuracyTenMinuteInterval() throws Exception { diff --git a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java new file mode 100644 index 000000000..5d2849401 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java @@ -0,0 +1,342 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.temporal.ChronoUnit; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.lucene.tests.util.TimeUnits; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.ClusterServiceUtils; + +import test.org.opensearch.ad.util.LabelledAnomalyGenerator; +import test.org.opensearch.ad.util.MultiDimDataWithTime; + +import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; +import com.google.common.collect.ImmutableList; + +@TimeoutSuite(millis = 60 * TimeUnits.MINUTE) // rcf may be slow due to bounding box cache disabled +public class HCADModelPerfTests extends AbstractCosineDataTest { + + /** + * A template to perform precision/recall test by simulating HCAD logic with only one entity. + * + * @param detectorIntervalMins Detector interval + * @param precisionThreshold precision threshold + * @param recallThreshold recall threshold + * @param baseDimension the number of dimensions + * @param anomalyIndependent whether anomalies in each dimension is generated independently + * @throws Exception when failing to create anomaly detector or creating training data + */ + @SuppressWarnings("unchecked") + private void averageAccuracyTemplate( + int detectorIntervalMins, + float precisionThreshold, + float recallThreshold, + int baseDimension, + boolean anomalyIndependent + ) throws Exception { + int dataSize = 20 * AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; + int trainTestSplit = 300; + // detector interval + int interval = detectorIntervalMins; + int delta = 60000 * interval; + + int numberOfTrials = 10; + double prec = 0; + double recall = 0; + double totalPrec = 0; + double totalRecall = 0; + + // training data ranges from timestamps[0] ~ timestamps[trainTestSplit-1] + // set up detector + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(interval, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setShingleSize(AnomalyDetectorSettings.DEFAULT_SHINGLE_SIZE) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + for (int z = 1; z <= numberOfTrials; z++) { + long seed = z; + LOG.info("seed = " + seed); + // recreate in each loop; otherwise, we will have heap overflow issue. + searchFeatureDao = mock(SearchFeatureDao.class); + clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + featureManager = new FeatureManager( + searchFeatureDao, + interpolator, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + ); + + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + interpolator, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + seed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class), + settings, + clusterService + ); + + // create labelled data + MultiDimDataWithTime dataWithKeys = LabelledAnomalyGenerator + .getMultiDimData( + dataSize + detector.getShingleSize() - 1, + 50, + 100, + 5, + seed, + baseDimension, + false, + trainTestSplit, + delta, + anomalyIndependent + ); + + long[] timestamps = dataWithKeys.timestampsMs; + double[][] data = dataWithKeys.data; + when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(timestamps[0])); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + doAnswer(invocation -> { + List> ranges = invocation.getArgument(1); + List> coldStartSamples = new ArrayList<>(); + + Collections.sort(ranges, new Comparator>() { + @Override + public int compare(Entry p1, Entry p2) { + return Long.compare(p1.getKey(), p2.getKey()); + } + }); + for (int j = 0; j < ranges.size(); j++) { + Entry range = ranges.get(j); + Long start = range.getKey(); + int valueIndex = searchInsert(timestamps, start); + coldStartSamples.add(Optional.of(data[valueIndex])); + } + + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entity = Entity.createSingleAttributeEntity("field", entityName + z); + EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); + ModelState modelState = new ModelState<>( + model, + entity.getModelId(detectorId).get(), + detector.getDetectorId(), + ModelType.ENTITY.getName(), + clock, + priority + ); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + listener = ActionListener.wrap(() -> { + released.set(true); + inProgressLatch.countDown(); + }); + + entityColdStarter.trainModel(entity, detector.getDetectorId(), modelState, listener); + + checkSemaphoreRelease(); + assertTrue(model.getTrcf().isPresent()); + + int tp = 0; + int fp = 0; + int fn = 0; + long[] changeTimestamps = dataWithKeys.changeTimeStampsMs; + + for (int j = trainTestSplit; j < data.length; j++) { + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + if (result.getGrade() > 0) { + if (changeTimestamps[j] == 0) { + fp++; + } else { + tp++; + } + } else { + if (changeTimestamps[j] != 0) { + fn++; + } + // else ok + } + } + + if (tp + fp == 0) { + prec = 1; + } else { + prec = tp * 1.0 / (tp + fp); + } + + if (tp + fn == 0) { + recall = 1; + } else { + recall = tp * 1.0 / (tp + fn); + } + + totalPrec += prec; + totalRecall += recall; + modelState = null; + dataWithKeys = null; + reset(searchFeatureDao); + searchFeatureDao = null; + clusterService = null; + } + + double avgPrec = totalPrec / numberOfTrials; + double avgRecall = totalRecall / numberOfTrials; + LOG.info("{} features, Interval {}, Precision: {}, recall: {}", baseDimension, detectorIntervalMins, avgPrec, avgRecall); + assertTrue("average precision is " + avgPrec, avgPrec >= precisionThreshold); + assertTrue("average recall is " + avgRecall, avgRecall >= recallThreshold); + } + + /** + * Split average accuracy tests into two in case of time out per test. + * @throws Exception when failing to perform tests + */ + public void testAverageAccuracyDependent() throws Exception { + LOG.info("Anomalies are injected dependently"); + + // 10 minute interval, 4 features + averageAccuracyTemplate(10, 0.4f, 0.3f, 4, false); + + // 10 minute interval, 2 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 2, false); + + // 10 minute interval, 1 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 1, false); + + // 5 minute interval, 4 features + averageAccuracyTemplate(5, 0.4f, 0.3f, 4, false); + + // 5 minute interval, 2 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 2, false); + + // 5 minute interval, 1 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 1, false); + } + + /** + * Split average accuracy tests into two in case of time out per test. + * @throws Exception when failing to perform tests + */ + public void testAverageAccuracyIndependent() throws Exception { + LOG.info("Anomalies are injected independently"); + + // 10 minute interval, 4 features + averageAccuracyTemplate(10, 0.3f, 0.1f, 4, true); + + // 10 minute interval, 2 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 2, true); + + // 10 minute interval, 1 features + averageAccuracyTemplate(10, 0.3f, 0.4f, 1, true); + + // 5 minute interval, 4 features + averageAccuracyTemplate(5, 0.2f, 0.1f, 4, true); + + // 5 minute interval, 2 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 2, true); + + // 5 minute interval, 1 features + averageAccuracyTemplate(5, 0.3f, 0.4f, 1, true); + } +} diff --git a/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java b/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java index f2ef3cc2d..f77c135fb 100644 --- a/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java +++ b/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java @@ -29,6 +29,7 @@ public class LabelledAnomalyGenerator { * @param useSlope whether to use slope in cosine data * @param historicalData the number of historical points relative to now * @param delta point interval + * @param anomalyIndependent whether anomalies in each dimension is generated independently * @return the labelled data */ public static MultiDimDataWithTime getMultiDimData( @@ -40,7 +41,8 @@ public static MultiDimDataWithTime getMultiDimData( int baseDimension, boolean useSlope, int historicalData, - int delta + int delta, + boolean anomalyIndependent ) { double[][] data = new double[num][]; long[] timestamps = new long[num]; @@ -66,14 +68,34 @@ public static MultiDimDataWithTime getMultiDimData( startEpochMs += delta; data[i] = new double[baseDimension]; double[] newChange = new double[baseDimension]; - for (int j = 0; j < baseDimension; j++) { - data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); - if (noiseprg.nextDouble() < 0.01 && noiseprg.nextDouble() < 0.3) { - double factor = 5 * (1 + noiseprg.nextDouble()); - double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; - data[i][j] += newChange[j] = change; - changedTimestamps[i] = timestamps[i]; - changes[i] = newChange; + // decide whether we should inject anomalies at this point + // If we do this for each dimension, each dimension's anomalies + // are independent and will make it harder for RCF to detect anomalies. + // Doing it in point level will make each dimension's anomalies + // correlated. + if (anomalyIndependent) { + for (int j = 0; j < baseDimension; j++) { + data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); + if (noiseprg.nextDouble() < 0.01 && noiseprg.nextDouble() < 0.3) { + double factor = 5 * (1 + noiseprg.nextDouble()); + double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; + data[i][j] += newChange[j] = change; + changedTimestamps[i] = timestamps[i]; + changes[i] = newChange; + } + } + } else { + boolean flag = (noiseprg.nextDouble() < 0.01); + for (int j = 0; j < baseDimension; j++) { + data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); + // adding the condition < 0.3 so there is still some variance if all features have an anomaly or not + if (flag && noiseprg.nextDouble() < 0.3) { + double factor = 5 * (1 + noiseprg.nextDouble()); + double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; + data[i][j] += newChange[j] = change; + changedTimestamps[i] = timestamps[i]; + changes[i] = newChange; + } } } } From dc293acbf6d6ae50d850aabba3607df56a778ef1 Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Wed, 7 Dec 2022 12:52:48 -0800 Subject: [PATCH 2/5] Fix _source bug (#749) Signed-off-by: Amit Galitzky --- .../opensearch/ad/model/AnomalyDetector.java | 2 +- .../ad/rest/AbstractSearchAction.java | 44 ++----------------- .../rest/RestSearchAnomalyResultAction.java | 2 +- .../opensearch/ad/util/RestHandlerUtils.java | 29 +++++++++--- .../ad/util/RestHandlerUtilsTests.java | 40 +++++++++++++---- 5 files changed, 61 insertions(+), 56 deletions(-) diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index 2c069d703..da64b30ea 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -139,7 +139,7 @@ public class AnomalyDetector implements Writeable, ToXContentObject { * @param detectionInterval detecting interval * @param windowDelay max delay window for realtime data * @param shingleSize number of the most recent time intervals to form a shingled data point - * @param uiMetadata metadata used by Kibana + * @param uiMetadata metadata used by OpenSearch-Dashboards * @param schemaVersion anomaly detector index mapping version * @param lastUpdateTime detector's last update time * @param categoryFields a list of partition fields diff --git a/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java b/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java index 3bba68bcd..07c55e6df 100644 --- a/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java +++ b/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java @@ -13,8 +13,6 @@ import static org.opensearch.ad.util.RestHandlerUtils.getSourceContext; import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; -import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.util.ArrayList; @@ -27,17 +25,9 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.handler.AnomalyDetectorFunction; import org.opensearch.ad.settings.EnabledSetting; import org.opensearch.client.node.NodeClient; -import org.opensearch.common.bytes.BytesReference; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.ToXContentObject; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentParser; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; @@ -45,7 +35,6 @@ import org.opensearch.rest.RestResponse; import org.opensearch.rest.RestStatus; import org.opensearch.rest.action.RestResponseListener; -import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; /** @@ -82,21 +71,15 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); - searchSourceBuilder.fetchSource(getSourceContext(request)); + // order of response will be re-arranged everytime we use `_source`, we sometimes do this + // even if user doesn't give this field as we exclude ui_metadata if request isn't from OSD + // ref-link: https://github.com/elastic/elasticsearch/issues/17639 + searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(this.index); return channel -> client.execute(actionType, searchRequest, search(channel)); } - protected void executeWithAdmin(NodeClient client, AnomalyDetectorFunction function, RestChannel channel) { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - function.execute(); - } catch (Exception e) { - logger.error("Failed to execute with admin", e); - onFailure(channel, e); - } - } - protected void onFailure(RestChannel channel, Exception e) { try { channel.sendResponse(new BytesRestResponse(channel, e)); @@ -112,25 +95,6 @@ public RestResponse buildResponse(SearchResponse response) throws Exception { if (response.isTimedOut()) { return new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, response.toString()); } - - if (clazz == AnomalyDetector.class) { - for (SearchHit hit : response.getHits()) { - XContentParser parser = XContentType.JSON - .xContent() - .createParser( - channel.request().getXContentRegistry(), - LoggingDeprecationHandler.INSTANCE, - hit.getSourceAsString() - ); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - - // write back id and version to anomaly detector object - ToXContentObject xContentObject = AnomalyDetector.parse(parser, hit.getId(), hit.getVersion()); - XContentBuilder builder = xContentObject.toXContent(jsonBuilder(), EMPTY_PARAMS); - hit.sourceRef(BytesReference.bytes(builder)); - } - } - return new BytesRestResponse(RestStatus.OK, response.toXContent(channel.newBuilder(), EMPTY_PARAMS)); } }; diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java index 64771181a..f8468bd15 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java @@ -70,7 +70,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); - searchSourceBuilder.fetchSource(getSourceContext(request)); + searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(this.index); diff --git a/src/main/java/org/opensearch/ad/util/RestHandlerUtils.java b/src/main/java/org/opensearch/ad/util/RestHandlerUtils.java index c11918653..36dc76a28 100644 --- a/src/main/java/org/opensearch/ad/util/RestHandlerUtils.java +++ b/src/main/java/org/opensearch/ad/util/RestHandlerUtils.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.Set; +import org.apache.commons.lang.ArrayUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchStatusException; @@ -43,6 +44,7 @@ import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestStatus; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import com.google.common.base.Throwables; @@ -80,21 +82,36 @@ public final class RestHandlerUtils { public static final String VALIDATE = "_validate"; public static final ToXContent.MapParams XCONTENT_WITH_TYPE = new ToXContent.MapParams(ImmutableMap.of("with_type", "true")); - private static final String OPENSEARCH_DASHBOARDS_USER_AGENT = "OpenSearch Dashboards"; - private static final String[] UI_METADATA_EXCLUDE = new String[] { AnomalyDetector.UI_METADATA_FIELD }; + public static final String OPENSEARCH_DASHBOARDS_USER_AGENT = "OpenSearch Dashboards"; + public static final String[] UI_METADATA_EXCLUDE = new String[] { AnomalyDetector.UI_METADATA_FIELD }; private RestHandlerUtils() {} /** - * Checks to see if the request came from Kibana, if so we want to return the UI Metadata from the document. + * Checks to see if the request came from OpenSearch-Dashboards, if so we want to return the UI Metadata from the document. * If the request came from the client then we exclude the UI Metadata from the search result. - * + * We also take into account the given `_source` field and respect the correct fields to be returned. * @param request rest request + * @param searchSourceBuilder an instance of the searchSourceBuilder to fetch _source field * @return instance of {@link org.opensearch.search.fetch.subphase.FetchSourceContext} */ - public static FetchSourceContext getSourceContext(RestRequest request) { + public static FetchSourceContext getSourceContext(RestRequest request, SearchSourceBuilder searchSourceBuilder) { String userAgent = Strings.coalesceToEmpty(request.header("User-Agent")); - if (!userAgent.contains(OPENSEARCH_DASHBOARDS_USER_AGENT)) { + + // If there is a _source given in request than we either add UI_Metadata to exclude or not depending on if request + // is from OpenSearch-Dashboards, if no _source field then we either exclude UI_metadata or return nothing at all. + if (searchSourceBuilder.fetchSource() != null) { + if (userAgent.contains(OPENSEARCH_DASHBOARDS_USER_AGENT)) { + return new FetchSourceContext( + true, + searchSourceBuilder.fetchSource().includes(), + searchSourceBuilder.fetchSource().excludes() + ); + } else { + String[] newArray = (String[]) ArrayUtils.addAll(searchSourceBuilder.fetchSource().excludes(), UI_METADATA_EXCLUDE); + return new FetchSourceContext(true, searchSourceBuilder.fetchSource().includes(), newArray); + } + } else if (!userAgent.contains(OPENSEARCH_DASHBOARDS_USER_AGENT)) { return new FetchSourceContext(true, Strings.EMPTY_ARRAY, UI_METADATA_EXCLUDE); } else { return null; diff --git a/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java b/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java index 0e37109ef..08089a11f 100644 --- a/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java @@ -13,6 +13,7 @@ import static org.opensearch.ad.TestHelpers.builder; import static org.opensearch.ad.TestHelpers.randomFeature; +import static org.opensearch.ad.util.RestHandlerUtils.OPENSEARCH_DASHBOARDS_USER_AGENT; import java.io.IOException; @@ -24,6 +25,7 @@ import org.opensearch.common.xcontent.XContentParser; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.rest.FakeRestChannel; @@ -34,17 +36,39 @@ public class RestHandlerUtilsTests extends OpenSearchTestCase { - public void testGetSourceContext() { - RestRequest request = new FakeRestRequest(); - FetchSourceContext context = RestHandlerUtils.getSourceContext(request); - assertArrayEquals(new String[] { "ui_metadata" }, context.excludes()); + public void testGetSourceContextFromOpenSearchDashboardEmptyExcludes() { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + builder.withHeaders(ImmutableMap.of("User-Agent", ImmutableList.of(OPENSEARCH_DASHBOARDS_USER_AGENT, randomAlphaOfLength(10)))); + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + testSearchSourceBuilder.fetchSource(new String[] { "a" }, new String[0]); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(builder.build(), testSearchSourceBuilder); + assertArrayEquals(new String[] { "a" }, sourceContext.includes()); + assertEquals(0, sourceContext.excludes().length); + assertEquals(1, sourceContext.includes().length); + } + + public void testGetSourceContextFromClientWithExcludes() { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + testSearchSourceBuilder.fetchSource(new String[] { "a" }, new String[] { "b" }); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(builder.build(), testSearchSourceBuilder); + assertEquals(sourceContext.excludes().length, 2); + } + + public void testGetSourceContextFromClientWithoutSource() { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(builder.build(), testSearchSourceBuilder); + assertEquals(sourceContext.excludes().length, 1); + assertEquals(sourceContext.includes().length, 0); } - public void testGetSourceContextForKibana() { + public void testGetSourceContextOpenSearchDashboardWithoutSources() { FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); - builder.withHeaders(ImmutableMap.of("User-Agent", ImmutableList.of("OpenSearch Dashboards", randomAlphaOfLength(10)))); - FetchSourceContext context = RestHandlerUtils.getSourceContext(builder.build()); - assertNull(context); + builder.withHeaders(ImmutableMap.of("User-Agent", ImmutableList.of(OPENSEARCH_DASHBOARDS_USER_AGENT, randomAlphaOfLength(10)))); + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(builder.build(), testSearchSourceBuilder); + assertNull(sourceContext); } public void testCreateXContentParser() throws IOException { From 012da25befbef70e42bbf7a280dc4537d60f7e4a Mon Sep 17 00:00:00 2001 From: "Daniel (dB.) Doubrovkine" Date: Thu, 15 Dec 2022 17:11:10 -0500 Subject: [PATCH 3/5] Fix: typo in ohltyler. (#760) Signed-off-by: dblock (cherry picked from commit 609abe42df8af450e48bfa122746361f7a8201c7) --- MAINTAINERS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 3b99b305d..d6730245e 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -9,7 +9,7 @@ | Kaituo Li | [kaituo](https://github.com/kaituo) | Amazon | | Lai Jiang | [wnbts](https://github.com/wnbts) | Amazon | | Sarat Vemulapalli | [saratvemulapalli](https://github.com/saratvemulapalli) | Amazon | -| Tyler Ohlsen | [ohtyler](https://github.com/ohltyler) | Amazon | +| Tyler Ohlsen | [ohltyler](https://github.com/ohltyler) | Amazon | | Vamshi Vijay Nakkirtha | [vamshin](https://github.com/vamshin) | Amazon | | Vijayan Balasubramanian | [VijayanB](https://github.com/VijayanB) | Amazon | | Yaliang Wu | [ylwu-amzn](https://github.com/ylwu-amzn) | Amazon | From 021f92fe8e0c2ecbf37dce4ce826840bbae9763f Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Tue, 6 Dec 2022 15:44:33 -0800 Subject: [PATCH 4/5] Model Profile Test (#748) (cherry picked from commit ce3d747317aec5ccef080c767d4ba8e7015c5a2a) --- .../ad/model/ModelProfileTests.java | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 src/test/java/org/opensearch/ad/model/ModelProfileTests.java diff --git a/src/test/java/org/opensearch/ad/model/ModelProfileTests.java b/src/test/java/org/opensearch/ad/model/ModelProfileTests.java new file mode 100644 index 000000000..d4d0fbf49 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/ModelProfileTests.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; + +import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; + +import test.org.opensearch.ad.util.JsonDeserializer; + +public class ModelProfileTests extends AbstractADTest { + + public void testToXContent() throws IOException { + ModelProfile profile1 = new ModelProfile( + randomAlphaOfLength(5), + Entity.createSingleAttributeEntity(randomAlphaOfLength(5), randomAlphaOfLength(5)), + 0 + ); + XContentBuilder builder = getBuilder(profile1); + String json = Strings.toString(builder); + assertTrue(JsonDeserializer.hasChildNode(json, CommonName.ENTITY_KEY)); + assertFalse(JsonDeserializer.hasChildNode(json, CommonName.MODEL_SIZE_IN_BYTES)); + + ModelProfile profile2 = new ModelProfile(randomAlphaOfLength(5), null, 1); + + builder = getBuilder(profile2); + json = Strings.toString(builder); + + assertFalse(JsonDeserializer.hasChildNode(json, CommonName.ENTITY_KEY)); + assertTrue(JsonDeserializer.hasChildNode(json, CommonName.MODEL_SIZE_IN_BYTES)); + + } + + private XContentBuilder getBuilder(ModelProfile profile) throws IOException { + XContentBuilder builder = jsonBuilder(); + builder.startObject(); + profile.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + return builder; + } +} From 8ce449345e43dd34c7891b9f7d5835c3f242f27f Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Fri, 16 Dec 2022 10:30:51 -0800 Subject: [PATCH 5/5] Speed up cold start (#753) (#763) If historical data is enough, a single stream detector takes 1 interval for cold start to be triggered + 1 interval for the state document to be updated. Similar to single stream detectors, HCAD cold start needs 2 intervals and one more interval to make sure an entity appears more than once. So HCAD needs three intervals to complete cold starts. Long initialization is the single most complained problem of AD. This PR reduces both single stream and HCAD detectors' initialization time to 1 minute by * delaying real time cache update by one minute when we receive ResourceNotFoundException in single stream detectors or when the init progress of HCAD real time cache is 0. Thus, we can finish the cold start and writing checkpoint one minute later and update the state document accordingly. This optimization saves one interval to wait for the state document update. * disable the door keeper by default so that we won't have to wait an extra interval in HCAD. * trigger cold start when starting a real time detector. This optimization saves one interval to wait for the cold start to be triggered. Testing done: * verified the cold start time is reduced to 1 minute. * added tests for new code. Signed-off-by: Kaituo Li --- build.gradle | 8 +- .../ad/AnomalyDetectorJobRunner.java | 373 +++++++++--------- .../opensearch/ad/AnomalyDetectorPlugin.java | 48 ++- .../ad/ExecuteADResultResponseRecorder.java | 288 ++++++++++++++ .../opensearch/ad/caching/PriorityCache.java | 49 +-- .../opensearch/ad/model/AnomalyDetector.java | 1 + .../ad/ratelimit/EntityColdStartWorker.java | 49 ++- .../RestExecuteAnomalyDetectorAction.java | 1 - .../IndexAnomalyDetectorJobActionHandler.java | 122 ++++-- .../ad/settings/EnabledSetting.java | 20 +- .../org/opensearch/ad/task/ADTaskManager.java | 11 +- .../AnomalyDetectorJobTransportAction.java | 10 +- .../AnomalyResultTransportAction.java | 2 +- .../ad/AnomalyDetectorJobRunnerTests.java | 84 +++- .../ad/caching/PriorityCacheTests.java | 36 +- .../ad/e2e/DetectionResultEvalutationIT.java | 25 +- .../ad/ml/EntityColdStarterTests.java | 1 - ...alyDetectorJobTransportActionWithUser.java | 10 +- .../ratelimit/EntityColdStartWorkerTests.java | 29 +- ...xAnomalyDetectorJobActionHandlerTests.java | 354 +++++++++++++++++ .../AnomalyDetectorJobActionTests.java | 4 +- 21 files changed, 1193 insertions(+), 332 deletions(-) create mode 100644 src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java create mode 100644 src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java diff --git a/build.gradle b/build.gradle index c66674c16..594fd7813 100644 --- a/build.gradle +++ b/build.gradle @@ -134,7 +134,7 @@ configurations.all { if (it.state != Configuration.State.UNRESOLVED) return resolutionStrategy { force "joda-time:joda-time:${versions.joda}" - force "com.fasterxml.jackson.core:jackson-core:2.14.0" + force "com.fasterxml.jackson.core:jackson-core:2.14.1" force "commons-logging:commons-logging:${versions.commonslogging}" force "org.apache.httpcomponents:httpcore:${versions.httpcore}" force "commons-codec:commons-codec:${versions.commonscodec}" @@ -677,9 +677,9 @@ dependencies { implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc3' // force Jackson version to avoid version conflict issue - implementation "com.fasterxml.jackson.core:jackson-core:2.14.0" - implementation "com.fasterxml.jackson.core:jackson-databind:2.14.0" - implementation "com.fasterxml.jackson.core:jackson-annotations:2.14.0" + implementation "com.fasterxml.jackson.core:jackson-core:2.14.1" + implementation "com.fasterxml.jackson.core:jackson-databind:2.14.1" + implementation "com.fasterxml.jackson.core:jackson-annotations:2.14.1" // used for serializing/deserializing rcf models. implementation group: 'io.protostuff', name: 'protostuff-core', version: '1.8.0' diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java index 655267dcf..aa0ce8075 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java @@ -14,21 +14,18 @@ import static org.opensearch.action.DocWriteResponse.Result.CREATED; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.ad.AnomalyDetectorPlugin.AD_THREAD_POOL_NAME; -import static org.opensearch.ad.constant.CommonErrorMessages.CAN_NOT_FIND_LATEST_TASK; import static org.opensearch.ad.util.RestHandlerUtils.XCONTENT_WITH_TYPE; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.Instant; -import java.util.ArrayList; -import java.util.HashSet; import java.util.List; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; @@ -37,15 +34,10 @@ import org.opensearch.ad.common.exception.AnomalyDetectionException; import org.opensearch.ad.common.exception.EndRunException; import org.opensearch.ad.common.exception.InternalFailure; -import org.opensearch.ad.common.exception.ResourceNotFoundException; -import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorJob; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.FeatureData; -import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.rest.handler.AnomalyDetectorFunction; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; @@ -53,12 +45,7 @@ import org.opensearch.ad.transport.AnomalyResultRequest; import org.opensearch.ad.transport.AnomalyResultResponse; import org.opensearch.ad.transport.AnomalyResultTransportAction; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileRequest; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; -import org.opensearch.ad.util.DiscoveryNodeFilterer; import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; @@ -66,7 +53,6 @@ import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.InjectSecurity; -import org.opensearch.commons.authuser.User; import org.opensearch.jobscheduler.spi.JobExecutionContext; import org.opensearch.jobscheduler.spi.LockModel; import org.opensearch.jobscheduler.spi.ScheduledJobParameter; @@ -88,11 +74,11 @@ public class AnomalyDetectorJobRunner implements ScheduledJobRunner { private int maxRetryForEndRunException; private Client client; private ThreadPool threadPool; - private AnomalyIndexHandler anomalyResultHandler; private ConcurrentHashMap detectorEndRunExceptionCount; private AnomalyDetectionIndices anomalyDetectionIndices; - private DiscoveryNodeFilterer nodeFilter; private ADTaskManager adTaskManager; + private NodeStateManager nodeStateManager; + private ExecuteADResultResponseRecorder recorder; public static AnomalyDetectorJobRunner getJobRunnerInstance() { if (INSTANCE != null) { @@ -120,10 +106,6 @@ public void setThreadPool(ThreadPool threadPool) { this.threadPool = threadPool; } - public void setAnomalyResultHandler(AnomalyIndexHandler anomalyResultHandler) { - this.anomalyResultHandler = anomalyResultHandler; - } - public void setSettings(Settings settings) { this.settings = settings; this.maxRetryForEndRunException = AnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings); @@ -137,8 +119,12 @@ public void setAnomalyDetectionIndices(AnomalyDetectionIndices anomalyDetectionI this.anomalyDetectionIndices = anomalyDetectionIndices; } - public void setNodeFilter(DiscoveryNodeFilterer nodeFilter) { - this.nodeFilter = nodeFilter; + public void setNodeStateManager(NodeStateManager nodeStateManager) { + this.nodeStateManager = nodeStateManager; + } + + public void setExecuteADResultResponseRecorder(ExecuteADResultResponseRecorder recorder) { + this.recorder = recorder; } @Override @@ -159,28 +145,49 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont final LockService lockService = context.getLockService(); Runnable runnable = () -> { - if (jobParameter.getLockDurationSeconds() != null) { - lockService - .acquireLock( - jobParameter, - context, - ActionListener - .wrap(lock -> runAdJob(jobParameter, lockService, lock, detectionStartTime, executionStartTime), exception -> { - indexAnomalyResultException( - jobParameter, - lockService, - null, - detectionStartTime, - executionStartTime, - exception, - false - ); - throw new IllegalStateException("Failed to acquire lock for AD job: " + detectorId); - }) - ); - } else { - log.warn("Can't get lock for AD job: " + detectorId); - } + nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { + if (!detectorOptional.isPresent()) { + log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId)); + return; + } + AnomalyDetector detector = detectorOptional.get(); + + if (jobParameter.getLockDurationSeconds() != null) { + lockService + .acquireLock( + jobParameter, + context, + ActionListener + .wrap( + lock -> runAdJob( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + recorder, + detector + ), + exception -> { + indexAnomalyResultException( + jobParameter, + lockService, + null, + detectionStartTime, + executionStartTime, + exception, + false, + recorder, + detector + ); + throw new IllegalStateException("Failed to acquire lock for AD job: " + detectorId); + } + ) + ); + } else { + log.warn("Can't get lock for AD job: " + detectorId); + } + }, e -> log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), e))); }; ExecutorService executor = threadPool.executor(AD_THREAD_POOL_NAME); @@ -195,13 +202,17 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont * @param lock lock to run job * @param detectionStartTime detection start time * @param executionStartTime detection end time + * @param recorder utility to record job execution result + * @param detector associated detector accessor */ protected void runAdJob( AnomalyDetectorJob jobParameter, LockService lockService, LockModel lock, Instant detectionStartTime, - Instant executionStartTime + Instant executionStartTime, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector ) { String detectorId = jobParameter.getName(); if (lock == null) { @@ -212,7 +223,9 @@ protected void runAdJob( detectionStartTime, executionStartTime, "Can't run AD job due to null lock", - false + false, + recorder, + detector ); return; } @@ -242,16 +255,38 @@ protected void runAdJob( } String resultIndex = jobParameter.getResultIndex(); if (resultIndex == null) { - runAnomalyDetectionJob(jobParameter, lockService, lock, detectionStartTime, executionStartTime, detectorId, user, roles); + runAnomalyDetectionJob( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + detectorId, + user, + roles, + recorder, + detector + ); return; } ActionListener listener = ActionListener.wrap(r -> { log.debug("Custom index is valid"); }, e -> { Exception exception = new EndRunException(detectorId, e.getMessage(), true); - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, exception); + handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, exception, recorder, detector); }); anomalyDetectionIndices.validateCustomIndexForBackendJob(resultIndex, detectorId, user, roles, () -> { listener.onResponse(true); - runAnomalyDetectionJob(jobParameter, lockService, lock, detectionStartTime, executionStartTime, detectorId, user, roles); + runAnomalyDetectionJob( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + detectorId, + user, + roles, + recorder, + detector + ); }, listener); } @@ -263,7 +298,9 @@ private void runAnomalyDetectionJob( Instant executionStartTime, String detectorId, String user, - List roles + List roles, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector ) { try (InjectSecurity injectSecurity = new InjectSecurity(detectorId, settings, client.threadPool().getThreadContext())) { @@ -282,15 +319,43 @@ private void runAnomalyDetectionJob( ActionListener .wrap( response -> { - indexAnomalyResult(jobParameter, lockService, lock, detectionStartTime, executionStartTime, response); + indexAnomalyResult( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + response, + recorder, + detector + ); }, exception -> { - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, exception); + handleAdException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + exception, + recorder, + detector + ); } ) ); } catch (Exception e) { - indexAnomalyResultException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, e, true); + indexAnomalyResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + e, + true, + recorder, + detector + ); log.error("Failed to execute AD job " + detectorId, e); } } @@ -331,6 +396,8 @@ private void runAnomalyDetectionJob( * @param detectionStartTime detection start time * @param executionStartTime detection end time * @param exception exception + * @param recorder utility to record job execution result + * @param detector associated detector accessor */ protected void handleAdException( AnomalyDetectorJob jobParameter, @@ -338,7 +405,9 @@ protected void handleAdException( LockModel lock, Instant detectionStartTime, Instant executionStartTime, - Exception exception + Exception exception, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector ) { String detectorId = jobParameter.getName(); if (exception instanceof EndRunException) { @@ -353,7 +422,9 @@ protected void handleAdException( lock, detectionStartTime, executionStartTime, - (EndRunException) exception + (EndRunException) exception, + recorder, + detector ); } else { detectorEndRunExceptionCount.compute(detectorId, (k, v) -> { @@ -378,7 +449,9 @@ protected void handleAdException( lock, detectionStartTime, executionStartTime, - (EndRunException) exception + (EndRunException) exception, + recorder, + detector ); return; } @@ -389,7 +462,9 @@ protected void handleAdException( detectionStartTime, executionStartTime, exception.getMessage(), - true + true, + recorder, + detector ); } } else { @@ -399,7 +474,17 @@ protected void handleAdException( } else { log.error("Failed to execute anomaly result action for " + detectorId, exception); } - indexAnomalyResultException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, exception, true); + indexAnomalyResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + exception, + true, + recorder, + detector + ); } } @@ -409,7 +494,9 @@ private void stopAdJobForEndRunException( LockModel lock, Instant detectionStartTime, Instant executionStartTime, - EndRunException exception + EndRunException exception, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector ) { String detectorId = jobParameter.getName(); detectorEndRunExceptionCount.remove(detectorId); @@ -427,7 +514,9 @@ private void stopAdJobForEndRunException( executionStartTime, error, true, - ADTaskState.STOPPED.name() + ADTaskState.STOPPED.name(), + recorder, + detector ) ); } @@ -491,81 +580,22 @@ private void indexAnomalyResult( LockModel lock, Instant detectionStartTime, Instant executionStartTime, - AnomalyResultResponse response + AnomalyResultResponse response, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector ) { String detectorId = jobParameter.getName(); detectorEndRunExceptionCount.remove(detectorId); try { - // skipping writing to the result index if not necessary - // For a single-entity detector, the result is not useful if error is null - // and rcf score (thus anomaly grade/confidence) is null. - // For a HCAD detector, we don't need to save on the detector level. - // We return 0 or Double.NaN rcf score if there is no error. - if ((response.getAnomalyScore() <= 0 || Double.isNaN(response.getAnomalyScore())) && response.getError() == null) { - updateRealtimeTask(response, detectorId); - return; - } - IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) jobParameter.getWindowDelay(); - Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - User user = jobParameter.getUser(); - - if (response.getError() != null) { - log.info("Anomaly result action run successfully for {} with error {}", detectorId, response.getError()); - } - - AnomalyResult anomalyResult = response - .toAnomalyResult( - detectorId, - dataStartTime, - dataEndTime, - executionStartTime, - Instant.now(), - anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), - user, - response.getError() - ); - - String resultIndex = jobParameter.getResultIndex(); - anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); - updateRealtimeTask(response, detectorId); + recorder.indexAnomalyResult(detectionStartTime, executionStartTime, response, detector); } catch (EndRunException e) { - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, e); + handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, e, recorder, detector); } catch (Exception e) { log.error("Failed to index anomaly result for " + detectorId, e); } finally { releaseLock(jobParameter, lockService, lock); } - } - private void updateRealtimeTask(AnomalyResultResponse response, String detectorId) { - if (response.isHCDetector() != null - && response.isHCDetector() - && !adTaskManager.skipUpdateHCRealtimeTask(detectorId, response.getError())) { - DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); - Set profiles = new HashSet<>(); - profiles.add(DetectorProfileName.INIT_PROGRESS); - ProfileRequest profileRequest = new ProfileRequest(detectorId, profiles, true, dataNodes); - client.execute(ProfileAction.INSTANCE, profileRequest, ActionListener.wrap(r -> { - log.debug("Update latest realtime task for HC detector {}, total updates: {}", detectorId, r.getTotalUpdates()); - updateLatestRealtimeTask( - detectorId, - null, - r.getTotalUpdates(), - response.getDetectorIntervalInMinutes(), - response.getError() - ); - }, e -> { log.error("Failed to update latest realtime task for " + detectorId, e); })); - } else { - log.debug("Update latest realtime task for SINGLE detector {}, total updates: {}", detectorId, response.getRcfTotalUpdates()); - updateLatestRealtimeTask( - detectorId, - null, - response.getRcfTotalUpdates(), - response.getDetectorIntervalInMinutes(), - response.getError() - ); - } } private void indexAnomalyResultException( @@ -575,13 +605,25 @@ private void indexAnomalyResultException( Instant detectionStartTime, Instant executionStartTime, Exception exception, - boolean releaseLock + boolean releaseLock, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector ) { try { String errorMessage = exception instanceof AnomalyDetectionException ? exception.getMessage() : Throwables.getStackTraceAsString(exception); - indexAnomalyResultException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, errorMessage, releaseLock); + indexAnomalyResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + errorMessage, + releaseLock, + recorder, + detector + ); } catch (Exception e) { log.error("Failed to index anomaly result for " + jobParameter.getName(), e); } @@ -594,7 +636,9 @@ private void indexAnomalyResultException( Instant detectionStartTime, Instant executionStartTime, String errorMessage, - boolean releaseLock + boolean releaseLock, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector ) { indexAnomalyResultException( jobParameter, @@ -604,7 +648,9 @@ private void indexAnomalyResultException( executionStartTime, errorMessage, releaseLock, - null + null, + recorder, + detector ); } @@ -616,40 +662,12 @@ private void indexAnomalyResultException( Instant executionStartTime, String errorMessage, boolean releaseLock, - String taskState + String taskState, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector ) { - String detectorId = jobParameter.getName(); try { - IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) jobParameter.getWindowDelay(); - Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - User user = jobParameter.getUser(); - - AnomalyResult anomalyResult = new AnomalyResult( - detectorId, - null, // no task id - new ArrayList(), - dataStartTime, - dataEndTime, - executionStartTime, - Instant.now(), - errorMessage, - null, // single-stream detectors have no entity - user, - anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), - null // no model id - ); - String resultIndex = jobParameter.getResultIndex(); - if (resultIndex != null && !anomalyDetectionIndices.doesIndexExist(resultIndex)) { - // Set result index as null, will write exception to default result index. - anomalyResultHandler.index(anomalyResult, detectorId, null); - } else { - anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); - } - - updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); - } catch (Exception e) { - log.error("Failed to index anomaly result for " + detectorId, e); + recorder.indexAnomalyResultException(detectionStartTime, executionStartTime, errorMessage, taskState, detector); } finally { if (releaseLock) { releaseLock(jobParameter, lockService, lock); @@ -657,37 +675,6 @@ private void indexAnomalyResultException( } } - private void updateLatestRealtimeTask( - String detectorId, - String taskState, - Long rcfTotalUpdates, - Long detectorIntervalInMinutes, - String error - ) { - // Don't need info as this will be printed repeatedly in each interval - adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( - detectorId, - taskState, - rcfTotalUpdates, - detectorIntervalInMinutes, - error, - ActionListener.wrap(r -> { - if (r != null) { - log.debug("Updated latest realtime task successfully for detector {}, taskState: {}", detectorId, taskState); - } - }, e -> { - if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CAN_NOT_FIND_LATEST_TASK)) { - // Clear realtime task cache, will recreate AD task in next run, check AnomalyResultTransportAction. - log.error("Can't find latest realtime task of detector " + detectorId); - adTaskManager.removeRealtimeTaskCache(detectorId); - } else { - log.error("Failed to update latest realtime task for detector " + detectorId, e); - } - }) - ); - } - private void releaseLock(AnomalyDetectorJob jobParameter, LockService lockService, LockModel lock) { lockService .release( diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java b/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java index b21d11cc5..cace64e3f 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java @@ -234,11 +234,12 @@ public class AnomalyDetectorPlugin extends Plugin implements ActionPlugin, Scrip private ClientUtil clientUtil; private DiscoveryNodeFilterer nodeFilter; private IndexUtils indexUtils; - private ADTaskCacheManager adTaskCacheManager; private ADTaskManager adTaskManager; private ADBatchTaskRunner adBatchTaskRunner; // package private for testing GenericObjectPool serializeRCFBufferPool; + private NodeStateManager stateManager; + private ExecuteADResultResponseRecorder adResultResponseRecorder; static { SpecialPermission.check(); @@ -259,25 +260,14 @@ public List getRestHandlers( IndexNameExpressionResolver indexNameExpressionResolver, Supplier nodesInCluster ) { - AnomalyIndexHandler anomalyResultHandler = new AnomalyIndexHandler( - client, - settings, - threadPool, - CommonName.ANOMALY_RESULT_INDEX_ALIAS, - anomalyDetectionIndices, - this.clientUtil, - this.indexUtils, - clusterService - ); - AnomalyDetectorJobRunner jobRunner = AnomalyDetectorJobRunner.getJobRunnerInstance(); jobRunner.setClient(client); jobRunner.setThreadPool(threadPool); - jobRunner.setAnomalyResultHandler(anomalyResultHandler); jobRunner.setSettings(settings); jobRunner.setAnomalyDetectionIndices(anomalyDetectionIndices); - jobRunner.setNodeFilter(nodeFilter); jobRunner.setAdTaskManager(adTaskManager); + jobRunner.setNodeStateManager(stateManager); + jobRunner.setExecuteADResultResponseRecorder(adResultResponseRecorder); RestGetAnomalyDetectorAction restGetAnomalyDetectorAction = new RestGetAnomalyDetectorAction(); RestIndexAnomalyDetectorAction restIndexAnomalyDetectorAction = new RestIndexAnomalyDetectorAction(settings, clusterService); @@ -383,7 +373,7 @@ public Collection createComponents( adCircuitBreakerService ); - NodeStateManager stateManager = new NodeStateManager( + stateManager = new NodeStateManager( client, xContentRegistry, settings, @@ -568,7 +558,8 @@ public PooledObject wrap(LinkedBuffer obj) { AnomalyDetectorSettings.QUEUE_MAINTENANCE, entityColdStarter, AnomalyDetectorSettings.HOURLY_MAINTENANCE, - stateManager + stateManager, + cacheProvider ); ModelManager modelManager = new ModelManager( @@ -714,7 +705,7 @@ public PooledObject wrap(LinkedBuffer obj) { anomalyDetectorRunner = new AnomalyDetectorRunner(modelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); - adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); + ADTaskCacheManager adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); adTaskManager = new ADTaskManager( settings, clusterService, @@ -754,6 +745,26 @@ public PooledObject wrap(LinkedBuffer obj) { ADSearchHandler adSearchHandler = new ADSearchHandler(settings, clusterService, client); + AnomalyIndexHandler anomalyResultHandler = new AnomalyIndexHandler( + client, + settings, + threadPool, + CommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + this.clientUtil, + this.indexUtils, + clusterService + ); + + adResultResponseRecorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client + ); + // return objects used by Guice to inject dependencies for e.g., // transport action handler constructors return ImmutableList @@ -795,7 +806,8 @@ public PooledObject wrap(LinkedBuffer obj) { checkpointWriteQueue, coldEntityQueue, entityColdStarter, - adTaskCacheManager + adTaskCacheManager, + adResultResponseRecorder ); } diff --git a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java new file mode 100644 index 000000000..19710f0cb --- /dev/null +++ b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java @@ -0,0 +1,288 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.ad.constant.CommonErrorMessages.CAN_NOT_FIND_LATEST_TASK; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.common.exception.EndRunException; +import org.opensearch.ad.common.exception.ResourceNotFoundException; +import org.opensearch.ad.constant.CommonErrorMessages; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.AnomalyDetectionIndices; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.FeatureData; +import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultResponse; +import org.opensearch.ad.transport.ProfileAction; +import org.opensearch.ad.transport.ProfileRequest; +import org.opensearch.ad.transport.RCFPollingAction; +import org.opensearch.ad.transport.RCFPollingRequest; +import org.opensearch.ad.transport.handler.AnomalyIndexHandler; +import org.opensearch.ad.util.DiscoveryNodeFilterer; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.threadpool.ThreadPool; + +public class ExecuteADResultResponseRecorder { + private static final Logger log = LogManager.getLogger(ExecuteADResultResponseRecorder.class); + + private AnomalyDetectionIndices anomalyDetectionIndices; + private AnomalyIndexHandler anomalyResultHandler; + private ADTaskManager adTaskManager; + private DiscoveryNodeFilterer nodeFilter; + private ThreadPool threadPool; + private Client client; + + public ExecuteADResultResponseRecorder( + AnomalyDetectionIndices anomalyDetectionIndices, + AnomalyIndexHandler anomalyResultHandler, + ADTaskManager adTaskManager, + DiscoveryNodeFilterer nodeFilter, + ThreadPool threadPool, + Client client + ) { + this.anomalyDetectionIndices = anomalyDetectionIndices; + this.anomalyResultHandler = anomalyResultHandler; + this.adTaskManager = adTaskManager; + this.nodeFilter = nodeFilter; + this.threadPool = threadPool; + this.client = client; + } + + public void indexAnomalyResult( + Instant detectionStartTime, + Instant executionStartTime, + AnomalyResultResponse response, + AnomalyDetector detector + ) { + String detectorId = detector.getDetectorId(); + try { + // skipping writing to the result index if not necessary + // For a single-entity detector, the result is not useful if error is null + // and rcf score (thus anomaly grade/confidence) is null. + // For a HCAD detector, we don't need to save on the detector level. + // We return 0 or Double.NaN rcf score if there is no error. + if ((response.getAnomalyScore() <= 0 || Double.isNaN(response.getAnomalyScore())) && response.getError() == null) { + updateRealtimeTask(response, detectorId); + return; + } + IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay(); + Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + User user = detector.getUser(); + + if (response.getError() != null) { + log.info("Anomaly result action run successfully for {} with error {}", detectorId, response.getError()); + } + + AnomalyResult anomalyResult = response + .toAnomalyResult( + detectorId, + dataStartTime, + dataEndTime, + executionStartTime, + Instant.now(), + anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), + user, + response.getError() + ); + + String resultIndex = detector.getResultIndex(); + anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); + updateRealtimeTask(response, detectorId); + } catch (EndRunException e) { + throw e; + } catch (Exception e) { + log.error("Failed to index anomaly result for " + detectorId, e); + } + } + + /** + * Update real time task (one document per detector in state index). If the real-time task has no changes compared with local cache, + * the task won't update. Task only updates when the state changed, or any error happened, or AD job stopped. Task is mainly consumed + * by the front-end to track detector status. For single-stream detectors, we embed model total updates in AnomalyResultResponse and + * update state accordingly. For HCAD, we won't wait for model finishing updating before returning a response to the job scheduler + * since it might be long before all entities finish execution. So we don't embed model total updates in AnomalyResultResponse. + * Instead, we issue a profile request to poll each model node and get the maximum total updates among all models. + * @param response response returned from executing AnomalyResultAction + * @param detectorId Detector Id + */ + private void updateRealtimeTask(AnomalyResultResponse response, String detectorId) { + if (response.isHCDetector() != null && response.isHCDetector()) { + if (adTaskManager.skipUpdateHCRealtimeTask(detectorId, response.getError())) { + return; + } + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + Set profiles = new HashSet<>(); + profiles.add(DetectorProfileName.INIT_PROGRESS); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profiles, true, dataNodes); + Runnable profileHCInitProgress = () -> { + client.execute(ProfileAction.INSTANCE, profileRequest, ActionListener.wrap(r -> { + log.debug("Update latest realtime task for HC detector {}, total updates: {}", detectorId, r.getTotalUpdates()); + updateLatestRealtimeTask( + detectorId, + null, + r.getTotalUpdates(), + response.getDetectorIntervalInMinutes(), + response.getError() + ); + }, e -> { log.error("Failed to update latest realtime task for " + detectorId, e); })); + }; + if (!adTaskManager.isHCRealtimeTaskStartInitializing(detectorId)) { + // real time init progress is 0 may mean this is a newly started detector + // Delay real time cache update by one minute. If we are in init status, the delay may give the model training time to + // finish. We can change the detector running immediately instead of waiting for the next interval. + threadPool.schedule(profileHCInitProgress, new TimeValue(60, TimeUnit.SECONDS), AnomalyDetectorPlugin.AD_THREAD_POOL_NAME); + } else { + profileHCInitProgress.run(); + } + + } else { + log + .debug( + "Update latest realtime task for single stream detector {}, total updates: {}", + detectorId, + response.getRcfTotalUpdates() + ); + updateLatestRealtimeTask( + detectorId, + null, + response.getRcfTotalUpdates(), + response.getDetectorIntervalInMinutes(), + response.getError() + ); + } + } + + private void updateLatestRealtimeTask( + String detectorId, + String taskState, + Long rcfTotalUpdates, + Long detectorIntervalInMinutes, + String error + ) { + // Don't need info as this will be printed repeatedly in each interval + adTaskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + detectorId, + taskState, + rcfTotalUpdates, + detectorIntervalInMinutes, + error, + ActionListener.wrap(r -> { + if (r != null) { + log.debug("Updated latest realtime task successfully for detector {}, taskState: {}", detectorId, taskState); + } + }, e -> { + if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CAN_NOT_FIND_LATEST_TASK)) { + // Clear realtime task cache, will recreate AD task in next run, check AnomalyResultTransportAction. + log.error("Can't find latest realtime task of detector " + detectorId); + adTaskManager.removeRealtimeTaskCache(detectorId); + } else { + log.error("Failed to update latest realtime task for detector " + detectorId, e); + } + }) + ); + } + + /** + * The function is not only indexing the result with the exception, but also updating the task state after + * 60s if the exception is related to cold start (index not found exceptions) for a single stream detector. + * + * @param detectionStartTime execution start time + * @param executionStartTime execution end time + * @param errorMessage Error message to record + * @param taskState AD task state (e.g., stopped) + * @param detector Detector config accessor + */ + public void indexAnomalyResultException( + Instant detectionStartTime, + Instant executionStartTime, + String errorMessage, + String taskState, + AnomalyDetector detector + ) { + String detectorId = detector.getDetectorId(); + try { + IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay(); + Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + User user = detector.getUser(); + + AnomalyResult anomalyResult = new AnomalyResult( + detectorId, + null, // no task id + new ArrayList(), + dataStartTime, + dataEndTime, + executionStartTime, + Instant.now(), + errorMessage, + null, // single-stream detectors have no entity + user, + anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), + null // no model id + ); + String resultIndex = detector.getResultIndex(); + if (resultIndex != null && !anomalyDetectionIndices.doesIndexExist(resultIndex)) { + // Set result index as null, will write exception to default result index. + anomalyResultHandler.index(anomalyResult, detectorId, null); + } else { + anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); + } + + if (errorMessage.contains(CommonErrorMessages.NO_CHECKPOINT_ERR_MSG) && !detector.isMultiCategoryDetector()) { + // single stream detector raises ResourceNotFoundException containing CommonErrorMessages.NO_CHECKPOINT_ERR_MSG + // when there is no checkpoint. + // Delay real time cache update by one minute so we will have trained models by then and update the state + // document accordingly. + threadPool.schedule(() -> { + RCFPollingRequest request = new RCFPollingRequest(detectorId); + client.execute(RCFPollingAction.INSTANCE, request, ActionListener.wrap(rcfPollResponse -> { + long totalUpdates = rcfPollResponse.getTotalUpdates(); + // if there are updates, don't record failures + updateLatestRealtimeTask( + detectorId, + taskState, + totalUpdates, + detector.getDetectorIntervalInMinutes(), + totalUpdates > 0 ? "" : errorMessage + ); + }, e -> { + log.error("Fail to execute RCFRollingAction", e); + updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); + })); + }, new TimeValue(60, TimeUnit.SECONDS), AnomalyDetectorPlugin.AD_THREAD_POOL_NAME); + } else { + updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); + } + + } catch (Exception e) { + log.error("Failed to index anomaly result for " + detectorId, e); + } + } + +} diff --git a/src/main/java/org/opensearch/ad/caching/PriorityCache.java b/src/main/java/org/opensearch/ad/caching/PriorityCache.java index 0e94663d2..dfae6b931 100644 --- a/src/main/java/org/opensearch/ad/caching/PriorityCache.java +++ b/src/main/java/org/opensearch/ad/caching/PriorityCache.java @@ -55,6 +55,7 @@ import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; import org.opensearch.ad.ratelimit.CheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.settings.EnabledSetting; import org.opensearch.ad.util.DateUtils; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; @@ -161,29 +162,31 @@ public ModelState get(String modelId, AnomalyDetector detector) { // during maintenance period, stop putting new entries if (!maintenanceLock.isLocked() && modelState == null) { - DoorKeeper doorKeeper = doorKeepers - .computeIfAbsent( - detectorId, - id -> { - // reset every 60 intervals - return new DoorKeeper( - AnomalyDetectorSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION, - AnomalyDetectorSettings.DOOR_KEEPER_FAULSE_POSITIVE_RATE, - detector.getDetectionIntervalDuration().multipliedBy(AnomalyDetectorSettings.DOOR_KEEPER_MAINTENANCE_FREQ), - clock - ); - } - ); - - // first hit, ignore - // since door keeper may get reset during maintenance, it is possible - // the entity is still active even though door keeper has no record of - // this model Id. We have to call isActive method to make sure. Otherwise, - // the entity might miss an anomaly result every 60 intervals due to door keeper - // reset. - if (!doorKeeper.mightContain(modelId) && !isActive(detectorId, modelId)) { - doorKeeper.put(modelId); - return null; + if (EnabledSetting.isDoorKeeperInCacheEnabled()) { + DoorKeeper doorKeeper = doorKeepers + .computeIfAbsent( + detectorId, + id -> { + // reset every 60 intervals + return new DoorKeeper( + AnomalyDetectorSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION, + AnomalyDetectorSettings.DOOR_KEEPER_FAULSE_POSITIVE_RATE, + detector.getDetectionIntervalDuration().multipliedBy(AnomalyDetectorSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + clock + ); + } + ); + + // first hit, ignore + // since door keeper may get reset during maintenance, it is possible + // the entity is still active even though door keeper has no record of + // this model Id. We have to call isActive method to make sure. Otherwise, + // the entity might miss an anomaly result every 60 intervals due to door keeper + // reset. + if (!doorKeeper.mightContain(modelId) && !isActive(detectorId, modelId)) { + doorKeeper.put(modelId); + return null; + } } try { diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index da64b30ea..b7492300d 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -123,6 +123,7 @@ public class AnomalyDetector implements Writeable, ToXContentObject { private DetectionDateRange detectionDateRange; public static final int MAX_RESULT_INDEX_NAME_SIZE = 255; + // OS doesn’t allow uppercase: https://tinyurl.com/yse2xdbx public static final String RESULT_INDEX_NAME_PATTERN = "[a-z0-9_-]+"; /** diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java b/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java index c1166d4d9..8702fafcc 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java @@ -22,13 +22,16 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.ActionListener; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; import org.opensearch.ad.ml.EntityColdStarter; import org.opensearch.ad.ml.EntityModel; import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; @@ -49,6 +52,7 @@ public class EntityColdStartWorker extends SingleRequestWorker { public static final String WORKER_NAME = "cold-start"; private final EntityColdStarter entityColdStarter; + private final CacheProvider cacheProvider; public EntityColdStartWorker( long heapSizeInBytes, @@ -67,7 +71,8 @@ public EntityColdStartWorker( Duration executionTtl, EntityColdStarter entityColdStarter, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + CacheProvider cacheProvider ) { super( WORKER_NAME, @@ -90,6 +95,7 @@ public EntityColdStartWorker( nodeStateManager ); this.entityColdStarter = entityColdStarter; + this.cacheProvider = cacheProvider; } @Override @@ -114,15 +120,42 @@ protected void executeRequest(EntityRequest coldStartRequest, ActionListener failureListener = ActionListener.delegateResponse(listener, (delegateListener, e) -> { - if (ExceptionUtil.isOverloaded(e)) { - LOG.error("OpenSearch is overloaded"); - setCoolDownStart(); + ActionListener coldStartListener = ActionListener.wrap(r -> { + nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { + try { + if (!detectorOptional.isPresent()) { + LOG + .error( + new ParameterizedMessage( + "fail to load trained model [{}] to cache due to the detector not being found.", + modelState.getModelId() + ) + ); + return; + } + AnomalyDetector detector = detectorOptional.get(); + EntityModel model = modelState.getModel(); + // load to cache if cold start succeeds + if (model != null && model.getTrcf() != null) { + cacheProvider.get().hostIfPossible(detector, modelState); + } + } finally { + listener.onResponse(null); + } + }, listener::onFailure)); + + }, e -> { + try { + if (ExceptionUtil.isOverloaded(e)) { + LOG.error("OpenSearch is overloaded"); + setCoolDownStart(); + } + nodeStateManager.setException(detectorId, e); + } finally { + listener.onFailure(e); } - nodeStateManager.setException(detectorId, e); - delegateListener.onFailure(e); }); - entityColdStarter.trainModel(coldStartRequest.getEntity(), detectorId, modelState, failureListener); + entityColdStarter.trainModel(coldStartRequest.getEntity(), detectorId, modelState, coldStartListener); } } diff --git a/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java index 6e2ac9131..363199d58 100644 --- a/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java @@ -70,7 +70,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } AnomalyDetectorExecutionInput input = getAnomalyDetectorExecutionInput(request); return channel -> { - String rawPath = request.rawPath(); String error = validateAdExecutionInput(input); if (StringUtils.isNotBlank(error)) { channel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, error)); diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java index b4c7eba3e..31256f3ed 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java @@ -31,6 +31,7 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.AnomalyDetector; @@ -38,6 +39,8 @@ import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultRequest; import org.opensearch.ad.transport.StopDetectorAction; import org.opensearch.ad.transport.StopDetectorRequest; import org.opensearch.ad.transport.StopDetectorResponse; @@ -52,6 +55,8 @@ import org.opensearch.rest.RestStatus; import org.opensearch.transport.TransportService; +import com.google.common.base.Throwables; + /** * Anomaly detector job REST action handler to process POST/PUT request. */ @@ -62,19 +67,18 @@ public class IndexAnomalyDetectorJobActionHandler { private final Long seqNo; private final Long primaryTerm; private final Client client; - private final ActionListener listener; private final NamedXContentRegistry xContentRegistry; private final TransportService transportService; private final ADTaskManager adTaskManager; private final Logger logger = LogManager.getLogger(IndexAnomalyDetectorJobActionHandler.class); private final TimeValue requestTimeout; + private final ExecuteADResultResponseRecorder recorder; /** * Constructor function. * * @param client ES node client that executes actions on the local node - * @param listener Listener to send responses * @param anomalyDetectionIndices anomaly detector index manager * @param detectorId detector identifier * @param seqNo sequence number of last modification @@ -83,10 +87,10 @@ public class IndexAnomalyDetectorJobActionHandler { * @param xContentRegistry Registry which is used for XContentParser * @param transportService transport service * @param adTaskManager AD task manager + * @param recorder Utility to record AnomalyResultAction execution result */ public IndexAnomalyDetectorJobActionHandler( Client client, - ActionListener listener, AnomalyDetectionIndices anomalyDetectionIndices, String detectorId, Long seqNo, @@ -94,10 +98,10 @@ public IndexAnomalyDetectorJobActionHandler( TimeValue requestTimeout, NamedXContentRegistry xContentRegistry, TransportService transportService, - ADTaskManager adTaskManager + ADTaskManager adTaskManager, + ExecuteADResultResponseRecorder recorder ) { this.client = client; - this.listener = listener; this.anomalyDetectionIndices = anomalyDetectionIndices; this.detectorId = detectorId; this.seqNo = seqNo; @@ -106,6 +110,7 @@ public IndexAnomalyDetectorJobActionHandler( this.xContentRegistry = xContentRegistry; this.transportService = transportService; this.adTaskManager = adTaskManager; + this.recorder = recorder; } /** @@ -113,16 +118,56 @@ public IndexAnomalyDetectorJobActionHandler( * 1. If job doesn't exist, create new job. * 2. If job exists: a). if job enabled, return error message; b). if job disabled, enable job. * @param detector anomaly detector + * @param listener Listener to send responses */ - public void startAnomalyDetectorJob(AnomalyDetector detector) { + public void startAnomalyDetectorJob(AnomalyDetector detector, ActionListener listener) { + // this start listener is created & injected throughout the job handler so that whenever the job response is received, + // there's the extra step of trying to index results and update detector state with a 60s delay. + ActionListener startListener = ActionListener.wrap(r -> { + try { + Instant executionEndTime = Instant.now(); + IntervalTimeConfiguration schedule = (IntervalTimeConfiguration) detector.getDetectionInterval(); + Instant executionStartTime = executionEndTime.minus(schedule.getInterval(), schedule.getUnit()); + AnomalyResultRequest getRequest = new AnomalyResultRequest( + detector.getDetectorId(), + executionStartTime.toEpochMilli(), + executionEndTime.toEpochMilli() + ); + client + .execute( + AnomalyResultAction.INSTANCE, + getRequest, + ActionListener + .wrap( + response -> recorder.indexAnomalyResult(executionStartTime, executionEndTime, response, detector), + exception -> { + + recorder + .indexAnomalyResultException( + executionStartTime, + executionEndTime, + Throwables.getStackTraceAsString(exception), + null, + detector + ); + } + ) + ); + } catch (Exception ex) { + listener.onFailure(ex); + return; + } + listener.onResponse(r); + + }, listener::onFailure); if (!anomalyDetectionIndices.doesAnomalyDetectorJobIndexExist()) { anomalyDetectionIndices.initAnomalyDetectorJobIndex(ActionListener.wrap(response -> { if (response.isAcknowledged()) { logger.info("Created {} with mappings.", ANOMALY_DETECTORS_INDEX); - createJob(detector); + createJob(detector, startListener); } else { logger.warn("Created {} with mappings call not acknowledged.", ANOMALY_DETECTORS_INDEX); - listener + startListener .onFailure( new OpenSearchStatusException( "Created " + ANOMALY_DETECTORS_INDEX + " with mappings call not acknowledged.", @@ -130,13 +175,13 @@ public void startAnomalyDetectorJob(AnomalyDetector detector) { ) ); } - }, exception -> listener.onFailure(exception))); + }, exception -> startListener.onFailure(exception))); } else { - createJob(detector); + createJob(detector, startListener); } } - private void createJob(AnomalyDetector detector) { + private void createJob(AnomalyDetector detector, ActionListener listener) { try { IntervalTimeConfiguration interval = (IntervalTimeConfiguration) detector.getDetectionInterval(); Schedule schedule = new IntervalSchedule(Instant.now(), (int) interval.getInterval(), interval.getUnit()); @@ -155,7 +200,7 @@ private void createJob(AnomalyDetector detector) { detector.getResultIndex() ); - getAnomalyDetectorJobForWrite(detector, job); + getAnomalyDetectorJobForWrite(detector, job, listener); } catch (Exception e) { String message = "Failed to parse anomaly detector job " + detectorId; logger.error(message, e); @@ -163,19 +208,30 @@ private void createJob(AnomalyDetector detector) { } } - private void getAnomalyDetectorJobForWrite(AnomalyDetector detector, AnomalyDetectorJob job) { + private void getAnomalyDetectorJobForWrite( + AnomalyDetector detector, + AnomalyDetectorJob job, + ActionListener listener + ) { GetRequest getRequest = new GetRequest(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX).id(detectorId); client .get( getRequest, ActionListener - .wrap(response -> onGetAnomalyDetectorJobForWrite(response, detector, job), exception -> listener.onFailure(exception)) + .wrap( + response -> onGetAnomalyDetectorJobForWrite(response, detector, job, listener), + exception -> listener.onFailure(exception) + ) ); } - private void onGetAnomalyDetectorJobForWrite(GetResponse response, AnomalyDetector detector, AnomalyDetectorJob job) - throws IOException { + private void onGetAnomalyDetectorJobForWrite( + GetResponse response, + AnomalyDetector detector, + AnomalyDetectorJob job, + ActionListener listener + ) throws IOException { if (response.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -207,7 +263,7 @@ private void onGetAnomalyDetectorJobForWrite(GetResponse response, AnomalyDetect transportService, ActionListener .wrap( - r -> { indexAnomalyDetectorJob(newJob, null); }, + r -> { indexAnomalyDetectorJob(newJob, null, listener); }, e -> { // Have logged error message in ADTaskManager#startDetector listener.onFailure(e); @@ -227,12 +283,16 @@ private void onGetAnomalyDetectorJobForWrite(GetResponse response, AnomalyDetect null, job.getUser(), transportService, - ActionListener.wrap(r -> { indexAnomalyDetectorJob(job, null); }, e -> listener.onFailure(e)) + ActionListener.wrap(r -> { indexAnomalyDetectorJob(job, null, listener); }, e -> listener.onFailure(e)) ); } } - private void indexAnomalyDetectorJob(AnomalyDetectorJob job, AnomalyDetectorFunction function) throws IOException { + private void indexAnomalyDetectorJob( + AnomalyDetectorJob job, + AnomalyDetectorFunction function, + ActionListener listener + ) throws IOException { IndexRequest indexRequest = new IndexRequest(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .source(job.toXContent(XContentFactory.jsonBuilder(), RestHandlerUtils.XCONTENT_WITH_TYPE)) @@ -244,11 +304,18 @@ private void indexAnomalyDetectorJob(AnomalyDetectorJob job, AnomalyDetectorFunc .index( indexRequest, ActionListener - .wrap(response -> onIndexAnomalyDetectorJobResponse(response, function), exception -> listener.onFailure(exception)) + .wrap( + response -> onIndexAnomalyDetectorJobResponse(response, function, listener), + exception -> listener.onFailure(exception) + ) ); } - private void onIndexAnomalyDetectorJobResponse(IndexResponse response, AnomalyDetectorFunction function) { + private void onIndexAnomalyDetectorJobResponse( + IndexResponse response, + AnomalyDetectorFunction function, + ActionListener listener + ) { if (response == null || (response.getResult() != CREATED && response.getResult() != UPDATED)) { String errorMsg = getShardsFailure(response); listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); @@ -274,8 +341,9 @@ private void onIndexAnomalyDetectorJobResponse(IndexResponse response, AnomalyDe * 2.If job exists: a).if job state is disabled, return error message; b).if job state is enabled, disable job. * * @param detectorId detector identifier + * @param listener Listener to send responses */ - public void stopAnomalyDetectorJob(String detectorId) { + public void stopAnomalyDetectorJob(String detectorId, ActionListener listener) { GetRequest getRequest = new GetRequest(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX).id(detectorId); client.get(getRequest, ActionListener.wrap(response -> { @@ -304,8 +372,9 @@ public void stopAnomalyDetectorJob(String detectorId) { .execute( StopDetectorAction.INSTANCE, new StopDetectorRequest(detectorId), - stopAdDetectorListener(detectorId) - ) + stopAdDetectorListener(detectorId, listener) + ), + listener ); } } catch (IOException e) { @@ -319,7 +388,10 @@ public void stopAnomalyDetectorJob(String detectorId) { }, exception -> listener.onFailure(exception))); } - private ActionListener stopAdDetectorListener(String detectorId) { + private ActionListener stopAdDetectorListener( + String detectorId, + ActionListener listener + ) { return new ActionListener() { @Override public void onResponse(StopDetectorResponse stopDetectorResponse) { diff --git a/src/main/java/org/opensearch/ad/settings/EnabledSetting.java b/src/main/java/org/opensearch/ad/settings/EnabledSetting.java index 6a22b769b..797df2a2b 100644 --- a/src/main/java/org/opensearch/ad/settings/EnabledSetting.java +++ b/src/main/java/org/opensearch/ad/settings/EnabledSetting.java @@ -39,8 +39,9 @@ public class EnabledSetting extends AbstractSetting { public static final String LEGACY_OPENDISTRO_AD_BREAKER_ENABLED = "opendistro.anomaly_detection.breaker.enabled"; - public static final String INTERPOLATION_IN_HCAD_COLD_START_ENABLED = - "plugins.anomaly_detection.hcad_cold_start_interpolation.enabled";; + public static final String INTERPOLATION_IN_HCAD_COLD_START_ENABLED = "plugins.anomaly_detection.hcad_cold_start_interpolation.enabled"; + + public static final String DOOR_KEEPER_IN_CACHE_ENABLED = "plugins.anomaly_detection.door_keeper_in_cache.enabled";; public static final Map> settings = unmodifiableMap(new HashMap>() { { @@ -75,6 +76,13 @@ public class EnabledSetting extends AbstractSetting { INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Setting.boolSetting(INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false, NodeScope, Dynamic) ); + + /** + * We have a bloom filter placed in front of inactive entity cache to + * filter out unpopular items that are not likely to appear more + * than once. Whether this bloom filter is enabled or not. + */ + put(DOOR_KEEPER_IN_CACHE_ENABLED, Setting.boolSetting(DOOR_KEEPER_IN_CACHE_ENABLED, false, NodeScope, Dynamic)); } }); @@ -112,4 +120,12 @@ public static boolean isADBreakerEnabled() { public static boolean isInterpolationInColdStartEnabled() { return EnabledSetting.getInstance().getSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED); } + + /** + * If enabled, we filter out unpopular items that are not likely to appear more than once + * @return wWhether door keeper in cache is enabled or not. + */ + public static boolean isDoorKeeperInCacheEnabled() { + return EnabledSetting.getInstance().getSettingValue(EnabledSetting.DOOR_KEEPER_IN_CACHE_ENABLED); + } } diff --git a/src/main/java/org/opensearch/ad/task/ADTaskManager.java b/src/main/java/org/opensearch/ad/task/ADTaskManager.java index 3d8f0ea95..1b3a775c8 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskManager.java @@ -334,7 +334,7 @@ private void startRealtimeOrHistoricalDetection( try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { if (detectionDateRange == null) { // start realtime job - handler.startAnomalyDetectorJob(detector.get()); + handler.startAnomalyDetectorJob(detector.get(), listener); } else { // start historical analysis task forwardApplyForTaskSlotsRequestToLeadNode(detector.get(), detectionDateRange, user, transportService, listener); @@ -852,7 +852,7 @@ public void stopDetector( ); } else { // stop realtime detector job - handler.stopAnomalyDetectorJob(detectorId); + handler.stopAnomalyDetectorJob(detectorId, listener); } }, listener); } @@ -2820,6 +2820,13 @@ public boolean skipUpdateHCRealtimeTask(String detectorId, String error) { && Objects.equals(error, realtimeTaskCache.getError()); } + public boolean isHCRealtimeTaskStartInitializing(String detectorId) { + ADRealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); + return realtimeTaskCache != null + && realtimeTaskCache.getInitProgress() != null + && realtimeTaskCache.getInitProgress().floatValue() > 0; + } + public String convertEntityToString(ADTask adTask) { if (adTask == null || !adTask.isEntityTask()) { return null; diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java index c7f2beddf..181dfc797 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java @@ -24,6 +24,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.DetectionDateRange; import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; @@ -51,6 +52,7 @@ public class AnomalyDetectorJobTransportAction extends HandledTransportAction filterByEnabled = it); + this.recorder = recorder; } @Override @@ -131,7 +135,6 @@ private void executeDetector( ) { IndexAnomalyDetectorJobActionHandler handler = new IndexAnomalyDetectorJobActionHandler( client, - listener, anomalyDetectionIndices, detectorId, seqNo, @@ -139,7 +142,8 @@ private void executeDetector( requestTimeout, xContentRegistry, transportService, - adTaskManager + adTaskManager, + recorder ); if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { adTaskManager.startDetector(detectorId, detectionDateRange, handler, user, transportService, context, listener); diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java index 0aa359a8f..921a5dc55 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java @@ -696,7 +696,7 @@ private Exception coldStartIfNoModel(AtomicReference failure, Anomaly } LOG.info("Trigger cold start for {}", detector.getDetectorId()); coldStart(detector); - return previousException.orElse(new InternalFailure(adID, NO_MODEL_ERR_MSG)); + return previousException.orElse(exp); } private void findException(Throwable cause, String adID, AtomicReference failure, String nodeId) { diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java index a56addd72..0c3d35037 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java @@ -30,6 +30,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.Locale; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadFactory; @@ -49,12 +50,14 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.ad.common.exception.EndRunException; import org.opensearch.ad.indices.AnomalyDetectionIndices; +import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.handler.AnomalyIndexHandler; import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.DiscoveryNodeFilterer; import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; @@ -114,6 +117,13 @@ public class AnomalyDetectorJobRunnerTests extends AbstractADTest { @Mock private AnomalyDetectionIndices indexUtil; + private ExecuteADResultResponseRecorder recorder; + + @Mock + private DiscoveryNodeFilterer nodeFilter; + + private AnomalyDetector detector; + @BeforeClass public static void setUpBeforeClass() { setUpThreadPool(AnomalyDetectorJobRunnerTests.class.getSimpleName()); @@ -138,7 +148,6 @@ public void setup() throws Exception { Mockito.doReturn(threadContext).when(mockedThreadPool).getThreadContext(); runner.setThreadPool(mockedThreadPool); runner.setClient(client); - runner.setAnomalyResultHandler(anomalyResultHandler); runner.setAdTaskManager(adTaskManager); Settings settings = Settings @@ -154,7 +163,6 @@ public void setup() throws Exception { AnomalyDetectionIndices anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); IndexNameExpressionResolver indexNameResolver = mock(IndexNameExpressionResolver.class); IndexUtils indexUtils = new IndexUtils(client, clientUtil, clusterService, indexNameResolver); - NodeStateManager stateManager = mock(NodeStateManager.class); runner.setAnomalyDetectionIndices(indexUtil); @@ -195,6 +203,18 @@ public void setup() throws Exception { return null; }).when(client).index(any(), any()); + + recorder = new ExecuteADResultResponseRecorder(indexUtil, anomalyResultHandler, adTaskManager, nodeFilter, threadPool, client); + runner.setExecuteADResultResponseRecorder(recorder); + detector = TestHelpers.randomAnomalyDetectorWithEmptyFeature(); + + NodeStateManager stateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + runner.setNodeStateManager(stateManager); } @Rule @@ -222,7 +242,7 @@ public void testRunJobWithNullLockDuration() throws InterruptedException { when(jobParameter.getLockDurationSeconds()).thenReturn(null); when(jobParameter.getSchedule()).thenReturn(new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES)); runner.runJob(jobParameter, context); - Thread.sleep(1000); + Thread.sleep(2000); assertTrue(testAppender.containsMessage("Can't get lock for AD job")); } @@ -239,7 +259,7 @@ public void testRunJobWithLockDuration() throws InterruptedException { @Test public void testRunAdJobWithNullLock() { LockModel lock = null; - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now()); + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, never()).execute(any(), any(), any()); } @@ -247,7 +267,7 @@ public void testRunAdJobWithNullLock() { public void testRunAdJobWithLock() { LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now()); + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, times(1)).execute(any(), any(), any()); } @@ -257,7 +277,7 @@ public void testRunAdJobWithExecuteException() { doThrow(RuntimeException.class).when(client).execute(any(), any(), any()); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now()); + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, times(1)).execute(any(), any(), any()); assertTrue(testAppender.containsMessage("Failed to execute AD job")); } @@ -266,7 +286,17 @@ public void testRunAdJobWithExecuteException() { public void testRunAdJobWithEndRunExceptionNow() { LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); - runner.handleAdException(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), exception); + runner + .handleAdException( + jobParameter, + lockService, + lock, + Instant.now().minusMillis(1000 * 60), + Instant.now(), + exception, + recorder, + detector + ); verify(anomalyResultHandler).index(any(), any(), any()); } @@ -366,7 +396,17 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b return null; }).when(client).index(any(IndexRequest.class), any()); - runner.handleAdException(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), exception); + runner + .handleAdException( + jobParameter, + lockService, + lock, + Instant.now().minusMillis(1000 * 60), + Instant.now(), + exception, + recorder, + detector + ); } @Test @@ -380,7 +420,17 @@ public void testRunAdJobWithEndRunExceptionNowAndGetJobException() { return null; }).when(client).get(any(GetRequest.class), any()); - runner.handleAdException(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), exception); + runner + .handleAdException( + jobParameter, + lockService, + lock, + Instant.now().minusMillis(1000 * 60), + Instant.now(), + exception, + recorder, + detector + ); assertTrue(testAppender.containsMessage("JobRunner will stop AD job due to EndRunException for")); assertTrue(testAppender.containsMessage("JobRunner failed to get detector job")); verify(anomalyResultHandler).index(any(), any(), any()); @@ -404,7 +454,17 @@ public void testRunAdJobWithEndRunExceptionNowAndFailToGetJob() { return null; }).when(client).get(any(), any()); - runner.handleAdException(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), exception); + runner + .handleAdException( + jobParameter, + lockService, + lock, + Instant.now().minusMillis(1000 * 60), + Instant.now(), + exception, + recorder, + detector + ); verify(anomalyResultHandler).index(any(), any(), any()); assertEquals(1, testAppender.countMessage("JobRunner failed to get detector job")); } @@ -425,10 +485,10 @@ public void testRunAdJobWithEndRunExceptionNotNowAndRetryUntilStop() throws Inte }).when(client).execute(any(), any(), any()); for (int i = 0; i < 3; i++) { - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime); + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); assertEquals(i + 1, testAppender.countMessage("EndRunException happened for")); } - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime); + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); assertEquals(1, testAppender.countMessage("JobRunner will stop AD job due to EndRunException retry exceeds upper limit")); } diff --git a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java index a4e0ccc13..d939884bf 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java @@ -56,6 +56,7 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.Entity; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.settings.EnabledSetting; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -650,21 +651,26 @@ public void testSelectEmpty() { // test that detector interval is more than 1 hour that maintenance is called before // the next get method public void testLongDetectorInterval() { - when(clock.instant()).thenReturn(Instant.ofEpochSecond(1000)); - when(detector.getDetectionIntervalDuration()).thenReturn(Duration.ofHours(12)); - String modelId = entity1.getModelId(detectorId).get(); - // record last access time 1000 - entityCache.get(modelId, detector); - assertEquals(-1, entityCache.getLastActiveMs(detectorId, modelId)); - // 2 hour = 7200 seconds have passed - long currentTimeEpoch = 8200; - when(clock.instant()).thenReturn(Instant.ofEpochSecond(currentTimeEpoch)); - // door keeper should not be expired since we reclaim space every 60 intervals - entityCache.maintenance(); - // door keeper still has the record and won't blocks entity state being created - entityCache.get(modelId, detector); - // * 1000 to convert to milliseconds - assertEquals(currentTimeEpoch * 1000, entityCache.getLastActiveMs(detectorId, modelId)); + try { + EnabledSetting.getInstance().setSettingValue(EnabledSetting.DOOR_KEEPER_IN_CACHE_ENABLED, true); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(1000)); + when(detector.getDetectionIntervalDuration()).thenReturn(Duration.ofHours(12)); + String modelId = entity1.getModelId(detectorId).get(); + // record last access time 1000 + assertTrue(null == entityCache.get(modelId, detector)); + assertEquals(-1, entityCache.getLastActiveMs(detectorId, modelId)); + // 2 hour = 7200 seconds have passed + long currentTimeEpoch = 8200; + when(clock.instant()).thenReturn(Instant.ofEpochSecond(currentTimeEpoch)); + // door keeper should not be expired since we reclaim space every 60 intervals + entityCache.maintenance(); + // door keeper still has the record and won't blocks entity state being created + entityCache.get(modelId, detector); + // * 1000 to convert to milliseconds + assertEquals(currentTimeEpoch * 1000, entityCache.getLastActiveMs(detectorId, modelId)); + } finally { + EnabledSetting.getInstance().setSettingValue(EnabledSetting.DOOR_KEEPER_IN_CACHE_ENABLED, false); + } } public void testGetNoPriorityUpdate() { diff --git a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java index 0c2cce832..074b15552 100644 --- a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java +++ b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java @@ -16,8 +16,6 @@ import java.text.SimpleDateFormat; import java.time.Clock; import java.time.Instant; -import java.time.format.DateTimeFormatter; -import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Calendar; import java.util.Date; @@ -45,7 +43,7 @@ public class DetectionResultEvalutationIT extends AbstractSyntheticDataTest { protected static final Logger LOG = (Logger) LogManager.getLogger(DetectionResultEvalutationIT.class); /** - * Simulate starting the given HCAD detector. + * Wait for HCAD cold start to finish. * @param detectorId Detector Id * @param data Data in Json format * @param trainTestSplit Training data size @@ -54,7 +52,7 @@ public class DetectionResultEvalutationIT extends AbstractSyntheticDataTest { * @param client OpenSearch Client * @throws Exception when failing to query/indexing from/to OpenSearch */ - private void simulateHCADStartDetector( + private void waitForHCADStartDetector( String detectorId, List data, int trainTestSplit, @@ -63,18 +61,6 @@ private void simulateHCADStartDetector( RestClient client ) throws Exception { - Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(trainTestSplit - 1).get("timestamp").getAsString())); - - Instant begin = null; - Instant end = null; - for (int i = 0; i < shingleSize; i++) { - begin = trainTime.minus(intervalMinutes * (shingleSize - 1 - i), ChronoUnit.MINUTES); - end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); - try { - getDetectionResult(detectorId, begin, end, client); - } catch (Exception e) {} - } - // It takes time to wait for model initialization long startTime = System.currentTimeMillis(); long duration = 0; do { @@ -94,7 +80,7 @@ private void simulateHCADStartDetector( break; } try { - getDetectionResult(detectorId, begin, end, client); + profileDetectorInitProgress(detectorId, client); } catch (Exception e) {} duration = System.currentTimeMillis() - startTime; } while (duration <= 60_000); @@ -280,14 +266,15 @@ private void verifyRestart(String datasetName, int intervalMinutes, int shingleS String detectorId = createDetector(datasetName, intervalMinutes, client, categoricalField, 0); // cannot stop without actually starting detector because ad complains no ad job index startDetector(detectorId, client); + profileDetectorInitProgress(detectorId, client); // it would be long if we wait for the job actually run the work periodically; speed it up by using simulateHCADStartDetector - simulateHCADStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + waitForHCADStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); String initProgress = profileDetectorInitProgress(detectorId, client); assertEquals("init progress is " + initProgress, "100%", initProgress); stopDetector(detectorId, client); // restart detector startDetector(detectorId, client); - simulateHCADStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + waitForHCADStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); initProgress = profileDetectorInitProgress(detectorId, client); assertEquals("init progress is " + initProgress, "100%", initProgress); } diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index 1236c6217..4f4ca09c4 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -114,7 +114,6 @@ public void testTrainUsingSamples() throws InterruptedException { public void testColdStart() throws InterruptedException, IOException { Queue samples = MLUtil.createQueueSamples(1); - double[] savedSample = samples.peek(); EntityModel model = new EntityModel(entity, samples, null); modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java index 4b7b8b00d..caa36cdaa 100644 --- a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java +++ b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java @@ -20,6 +20,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.DetectionDateRange; import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; @@ -52,6 +53,7 @@ public class MockAnomalyDetectorJobTransportActionWithUser extends private ThreadContext.StoredContext context; private final ADTaskManager adTaskManager; private final TransportService transportService; + private final ExecuteADResultResponseRecorder recorder; @Inject public MockAnomalyDetectorJobTransportActionWithUser( @@ -62,7 +64,8 @@ public MockAnomalyDetectorJobTransportActionWithUser( Settings settings, AnomalyDetectionIndices anomalyDetectionIndices, NamedXContentRegistry xContentRegistry, - ADTaskManager adTaskManager + ADTaskManager adTaskManager, + ExecuteADResultResponseRecorder recorder ) { super(MockAnomalyDetectorJobAction.NAME, transportService, actionFilters, AnomalyDetectorJobRequest::new); this.transportService = transportService; @@ -77,6 +80,7 @@ public MockAnomalyDetectorJobTransportActionWithUser( ThreadContext threadContext = new ThreadContext(settings); context = threadContext.stashContext(); + this.recorder = recorder; } @Override @@ -131,7 +135,6 @@ private void executeDetector( ) { IndexAnomalyDetectorJobActionHandler handler = new IndexAnomalyDetectorJobActionHandler( client, - listener, anomalyDetectionIndices, detectorId, seqNo, @@ -139,7 +142,8 @@ private void executeDetector( requestTimeout, xContentRegistry, transportService, - adTaskManager + adTaskManager, + recorder ); if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { adTaskManager.startDetector(detectorId, detectionDateRange, handler, user, transportService, context, listener); diff --git a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java index 77dc78015..0bc23d60c 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java @@ -30,7 +30,10 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -38,10 +41,13 @@ import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; import org.opensearch.rest.RestStatus; +import test.org.opensearch.ad.util.MLUtil; + public class EntityColdStartWorkerTests extends AbstractRateLimitingTest { ClusterService clusterService; EntityColdStartWorker worker; EntityColdStarter entityColdStarter; + CacheProvider cacheProvider; @Override public void setUp() throws Exception { @@ -64,6 +70,8 @@ public void setUp() throws Exception { entityColdStarter = mock(EntityColdStarter.class); + cacheProvider = mock(CacheProvider.class); + // Integer.MAX_VALUE makes a huge heap worker = new EntityColdStartWorker( Integer.MAX_VALUE, @@ -82,7 +90,8 @@ public void setUp() throws Exception { AnomalyDetectorSettings.QUEUE_MAINTENANCE, entityColdStarter, AnomalyDetectorSettings.HOURLY_MAINTENANCE, - nodeStateManager + nodeStateManager, + cacheProvider ); } @@ -135,4 +144,22 @@ public void testException() { verify(entityColdStarter, times(2)).trainModel(any(), anyString(), any(), any()); verify(nodeStateManager, times(2)).setException(eq(detectorId), any(OpenSearchStatusException.class)); } + + public void testModelHosted() { + EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + + ModelState state = invocation.getArgument(2); + state.setModel(MLUtil.createNonEmptyModel(detectorId)); + listener.onResponse(null); + + return null; + }).when(entityColdStarter).trainModel(any(), anyString(), any(), any()); + + worker.put(request); + + verify(cacheProvider, times(1)).get(); + } } diff --git a/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java new file mode 100644 index 000000000..c0366e95a --- /dev/null +++ b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java @@ -0,0 +1,354 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest.handler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.ad.constant.CommonErrorMessages.CAN_NOT_FIND_LATEST_TASK; + +import java.io.IOException; +import java.util.Arrays; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.common.exception.ResourceNotFoundException; +import org.opensearch.ad.constant.CommonErrorMessages; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.indices.AnomalyDetectionIndices; +import org.opensearch.ad.mock.model.MockSimpleLog; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.Feature; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultResponse; +import org.opensearch.ad.transport.ProfileAction; +import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.handler.AnomalyIndexHandler; +import org.opensearch.ad.util.DiscoveryNodeFilterer; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class IndexAnomalyDetectorJobActionHandlerTests extends OpenSearchTestCase { + + private static AnomalyDetectionIndices anomalyDetectionIndices; + private static String detectorId; + private static Long seqNo; + private static Long primaryTerm; + + private static NamedXContentRegistry xContentRegistry; + private static TransportService transportService; + private static TimeValue requestTimeout; + private static DiscoveryNodeFilterer nodeFilter; + private static AnomalyDetector detector; + + private ADTaskManager adTaskManager; + + private ThreadPool threadPool; + + private ExecuteADResultResponseRecorder recorder; + private Client client; + private IndexAnomalyDetectorJobActionHandler handler; + private AnomalyIndexHandler anomalyResultHandler; + + @BeforeClass + public static void setOnce() throws IOException { + detectorId = "123"; + seqNo = 1L; + primaryTerm = 2L; + anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); + xContentRegistry = NamedXContentRegistry.EMPTY; + transportService = mock(TransportService.class); + + requestTimeout = TimeValue.timeValueMinutes(60); + when(anomalyDetectionIndices.doesAnomalyDetectorJobIndexExist()).thenReturn(true); + + nodeFilter = mock(DiscoveryNodeFilterer.class); + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a")); + } + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + listener.onResponse(response); + + return null; + }).when(client).get(any(GetRequest.class), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + + IndexResponse response = mock(IndexResponse.class); + when(response.getResult()).thenReturn(CREATED); + listener.onResponse(response); + + return null; + }).when(client).index(any(IndexRequest.class), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + AnomalyResultResponse response = new AnomalyResultResponse(null, "", 0L, 10L, true); + listener.onResponse(response); + + return null; + }).when(client).execute(any(AnomalyResultAction.class), any(), any()); + + adTaskManager = mock(ADTaskManager.class); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[4]; + + AnomalyDetectorJobResponse response = mock(AnomalyDetectorJobResponse.class); + listener.onResponse(response); + + return null; + }).when(adTaskManager).startDetector(any(), any(), any(), any(), any()); + + threadPool = mock(ThreadPool.class); + + anomalyResultHandler = mock(AnomalyIndexHandler.class); + + recorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client + ); + + handler = new IndexAnomalyDetectorJobActionHandler( + client, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + requestTimeout, + xContentRegistry, + transportService, + adTaskManager, + recorder + ); + } + + @SuppressWarnings("unchecked") + public void testDelayHCProfile() { + when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(false); + + ActionListener listener = mock(ActionListener.class); + + handler.startAnomalyDetectorJob(detector, listener); + + verify(client, times(1)).get(any(), any()); + verify(client, times(1)).execute(any(), any(), any()); + verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); + verify(threadPool, times(1)).schedule(any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testNoDelayHCProfile() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + ProfileResponse response = mock(ProfileResponse.class); + when(response.getTotalUpdates()).thenReturn(3L); + listener.onResponse(response); + + return null; + }).when(client).execute(any(ProfileAction.class), any(), any()); + + when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); + + ActionListener listener = mock(ActionListener.class); + + handler.startAnomalyDetectorJob(detector, listener); + + verify(client, times(1)).get(any(), any()); + verify(client, times(2)).execute(any(), any(), any()); + verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + verify(threadPool, never()).schedule(any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testHCProfileException() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + listener.onFailure(new RuntimeException()); + + return null; + }).when(client).execute(any(ProfileAction.class), any(), any()); + + when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); + + ActionListener listener = mock(ActionListener.class); + + handler.startAnomalyDetectorJob(detector, listener); + + verify(client, times(1)).get(any(), any()); + verify(client, times(2)).execute(any(), any(), any()); + verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); + verify(adTaskManager, never()).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + verify(threadPool, never()).schedule(any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testUpdateLatestRealtimeTaskOnCoordinatingNodeResourceNotFoundException() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + ProfileResponse response = mock(ProfileResponse.class); + when(response.getTotalUpdates()).thenReturn(3L); + listener.onResponse(response); + + return null; + }).when(client).execute(any(ProfileAction.class), any(), any()); + + when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[5]; + + listener.onFailure(new ResourceNotFoundException(CAN_NOT_FIND_LATEST_TASK)); + + return null; + }).when(adTaskManager).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + + ActionListener listener = mock(ActionListener.class); + + handler.startAnomalyDetectorJob(detector, listener); + + verify(client, times(1)).get(any(), any()); + verify(client, times(2)).execute(any(), any(), any()); + verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).removeRealtimeTaskCache(anyString()); + verify(threadPool, never()).schedule(any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testUpdateLatestRealtimeTaskOnCoordinatingException() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + ProfileResponse response = mock(ProfileResponse.class); + when(response.getTotalUpdates()).thenReturn(3L); + listener.onResponse(response); + + return null; + }).when(client).execute(any(ProfileAction.class), any(), any()); + + when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[5]; + + listener.onFailure(new RuntimeException()); + + return null; + }).when(adTaskManager).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + + ActionListener listener = mock(ActionListener.class); + + handler.startAnomalyDetectorJob(detector, listener); + + verify(client, times(1)).get(any(), any()); + verify(client, times(2)).execute(any(), any(), any()); + verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + verify(adTaskManager, never()).removeRealtimeTaskCache(anyString()); + verify(threadPool, never()).schedule(any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testIndexException() throws IOException { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + listener.onFailure(new ResourceNotFoundException(detectorId, CommonErrorMessages.NO_CHECKPOINT_ERR_MSG)); + + return null; + }).when(client).execute(any(AnomalyResultAction.class), any(), any()); + + ActionListener listener = mock(ActionListener.class); + AggregationBuilder aggregationBuilder = TestHelpers + .parseAggregation("{\"test\":{\"max\":{\"field\":\"" + MockSimpleLog.VALUE_FIELD + "\"}}}"); + Feature feature = new Feature(randomAlphaOfLength(5), randomAlphaOfLength(10), true, aggregationBuilder); + detector = TestHelpers + .randomDetector( + ImmutableList.of(feature), + "test", + 10, + MockSimpleLog.TIME_FIELD, + null, + CommonName.CUSTOM_RESULT_INDEX_PREFIX + "index" + ); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + handler.startAnomalyDetectorJob(detector, listener); + verify(anomalyResultHandler, times(1)).index(any(), any(), eq(null)); + verify(threadPool, times(1)).schedule(any(), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java index cefb64f05..860347a21 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java @@ -25,6 +25,7 @@ import org.junit.Test; import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.DetectionDateRange; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -76,7 +77,8 @@ public void setUp() throws Exception { indexSettings(), mock(AnomalyDetectionIndices.class), xContentRegistry(), - mock(ADTaskManager.class) + mock(ADTaskManager.class), + mock(ExecuteADResultResponseRecorder.class) ); task = mock(Task.class); request = new AnomalyDetectorJobRequest("1234", 4567, 7890, "_start");