diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java index b4cdfe2f..a2f8abd1 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java @@ -24,7 +24,7 @@ import java.util.Optional; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.time.Instant; @@ -62,7 +62,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; import org.elasticsearch.action.bulk.BackoffPolicy; import org.elasticsearch.action.index.IndexRequest; @@ -114,7 +113,6 @@ public class AnomalyResultTransportAction extends HandledTransportAction rcfResults = new ArrayList<>(); final AtomicReference failure = new AtomicReference(); + final AtomicInteger responseCount = new AtomicInteger(); + for (int i = 0; i < rcfPartitionNum; i++) { String rcfModelID = modelManager.getRcfModelId(adID, i); @@ -321,10 +320,24 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< } LOG.info("Sending RCF request to {} for model {}", rcfNodeId, rcfModelID); - LatchedActionListener rcfListener = new LatchedActionListener<>( - new RCFActionListener(rcfResults, rcfModelID.toString(), failure, rcfNodeId), - rcfLatch + + RCFActionListener rcfListener = new RCFActionListener( + rcfResults, + rcfModelID.toString(), + failure, + rcfNodeId, + detector, + listener, + thresholdModelID, + thresholdNode, + featureInResponse, + startTime, + endTime, + rcfPartitionNum, + responseCount, + adID ); + transportService .sendRequest( rcfNode.get(), @@ -334,87 +347,9 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< new ActionListenerResponseHandler<>(rcfListener, RCFResultResponse::new) ); } - - // wait a bit longer than transport timeout - long latchWaitSecs = Math.round(requestTimeout.getSeconds() * 1.25d); - - try { - LOG.debug("Wait for RCF results..."); - rcfLatch.await(latchWaitSecs, TimeUnit.SECONDS); - } catch (InterruptedException e) { - listener.onFailure(new InternalFailure(adID, CommonErrorMessages.WAIT_ERR_MSG, e)); - return; - } - - if (coldStartIfNoModel(failure, detector.get()) || rcfResults.isEmpty()) { - listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG)); - return; - } - - CombinedRcfResult combinedResult = getCombinedResult(rcfResults); - double combinedScore = combinedResult.getScore(); - - final CountDownLatch thresholdLatch = createCountDownLatch(1); - - final AtomicReference anomalyResultResponse = new AtomicReference<>(); - - LOG.info("Sending threshold request to {} for model {}", thresholdNodeId, thresholdModelID); - LatchedActionListener thresholdListener = new LatchedActionListener<>( - new ThresholdActionListener(anomalyResultResponse, featureInResponse, thresholdModelID, failure, thresholdNodeId), - thresholdLatch - ); - transportService - .sendRequest( - thresholdNode.get(), - ThresholdResultAction.NAME, - new ThresholdResultRequest(adID, thresholdModelID, combinedScore), - option, - new ActionListenerResponseHandler<>(thresholdListener, ThresholdResultResponse::new) - ); - - try { - LOG.debug("Wait for threshold results..."); - thresholdLatch.await(latchWaitSecs, TimeUnit.SECONDS); - } catch (InterruptedException e) { - listener.onFailure(new InternalFailure(adID, WAIT_FOR_THRESHOLD_ERR_MSG, e)); - return; - } - - if (coldStartIfNoModel(failure, detector.get())) { - listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG)); - return; - } - - if (anomalyResultResponse.get() != null) { - AnomalyResultResponse response = anomalyResultResponse.get(); - double confidence = response.getConfidence() * combinedResult.getConfidence(); - response = new AnomalyResultResponse(response.getAnomalyGrade(), confidence, response.getFeatures()); - listener.onResponse(response); - indexAnomalyResult( - new AnomalyResult( - adID, - Double.valueOf(combinedScore), - Double.valueOf(response.getAnomalyGrade()), - Double.valueOf(confidence), - featureInResponse, - Instant.ofEpochMilli(startTime), - Instant.ofEpochMilli(endTime) - ) - ); - } else if (failure.get() != null) { - listener.onFailure(failure.get()); - } else { - listener.onFailure(new InternalFailure(adID, "Unexpected exception")); - } - } catch (ClientException clientException) { - listener.onFailure(clientException); - } catch (AnomalyDetectionException adEx) { - listener.onFailure(new InternalFailure(adEx)); - } catch (Exception throwable) { - Throwable cause = ExceptionsHelper.unwrapCause(throwable); - listener.onFailure(new InternalFailure(adID, cause)); + } catch (Exception ex) { + handleExecuteException(ex, listener, adID); } - } /** @@ -637,37 +572,134 @@ void saveDetectorResult(IndexRequest indexRequest, String context, Iterator