diff --git a/build.gradle b/build.gradle index d2aa92b7..4b37620c 100644 --- a/build.gradle +++ b/build.gradle @@ -261,9 +261,6 @@ List jacocoExclusions = [ 'com.amazon.opendistroforelasticsearch.ad.transport.SearchAnomalyResultTransportAction*', // TODO: hc caused coverage to drop - //'com.amazon.opendistroforelasticsearch.ad.ml.ModelManager', - 'com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultTransportAction', - 'com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultTransportAction.EntityResultListener', 'com.amazon.opendistroforelasticsearch.ad.NodeStateManager', 'com.amazon.opendistroforelasticsearch.ad.transport.handler.MultiEntityResultHandler', 'com.amazon.opendistroforelasticsearch.ad.transport.EntityProfileTransportAction*', diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java index a9dbcf0a..85e921eb 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java @@ -442,7 +442,7 @@ public Collection createComponents( .>builder() .put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) .put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.MODEL_INFORMATION.getName(), new ADStat<>(false, new ModelsOnNodeSupplier(modelManager))) + .put(StatNames.MODEL_INFORMATION.getName(), new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider))) .put( StatNames.ANOMALY_DETECTORS_INDEX_STATUS.getName(), new ADStat<>(true, new IndexStatusSupplier(indexUtils, AnomalyDetector.ANOMALY_DETECTORS_INDEX)) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBuffer.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBuffer.java index 192013b7..cdec3bd6 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBuffer.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBuffer.java @@ -20,10 +20,12 @@ import java.time.Instant; import java.util.AbstractMap.SimpleImmutableEntry; import java.util.Comparator; +import java.util.List; import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentSkipListSet; +import java.util.stream.Collectors; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; @@ -525,4 +527,8 @@ public boolean expired(Duration stateTtl) { public String getDetectorId() { return detectorId; } + + public List> getAllModels() { + return items.values().stream().collect(Collectors.toList()); + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/EntityCache.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/EntityCache.java index fd42cc13..9915f605 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/EntityCache.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/EntityCache.java @@ -15,6 +15,8 @@ package com.amazon.opendistroforelasticsearch.ad.caching; +import java.util.List; + import com.amazon.opendistroforelasticsearch.ad.CleanState; import com.amazon.opendistroforelasticsearch.ad.MaintenanceState; import com.amazon.opendistroforelasticsearch.ad.ml.EntityModel; @@ -72,4 +74,11 @@ public interface EntityCache extends MaintenanceState, CleanState { * @return RCF model total updates of specific entity */ long getTotalUpdates(String detectorId, String entityModelId); + + /** + * Gets modelStates of all model hosted on a node + * + * @return list of modelStates + */ + List> getAllModels(); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCache.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCache.java index 43f11aa9..508d8e42 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCache.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCache.java @@ -23,6 +23,8 @@ import java.time.Instant; import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -554,4 +556,16 @@ public int getTotalActiveEntities() { activeEnities.values().stream().forEach(cacheBuffer -> { total.addAndGet(cacheBuffer.getActiveEntities()); }); return total.get(); } + + /** + * Gets modelStates of all model hosted on a node + * + * @return list of modelStates + */ + @Override + public List> getAllModels() { + List> states = new ArrayList<>(); + activeEnities.values().stream().forEach(cacheBuffer -> states.addAll(cacheBuffer.getAllModels())); + return states; + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java index c696e980..b445c0a8 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java @@ -812,8 +812,8 @@ public void getFeaturesByEntities( new ThreadedActionListener<>(logger, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, termsListener, false) ); - } catch (IOException e) { - throw new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, true); + } catch (Exception e) { + throw new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, false); } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index dfcf7e6c..cfd759ab 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -20,6 +20,7 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.ArrayDeque; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; @@ -1022,7 +1023,9 @@ public void processEntityCheckpoint( modelState.setLastCheckpointTime(clock.instant().minus(checkpointInterval)); } - assert (modelState.getModel() != null); + if (modelState.getModel() == null) { + modelState.setModel(new EntityModel(modelId, new ArrayDeque<>(), null, null)); + } maybeTrainBeforeScore(modelState, entityName); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplier.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplier.java index d8cb4b66..669b5da8 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplier.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplier.java @@ -27,7 +27,9 @@ import java.util.Set; import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; /** @@ -35,6 +37,7 @@ */ public class ModelsOnNodeSupplier implements Supplier>> { private ModelManager modelManager; + private CacheProvider cache; /** * Set that contains the model stats that should be exposed. @@ -45,16 +48,18 @@ public class ModelsOnNodeSupplier implements Supplier>> * Constructor * * @param modelManager object that manages the model partitions hosted on the node + * @param cache object that manages multi-entity detectors' models */ - public ModelsOnNodeSupplier(ModelManager modelManager) { + public ModelsOnNodeSupplier(ModelManager modelManager, CacheProvider cache) { this.modelManager = modelManager; + this.cache = cache; } @Override public List> get() { List> values = new ArrayList<>(); - modelManager - .getAllModels() + Stream + .concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream()) .forEach( modelState -> values .add( diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java index 514709ac..6bdeb9c8 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -104,6 +105,8 @@ public class AnomalyResultTransportAction extends HandledTransportAction listener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - AnomalyResultRequest request = AnomalyResultRequest.fromActionRequest(actionRequest); ActionListener original = listener; listener = ActionListener.wrap(original::onResponse, e -> { @@ -233,7 +235,6 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< listener.onFailure(new LimitExceededException(adID, CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); return; } - try { stateManager.getAnomalyDetector(adID, onGetDetector(listener, adID, request)); } catch (Exception ex) { @@ -297,7 +298,7 @@ private ActionListener> onGetDetector( ) ); } else { - entityFeatures + Set>> node2Entities = entityFeatures .entrySet() .stream() .collect( @@ -307,26 +308,29 @@ private ActionListener> onGetDetector( Collectors.toMap(Entry::getKey, Entry::getValue) ) ) - .entrySet() - .stream() - .forEach(nodeEntity -> { - DiscoveryNode node = nodeEntity.getKey(); - transportService - .sendRequest( - node, - EntityResultAction.NAME, - new EntityResultRequest(adID, nodeEntity.getValue(), dataStartTime, dataEndTime), - this.option, - new ActionListenerResponseHandler<>( - new EntityResultListener(node.getId(), adID), - AcknowledgedResponse::new, - ThreadPool.Names.SAME - ) - ); - }); + .entrySet(); + + int nodeCount = node2Entities.size(); + AtomicInteger responseCount = new AtomicInteger(); + + final AtomicReference failure = new AtomicReference<>(); + node2Entities.stream().forEach(nodeEntity -> { + DiscoveryNode node = nodeEntity.getKey(); + transportService + .sendRequest( + node, + EntityResultAction.NAME, + new EntityResultRequest(adID, nodeEntity.getValue(), dataStartTime, dataEndTime), + this.option, + new ActionListenerResponseHandler<>( + new EntityResultListener(node.getId(), adID, responseCount, nodeCount, failure, listener), + AcknowledgedResponse::new, + ThreadPool.Names.SAME + ) + ); + }); } - listener.onResponse(new AnomalyResultResponse(0, 0, 0, new ArrayList())); }, exception -> handleFailure(exception, listener, adID)); threadPool @@ -482,7 +486,7 @@ private ActionListener onFeatureResponse( private void handleFailure(Exception exception, ActionListener listener, String adID) { if (exception instanceof IndexNotFoundException) { - listener.onFailure(new EndRunException(adID, "Having trouble querying data: " + exception.getMessage(), true)); + listener.onFailure(new EndRunException(adID, TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), true)); } else if (exception instanceof EndRunException) { // invalid feature query listener.onFailure(exception); @@ -555,7 +559,7 @@ private void findException(Throwable cause, String adID, AtomicReference failure) { - LOG.error(new ParameterizedMessage("Received an error from node {} while fetching anomaly grade for {}", nodeID, adID), e); + LOG.error(new ParameterizedMessage("Received an error from node {} while doing model inference for {}", nodeID, adID), e); if (e == null) { return; } @@ -801,6 +805,8 @@ private void handlePredictionFailure(Exception e, String adID, String nodeID, At /** * Check if the input exception indicates connection issues. + * During blue-green deployment, we may see ActionNotFoundTransportException. + * Count that as connection issue and isolate that node if it continues to happen. * * @param e exception * @return true if we get disconnected from the node or the node is not in the @@ -811,7 +817,8 @@ private boolean hasConnectionIssue(Throwable e) { || e instanceof NodeClosedException || e instanceof ReceiveTimeoutTransportException || e instanceof NodeNotConnectedException - || e instanceof ConnectException; + || e instanceof ConnectException + || e instanceof ActionNotFoundTransportException; } private void handleConnectionException(String node) { @@ -1015,18 +1022,45 @@ private Optional coldStartIfNoCheckPoint(AnomalyDetec class EntityResultListener implements ActionListener { private String nodeId; private final String adID; + private AtomicInteger responseCount; + private int nodeCount; + private ActionListener listener; + private List ackResponses; + private AtomicReference failure; - EntityResultListener(String nodeId, String adID) { + EntityResultListener( + String nodeId, + String adID, + AtomicInteger responseCount, + int nodeCount, + AtomicReference failure, + ActionListener listener + ) { this.nodeId = nodeId; this.adID = adID; + this.responseCount = responseCount; + this.nodeCount = nodeCount; + this.failure = failure; + this.listener = listener; + this.ackResponses = new ArrayList<>(); } @Override public void onResponse(AcknowledgedResponse response) { - stateManager.resetBackpressureCounter(nodeId); - if (response.isAcknowledged() == false) { - LOG.error("Cannot send entities' features to {} for {}", nodeId, adID); - stateManager.addPressure(nodeId); + try { + stateManager.resetBackpressureCounter(nodeId); + if (response.isAcknowledged() == false) { + LOG.error("Cannot send entities' features to {} for {}", nodeId, adID); + stateManager.addPressure(nodeId); + } else { + ackResponses.add(response); + } + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, adID); + } finally { + if (nodeCount == responseCount.incrementAndGet()) { + handleEntityResponses(); + } } } @@ -1035,13 +1069,28 @@ public void onFailure(Exception e) { if (e == null) { return; } - Throwable cause = ExceptionsHelper.unwrapCause(e); - // in case of connection issue or the other node has no multi-entity - // transport actions (e.g., blue green deployment) - if (hasConnectionIssue(cause) || cause instanceof ActionNotFoundTransportException) { - handleConnectionException(nodeId); + try { + LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e); + + handlePredictionFailure(e, adID, nodeId, failure); + + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, adID); + } finally { + if (nodeCount == responseCount.incrementAndGet()) { + handleEntityResponses(); + } + } + } + + private void handleEntityResponses() { + if (failure.get() != null) { + listener.onFailure(failure.get()); + } else if (ackResponses.isEmpty()) { + listener.onFailure(new InternalFailure(adID, NO_ACK_ERR)); + } else { + listener.onResponse(new AnomalyResultResponse(0, 0, 0, new ArrayList())); } - LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e); } } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java index 9d15ccd0..c8581808 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java @@ -227,6 +227,9 @@ public void setupTestNodes(Settings settings) { } public void tearDownTestNodes() { + if (testNodes == null) { + return; + } for (FakeNode testNode : testNodes) { testNode.close(); } @@ -238,7 +241,7 @@ public void assertException( Class exceptionType, String msg ) { - Exception e = expectThrows(exceptionType, () -> listener.actionGet()); + Exception e = expectThrows(exceptionType, () -> listener.actionGet(20_000)); assertThat(e.getMessage(), containsString(msg)); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java index e8f1d6b3..7f6e2dd1 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java @@ -226,7 +226,7 @@ public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields(String de ImmutableList.of(randomFeature(true)), randomQuery(), randomIntervalTimeConfiguration(), - randomIntervalTimeConfiguration(), + new IntervalTimeConfiguration(0, ChronoUnit.MINUTES), randomIntBetween(1, 2000), null, randomInt(), diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBufferTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBufferTests.java index afabdc9e..de93b6ce 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBufferTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBufferTests.java @@ -167,6 +167,7 @@ public void testMaintenance() { cacheBuffer.put(modelId3, MLUtil.randomModelState(initialPriority, modelId3)); cacheBuffer.maintenance(); assertEquals(3, cacheBuffer.getActiveEntities()); + assertEquals(3, cacheBuffer.getAllModels().size()); when(clock.instant()).thenReturn(Instant.MAX); cacheBuffer.maintenance(); assertEquals(0, cacheBuffer.getActiveEntities()); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCacheTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCacheTests.java index 798e781f..c1f3dc1d 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCacheTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCacheTests.java @@ -186,6 +186,7 @@ public void testCacheHit() { // cache miss due to door keeper assertEquals(null, cacheProvider.get(modelId1, detector, point, entityName)); assertEquals(1, cacheProvider.getTotalActiveEntities()); + assertEquals(1, cacheProvider.getAllModels().size()); ModelState hitState = cacheProvider.get(modelId1, detector, point, entityName); assertEquals(detectorId, hitState.getDetectorId()); EntityModel model = hitState.getModel(); @@ -248,10 +249,12 @@ public void testSharedCache() { } assertEquals(2, cacheProvider.getActiveEntities(detectorId2)); assertEquals(3, cacheProvider.getTotalActiveEntities()); + assertEquals(3, cacheProvider.getAllModels().size()); when(memoryTracker.memoryToShed()).thenReturn(memoryPerEntity); cacheProvider.maintenance(); assertEquals(2, cacheProvider.getTotalActiveEntities()); + assertEquals(2, cacheProvider.getAllModels().size()); assertEquals(1, cacheProvider.getActiveEntities(detectorId2)); } @@ -377,9 +380,11 @@ public void testExpiredCacheBuffer() { cacheProvider.get(modelId2, detector, point, entityName); } assertEquals(2, cacheProvider.getTotalActiveEntities()); + assertEquals(2, cacheProvider.getAllModels().size()); when(clock.instant()).thenReturn(Instant.now()); cacheProvider.maintenance(); assertEquals(0, cacheProvider.getTotalActiveEntities()); + assertEquals(0, cacheProvider.getAllModels().size()); for (int i = 0; i < 2; i++) { // doorkeeper should have been reset diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java index 2527e756..692f1d93 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java @@ -71,6 +71,8 @@ import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunnerDelegate; +import test.com.amazon.opendistroforelasticsearch.ad.util.MLUtil; + import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorPlugin; import com.amazon.opendistroforelasticsearch.ad.MemoryTracker; import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; @@ -164,6 +166,7 @@ public class ModelManagerTests { private double[] attribution; private double[] point; private DiVector attributionVec; + private String entityName; @Mock private ActionListener rcfResultListener; @@ -171,6 +174,7 @@ public class ModelManagerTests { @Mock private ActionListener thresholdResultListener; private MemoryTracker memoryTracker; + private Instant now; @Before public void setup() { @@ -232,6 +236,9 @@ public void setup() { modelPartitioner = spy(new ModelPartitioner(numSamples, numTrees, nodeFilter, memoryTracker)); + now = Instant.now(); + when(clock.instant()).thenReturn(now); + modelManager = spy( new ModelManager( rcfSerde, @@ -284,6 +291,8 @@ public void setup() { listener.onResponse(Optional.of(failCheckpoint)); return null; }).when(checkpointDao).getModelCheckpoint(eq(failModelId), any(ActionListener.class)); + + entityName = "1.0.2.3"; } private Object[] getDetectorIdForModelIdData() { @@ -1188,4 +1197,56 @@ public void getPreviewResults_returnAnomalies_forLastAnomaly() { public void getPreviewResults_throwIllegalArgument_forInvalidInput() { modelManager.getPreviewResults(new double[0][0]); } + + @Test + public void getNullState() { + assertEquals(new ThresholdingResult(0, 0, 0), modelManager.getAnomalyResultForEntity("", new double[] {}, "", null, "")); + } + + @Test + public void getEmptyStateFullSamples() { + ModelState state = MLUtil.randomModelStateWithSample(false, numMinSamples); + assertEquals( + new ThresholdingResult(0, 0, 0), + modelManager.getAnomalyResultForEntity(detectorId, new double[] { -1 }, entityName, state, modelId) + ); + assertEquals(numMinSamples, state.getModel().getSamples().size()); + } + + @Test + public void getEmptyStateNotFullSamples() { + ModelState state = MLUtil.randomModelStateWithSample(false, numMinSamples - 1); + assertEquals( + new ThresholdingResult(0, 0, 0), + modelManager.getAnomalyResultForEntity(detectorId, new double[] { -1 }, entityName, state, modelId) + ); + assertEquals(numMinSamples, state.getModel().getSamples().size()); + } + + @Test + public void scoreSamples() { + ModelState state = MLUtil.randomNonEmptyModelState(); + modelManager.getAnomalyResultForEntity(detectorId, new double[] { -1 }, entityName, state, modelId); + assertEquals(0, state.getModel().getSamples().size()); + assertEquals(now, state.getLastUsedTime()); + } + + @Test + public void processEmptyCheckpoint() { + ModelState modelState = MLUtil.randomModelStateWithSample(false, numMinSamples - 1); + modelManager.processEntityCheckpoint(Optional.empty(), modelId, entityName, modelState); + assertEquals(now.minus(checkpointInterval), modelState.getLastCheckpointTime()); + } + + @Test + public void processNonEmptyCheckpoint() { + EntityModel model = MLUtil.createNonEmptyModel(modelId); + ModelState modelState = MLUtil.randomModelStateWithSample(false, numMinSamples); + Instant checkpointTime = Instant.ofEpochMilli(1000); + modelManager + .processEntityCheckpoint(Optional.of(new SimpleImmutableEntry<>(model, checkpointTime)), modelId, entityName, modelState); + assertEquals(checkpointTime, modelState.getLastCheckpointTime()); + assertEquals(0, modelState.getModel().getSamples().size()); + assertEquals(now, modelState.getLastUsedTime()); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/ADStatsTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/ADStatsTests.java index 356ec8c1..3e17e984 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/ADStatsTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/ADStatsTests.java @@ -34,6 +34,11 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import test.com.amazon.opendistroforelasticsearch.ad.util.MLUtil; + +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; +import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; +import com.amazon.opendistroforelasticsearch.ad.ml.EntityModel; import com.amazon.opendistroforelasticsearch.ad.ml.HybridThresholdingModel; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.ml.ModelState; @@ -58,6 +63,9 @@ public class ADStatsTests extends ESTestCase { @Mock private ModelManager modelManager; + @Mock + private CacheProvider cacheProvider; + @Before public void setup() { MockitoAnnotations.initMocks(this); @@ -76,6 +84,15 @@ public void setup() { ); when(modelManager.getAllModels()).thenReturn(modelsInformation); + + ModelState entityModel1 = MLUtil.randomNonEmptyModelState(); + ModelState entityModel2 = MLUtil.randomNonEmptyModelState(); + + List> entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); + EntityCache cache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(cache); + when(cache.getAllModels()).thenReturn(entityModelsInformation); + IndexUtils indexUtils = mock(IndexUtils.class); when(indexUtils.getIndexHealthStatus(anyString())).thenReturn("yellow"); @@ -90,7 +107,7 @@ public void setup() { statsMap = new HashMap>() { { put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); - put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager))); + put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider))); put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java index bd932fba..e21a8be6 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java @@ -16,6 +16,7 @@ package com.amazon.opendistroforelasticsearch.ad.stats.suppliers; import static com.amazon.opendistroforelasticsearch.ad.stats.suppliers.ModelsOnNodeSupplier.MODEL_STATE_STAT_KEYS; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.time.Clock; @@ -24,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.elasticsearch.test.ESTestCase; import org.junit.Before; @@ -31,6 +33,11 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import test.com.amazon.opendistroforelasticsearch.ad.util.MLUtil; + +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; +import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; +import com.amazon.opendistroforelasticsearch.ad.ml.EntityModel; import com.amazon.opendistroforelasticsearch.ad.ml.HybridThresholdingModel; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.ml.ModelState; @@ -41,10 +48,14 @@ public class ModelsOnNodeSupplierTests extends ESTestCase { private HybridThresholdingModel thresholdingModel; private List> expectedResults; private Clock clock; + private List> entityModelsInformation; @Mock private ModelManager modelManager; + @Mock + private CacheProvider cacheProvider; + @Before public void setup() { MockitoAnnotations.initMocks(this); @@ -64,16 +75,24 @@ public void setup() { ); when(modelManager.getAllModels()).thenReturn(expectedResults); + + ModelState entityModel1 = MLUtil.randomNonEmptyModelState(); + ModelState entityModel2 = MLUtil.randomNonEmptyModelState(); + + entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); + EntityCache cache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(cache); + when(cache.getAllModels()).thenReturn(entityModelsInformation); } @Test public void testGet() { - ModelsOnNodeSupplier modelsOnNodeSupplier = new ModelsOnNodeSupplier(modelManager); + ModelsOnNodeSupplier modelsOnNodeSupplier = new ModelsOnNodeSupplier(modelManager, cacheProvider); List> results = modelsOnNodeSupplier.get(); assertEquals( "get fails to return correct result", - expectedResults - .stream() + Stream + .concat(expectedResults.stream(), entityModelsInformation.stream()) .map( modelState -> modelState .getModelStateAsMap() diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportActionTests.java index 8c66031b..40d8f607 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportActionTests.java @@ -16,6 +16,7 @@ package com.amazon.opendistroforelasticsearch.ad.transport; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.time.Clock; import java.util.Arrays; @@ -34,6 +35,8 @@ import org.junit.Before; import org.junit.Test; +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; +import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.stats.ADStat; import com.amazon.opendistroforelasticsearch.ad.stats.ADStats; @@ -69,6 +72,9 @@ public void setUp() throws Exception { indexNameResolver ); ModelManager modelManager = mock(ModelManager.class); + CacheProvider cacheProvider = mock(CacheProvider.class); + EntityCache cache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(cache); clusterStatName1 = "clusterStat1"; clusterStatName2 = "clusterStat2"; @@ -78,7 +84,7 @@ public void setUp() throws Exception { statsMap = new HashMap>() { { put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); - put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager))); + put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider))); put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/MultientityResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/MultientityResultTests.java new file mode 100644 index 00000000..6e7804d3 --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/MultientityResultTests.java @@ -0,0 +1,498 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.function.Function; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportInterceptor; +import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportRequestOptions; +import org.elasticsearch.transport.TransportResponse; +import org.elasticsearch.transport.TransportResponseHandler; +import org.elasticsearch.transport.TransportService; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +import test.com.amazon.opendistroforelasticsearch.ad.util.MLUtil; + +import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; +import com.amazon.opendistroforelasticsearch.ad.NodeStateManager; +import com.amazon.opendistroforelasticsearch.ad.TestHelpers; +import com.amazon.opendistroforelasticsearch.ad.breaker.ADCircuitBreakerService; +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; +import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; +import com.amazon.opendistroforelasticsearch.ad.cluster.HashRing; +import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; +import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; +import com.amazon.opendistroforelasticsearch.ad.common.exception.InternalFailure; +import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; +import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; +import com.amazon.opendistroforelasticsearch.ad.feature.SearchFeatureDao; +import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; +import com.amazon.opendistroforelasticsearch.ad.ml.CheckpointDao; +import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; +import com.amazon.opendistroforelasticsearch.ad.ml.ModelPartitioner; +import com.amazon.opendistroforelasticsearch.ad.ml.ThresholdingResult; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; +import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings; +import com.amazon.opendistroforelasticsearch.ad.stats.ADStat; +import com.amazon.opendistroforelasticsearch.ad.stats.ADStats; +import com.amazon.opendistroforelasticsearch.ad.stats.StatNames; +import com.amazon.opendistroforelasticsearch.ad.stats.suppliers.CounterSupplier; +import com.amazon.opendistroforelasticsearch.ad.transport.handler.MultiEntityResultHandler; +import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; +import com.amazon.opendistroforelasticsearch.ad.util.IndexUtils; + +public class MultientityResultTests extends AbstractADTest { + private AnomalyResultTransportAction action; + private AnomalyResultRequest request; + private TransportInterceptor entityResultInterceptor; + private Clock clock; + private AnomalyDetector detector; + private NodeStateManager stateManager; + private static Settings settings; + private TransportService transportService; + private SearchFeatureDao searchFeatureDao; + private Client client; + private FeatureManager featureQuery; + private ModelManager normalModelManager; + private ModelPartitioner normalModelPartitioner; + private HashRing hashRing; + private ClusterService clusterService; + private IndexNameExpressionResolver indexNameResolver; + private ADCircuitBreakerService adCircuitBreakerService; + private ADStats adStats; + private ThreadPool mockThreadPool; + private String detectorId; + private Instant now; + private String modelId; + private MultiEntityResultHandler anomalyResultHandler; + private CheckpointDao checkpointDao; + private CacheProvider provider; + private AnomalyDetectionIndices indexUtil; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @SuppressWarnings({ "serial", "unchecked" }) + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + now = Instant.now(); + clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + detectorId = "123"; + modelId = "abc"; + String categoryField = "a"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Collections.singletonList(categoryField)); + + stateManager = mock(NodeStateManager.class); + // make sure parameters are not null, otherwise this mock won't get invoked + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(stateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + when(stateManager.getLastIndexThrottledTime()).thenReturn(Instant.MIN); + + settings = Settings.builder().put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build(); + + request = new AnomalyResultRequest(detectorId, 100, 200); + + transportService = mock(TransportService.class); + + client = mock(Client.class); + ThreadContext threadContext = new ThreadContext(settings); + mockThreadPool = mock(ThreadPool.class); + setUpADThreadPool(mockThreadPool); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + featureQuery = mock(FeatureManager.class); + + normalModelManager = mock(ModelManager.class); + when(normalModelManager.getEntityModelId(anyString(), anyString())).thenReturn(modelId); + + normalModelPartitioner = mock(ModelPartitioner.class); + + hashRing = mock(HashRing.class); + + clusterService = mock(ClusterService.class); + + indexNameResolver = new IndexNameExpressionResolver(); + + adCircuitBreakerService = mock(ADCircuitBreakerService.class); + when(adCircuitBreakerService.isOpen()).thenReturn(false); + + IndexUtils indexUtils = new IndexUtils(client, mock(ClientUtil.class), clusterService, indexNameResolver); + Map> statsMap = new HashMap>() { + { + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + } + }; + adStats = new ADStats(indexUtils, normalModelManager, statsMap); + + searchFeatureDao = mock(SearchFeatureDao.class); + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + stateManager, + featureQuery, + normalModelManager, + normalModelPartitioner, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + searchFeatureDao + ); + + anomalyResultHandler = mock(MultiEntityResultHandler.class); + checkpointDao = mock(CheckpointDao.class); + provider = mock(CacheProvider.class); + indexUtil = mock(AnomalyDetectionIndices.class); + } + + @Override + @After + public final void tearDown() throws Exception { + tearDownTestNodes(); + super.tearDown(); + } + + @SuppressWarnings("unchecked") + public void testQueryError() { + // non-EndRunException won't stop action from running + when(stateManager.fetchColdStartException(anyString())).thenReturn(Optional.of(new AnomalyDetectionException(detectorId, ""))); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener + .onFailure( + new EndRunException( + detectorId, + CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + false + ) + ); + return null; + }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + verify(stateManager, times(1)).getAnomalyDetector(anyString(), any(ActionListener.class)); + + assertException(listener, EndRunException.class, CommonErrorMessages.INVALID_SEARCH_QUERY_MSG); + } + + public void testIndexNotFound() { + // non-EndRunException won't stop action from running + when(stateManager.fetchColdStartException(anyString())).thenReturn(Optional.of(new AnomalyDetectionException(detectorId, ""))); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onFailure(new IndexNotFoundException("", "")); + return null; + }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + assertException(listener, EndRunException.class, AnomalyResultTransportAction.TROUBLE_QUERYING_ERR_MSG); + } + + public void testColdStartEndRunException() { + when(stateManager.fetchColdStartException(anyString())) + .thenReturn( + Optional + .of( + new EndRunException( + detectorId, + CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + false + ) + ) + ); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + assertException(listener, EndRunException.class, CommonErrorMessages.INVALID_SEARCH_QUERY_MSG); + } + + public void testEmptyFeatures() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(new HashMap()); + return null; + }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + } + + private TransportResponseHandler entityResultHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse(response); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + private TransportResponseHandler unackEntityResultHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse((T) new AcknowledgedResponse(false)); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + private void setUpEntityResult() { + // register entity result action + new EntityResultTransportAction( + new ActionFilters(Collections.emptySet()), + // since we send requests to testNodes[1] + testNodes[1].transportService, + normalModelManager, + adCircuitBreakerService, + anomalyResultHandler, + checkpointDao, + provider, + stateManager, + settings, + clock, + indexUtil + ); + + EntityCache entityCache = mock(EntityCache.class); + when(provider.get()).thenReturn(entityCache); + when(entityCache.get(any(), any(), any(), anyString())).thenReturn(MLUtil.randomNonEmptyModelState()); + + when(normalModelManager.getAnomalyResultForEntity(anyString(), any(), anyString(), any(), anyString())) + .thenReturn(new ThresholdingResult(0, 1, 1)); + } + + private void setUpTransportInterceptor( + Function, TransportResponseHandler> interceptor + ) { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + Map features = new HashMap(); + features.put("1.0.2.3", new double[] { 0 }); + features.put("2.0.2.3", new double[] { 1 }); + listener.onResponse(features); + return null; + }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); + + entityResultInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @SuppressWarnings("unchecked") + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (action.equals(EntityResultAction.NAME)) { + sender + .sendRequest( + connection, + action, + request, + options, + interceptor.apply((TransportResponseHandler) handler) + ); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + + setupTestNodes(settings, entityResultInterceptor); + + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(testNodes[1].discoveryNode())); + + TransportService realTransportService = testNodes[0].transportService; + ClusterService realClusterService = testNodes[0].clusterService; + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + realTransportService, + settings, + client, + stateManager, + featureQuery, + normalModelManager, + normalModelPartitioner, + hashRing, + realClusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + searchFeatureDao + ); + } + + public void testNonEmptyFeatures() { + setUpTransportInterceptor(this::entityResultHandler); + setUpEntityResult(); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(0d, response.getAnomalyGrade(), 0.01); + } + + public void testCircuitBreakerOpen() { + setUpTransportInterceptor(this::entityResultHandler); + + ADCircuitBreakerService openBreaker = mock(ADCircuitBreakerService.class); + when(openBreaker.isOpen()).thenReturn(true); + // register entity result action + new EntityResultTransportAction( + new ActionFilters(Collections.emptySet()), + // since we send requests to testNodes[1] + testNodes[1].transportService, + normalModelManager, + openBreaker, + anomalyResultHandler, + checkpointDao, + provider, + stateManager, + settings, + clock, + indexUtil + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + assertException(listener, LimitExceededException.class, CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG); + } + + public void testNotAck() { + setUpTransportInterceptor(this::unackEntityResultHandler); + setUpEntityResult(); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + assertException(listener, InternalFailure.class, AnomalyResultTransportAction.NO_ACK_ERR); + verify(stateManager, times(1)).addPressure(anyString()); + } +} diff --git a/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/MLUtil.java b/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/MLUtil.java index d8431339..71288db1 100644 --- a/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/MLUtil.java +++ b/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/MLUtil.java @@ -38,6 +38,7 @@ */ public class MLUtil { private static Random random = new Random(42); + private static int minSampleSize = AnomalyDetectorSettings.NUM_MIN_SAMPLES; private static String randomString(int targetStringLength) { int leftLimit = 97; // letter 'a' @@ -58,54 +59,79 @@ public static Queue createQueueSamples(int size) { } public static ModelState randomModelState() { - return randomModelState(random.nextBoolean(), random.nextFloat(), randomString(15)); + return randomModelState(random.nextBoolean(), random.nextFloat(), randomString(15), random.nextInt(minSampleSize)); } - public static ModelState randomModelState(boolean fullModel, float priority, String modelId) { + public static ModelState randomModelState(boolean fullModel, float priority, String modelId, int sampleSize) { String detectorId = randomString(5); - Queue samples = createQueueSamples(random.nextInt(128)); EntityModel model = null; if (fullModel) { - RandomCutForest rcf = RandomCutForest - .builder() - .dimensions(1) - .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) - .numberOfTrees(AnomalyDetectorSettings.MULTI_ENTITY_NUM_TREES) - .lambda(AnomalyDetectorSettings.TIME_DECAY) - .outputAfter(AnomalyDetectorSettings.NUM_MIN_SAMPLES) - .parallelExecutionEnabled(false) - .build(); - int numDataPoints = random.nextInt(1000) + AnomalyDetectorSettings.NUM_MIN_SAMPLES; - double[] scores = new double[numDataPoints]; - for (int j = 0; j < numDataPoints; j++) { - double[] dataPoint = new double[] { random.nextDouble() }; - scores[j] = rcf.getAnomalyScore(dataPoint); - rcf.update(dataPoint); - } - - double[] nonZeroScores = DoubleStream.of(scores).filter(score -> score > 0).toArray(); - ThresholdingModel threshold = new HybridThresholdingModel( - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, - AnomalyDetectorSettings.THRESHOLD_MAX_RANK_ERROR, - AnomalyDetectorSettings.THRESHOLD_MAX_SCORE, - AnomalyDetectorSettings.THRESHOLD_NUM_LOGNORMAL_QUANTILES, - AnomalyDetectorSettings.THRESHOLD_DOWNSAMPLES, - AnomalyDetectorSettings.THRESHOLD_MAX_SAMPLES - ); - threshold.train(nonZeroScores); - model = new EntityModel(modelId, samples, rcf, threshold); + model = createNonEmptyModel(modelId, sampleSize); } else { - model = new EntityModel(modelId, samples, null, null); + model = createEmptyModel(modelId, sampleSize); } return new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), Clock.systemUTC(), priority); } public static ModelState randomNonEmptyModelState() { - return randomModelState(true, random.nextFloat(), randomString(15)); + return randomModelState(true, random.nextFloat(), randomString(15), random.nextInt(minSampleSize)); + } + + public static ModelState randomEmptyModelState() { + return randomModelState(false, random.nextFloat(), randomString(15), random.nextInt(minSampleSize)); } public static ModelState randomModelState(float priority, String modelId) { - return randomModelState(random.nextBoolean(), priority, modelId); + return randomModelState(random.nextBoolean(), priority, modelId, random.nextInt(minSampleSize)); + } + + public static ModelState randomModelStateWithSample(boolean fullModel, int sampleSize) { + return randomModelState(fullModel, random.nextFloat(), randomString(15), sampleSize); + } + + public static EntityModel createEmptyModel(String modelId, int sampleSize) { + Queue samples = createQueueSamples(sampleSize); + return new EntityModel(modelId, samples, null, null); + } + + public static EntityModel createEmptyModel(String modelId) { + return createEmptyModel(modelId, random.nextInt(minSampleSize)); + } + + public static EntityModel createNonEmptyModel(String modelId, int sampleSize) { + Queue samples = createQueueSamples(sampleSize); + RandomCutForest rcf = RandomCutForest + .builder() + .dimensions(1) + .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + .numberOfTrees(AnomalyDetectorSettings.MULTI_ENTITY_NUM_TREES) + .lambda(AnomalyDetectorSettings.TIME_DECAY) + .outputAfter(AnomalyDetectorSettings.NUM_MIN_SAMPLES) + .parallelExecutionEnabled(false) + .build(); + int numDataPoints = random.nextInt(1000) + AnomalyDetectorSettings.NUM_MIN_SAMPLES; + double[] scores = new double[numDataPoints]; + for (int j = 0; j < numDataPoints; j++) { + double[] dataPoint = new double[] { random.nextDouble() }; + scores[j] = rcf.getAnomalyScore(dataPoint); + rcf.update(dataPoint); + } + + double[] nonZeroScores = DoubleStream.of(scores).filter(score -> score > 0).toArray(); + ThresholdingModel threshold = new HybridThresholdingModel( + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.THRESHOLD_MAX_RANK_ERROR, + AnomalyDetectorSettings.THRESHOLD_MAX_SCORE, + AnomalyDetectorSettings.THRESHOLD_NUM_LOGNORMAL_QUANTILES, + AnomalyDetectorSettings.THRESHOLD_DOWNSAMPLES, + AnomalyDetectorSettings.THRESHOLD_MAX_SAMPLES + ); + threshold.train(nonZeroScores); + return new EntityModel(modelId, samples, rcf, threshold); + } + + public static EntityModel createNonEmptyModel(String modelId) { + return createNonEmptyModel(modelId, random.nextInt(minSampleSize)); } }