Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Fix issue where max number of multi-entity detector doesn't work for …
Browse files Browse the repository at this point in the history
…UpdateDetector (#285)
  • Loading branch information
yizheliu-amazon authored Oct 22, 2020
1 parent 68f2104 commit ab836c7
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX;
import static com.amazon.opendistroforelasticsearch.ad.util.RestHandlerUtils.XCONTENT_WITH_TYPE;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -53,6 +54,7 @@
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
Expand Down Expand Up @@ -245,31 +247,52 @@ private void updateAnomalyDetector(String detectorId) {
);
}

private void onGetAnomalyDetectorResponse(GetResponse response) throws IOException {
private void onGetAnomalyDetectorResponse(GetResponse response) {
if (!response.isExists()) {
listener
.onFailure(new ElasticsearchStatusException("AnomalyDetector is not found with id: " + detectorId, RestStatus.NOT_FOUND));
return;
}
try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation);
AnomalyDetector existingDetector = AnomalyDetector.parse(parser, response.getId(), response.getVersion());
if (!hasCategoryField(existingDetector) && hasCategoryField(this.anomalyDetector)) {
validateAgainstExistingMultiEntityAnomalyDetector(detectorId);
} else {
validateCategoricalField(detectorId);
}
} catch (IOException e) {
String message = "Failed to parse anomaly detector " + detectorId;
logger.error(message, e);
listener.onFailure(new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR));
}

validateCategoricalField(detectorId);
}

private boolean hasCategoryField(AnomalyDetector detector) {
return detector.getCategoryField() != null && !detector.getCategoryField().isEmpty();
}

private void validateAgainstExistingMultiEntityAnomalyDetector(String detectorId) {
QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout);

SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder);

client
.search(
searchRequest,
ActionListener
.wrap(response -> onSearchMultiEntityAdResponse(response, detectorId), exception -> listener.onFailure(exception))
);
}

private void createAnomalyDetector() {
try {
List<String> categoricalFields = anomalyDetector.getCategoryField();
if (categoricalFields != null && categoricalFields.size() > 0) {
QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout);

SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder);

client
.search(
searchRequest,
ActionListener.wrap(response -> onSearchMultiEntityAdResponse(response), exception -> listener.onFailure(exception))
);
validateAgainstExistingMultiEntityAnomalyDetector(null);
} else {
QueryBuilder query = QueryBuilders.matchAllQuery();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout);
Expand Down Expand Up @@ -298,13 +321,13 @@ private void onSearchSingleEntityAdResponse(SearchResponse response) throws IOEx
}
}

private void onSearchMultiEntityAdResponse(SearchResponse response) throws IOException {
private void onSearchMultiEntityAdResponse(SearchResponse response, String detectorId) throws IOException {
if (response.getHits().getTotalHits().value >= maxMultiEntityAnomalyDetectors) {
String errorMsg = EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG + maxMultiEntityAnomalyDetectors;
logger.error(errorMsg);
listener.onFailure(new IllegalArgumentException(errorMsg));
} else {
validateCategoricalField(null);
validateCategoricalField(detectorId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> void doE
}

@SuppressWarnings("unchecked")
private void testValidTypeTepmlate(String filedTypeName) throws IOException {
private void testValidTypeTemplate(String filedTypeName) throws IOException {
String field = "a";
AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field));

Expand Down Expand Up @@ -388,24 +388,24 @@ public <Request extends ActionRequest, Response extends ActionResponse> void doE
}

public void testIpField() throws IOException {
testValidTypeTepmlate(CommonName.IP_TYPE);
testValidTypeTemplate(CommonName.IP_TYPE);
}

public void testKeywordField() throws IOException {
testValidTypeTepmlate(CommonName.KEYWORD_TYPE);
testValidTypeTemplate(CommonName.KEYWORD_TYPE);
}

@SuppressWarnings("unchecked")
private void testUpdateTepmlate(String fieldTypeName) throws IOException {
private void testUpdateTemplate(String fieldTypeName) throws IOException {
String field = "a";
AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field));

SearchResponse detectorResponse = mock(SearchResponse.class);
int totalHits = 9;
when(detectorResponse.getHits()).thenReturn(createSearchHits(totalHits));

GetResponse getDetectorResponse = mock(GetResponse.class);
when(getDetectorResponse.isExists()).thenReturn(true);
GetResponse getDetectorResponse = TestHelpers
.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX);

SearchResponse userIndexResponse = mock(SearchResponse.class);
int userIndexHits = 0;
Expand Down Expand Up @@ -485,15 +485,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void doE
}

public void testUpdateIpField() throws IOException {
testUpdateTepmlate(CommonName.IP_TYPE);
testUpdateTemplate(CommonName.IP_TYPE);
}

public void testUpdateKeywordField() throws IOException {
testUpdateTepmlate(CommonName.KEYWORD_TYPE);
testUpdateTemplate(CommonName.KEYWORD_TYPE);
}

public void testUpdateTextField() throws IOException {
testUpdateTepmlate(TEXT_FIELD_TYPE);
testUpdateTemplate(TEXT_FIELD_TYPE);
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -527,4 +527,151 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException {
assertTrue(value instanceof IllegalArgumentException);
assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG));
}

@SuppressWarnings("unchecked")
public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOException {
int totalHits = 10;
AnomalyDetector existingDetector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, null);
GetResponse getDetectorResponse = TestHelpers
.createGetResponse(existingDetector, existingDetector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX);

SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(createSearchHits(totalHits));

doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2);

assertTrue(args[0] instanceof SearchRequest);
assertTrue(args[1] instanceof ActionListener);

ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) args[1];

listener.onResponse(searchResponse);

return null;
}).when(clientMock).search(any(SearchRequest.class), any());

doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2);

assertTrue(args[0] instanceof GetRequest);
assertTrue(args[1] instanceof ActionListener);

ActionListener<GetResponse> listener = (ActionListener<GetResponse>) args[1];

listener.onResponse(getDetectorResponse);

return null;
}).when(clientMock).get(any(GetRequest.class), any());

ClusterName clusterName = new ClusterName("test");
ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build();
when(clusterService.state()).thenReturn(clusterState);

handler = new IndexAnomalyDetectorActionHandler(
clusterService,
clientMock,
channel,
anomalyDetectionIndices,
detectorId,
seqNo,
primaryTerm,
refreshPolicy,
detector,
requestTimeout,
maxSingleEntityAnomalyDetectors,
maxMultiEntityAnomalyDetectors,
maxAnomalyFeatures,
RestRequest.Method.PUT,
xContentRegistry(),
mock(RestClient.class),
null
);

handler.resolveUserAndStart();

ArgumentCaptor<Exception> response = ArgumentCaptor.forClass(Exception.class);
verify(clientMock, times(1)).search(any(SearchRequest.class), any());
verify(clientMock, times(1)).get(any(GetRequest.class), any());
verify(channel).onFailure(response.capture());
Exception value = response.getValue();
assertTrue(value instanceof IllegalArgumentException);
assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG));
}

@SuppressWarnings("unchecked")
public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOException {
int totalHits = 10;
AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a"));
GetResponse getDetectorResponse = TestHelpers
.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX);

SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(createSearchHits(totalHits));

doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2);

assertTrue(args[0] instanceof SearchRequest);
assertTrue(args[1] instanceof ActionListener);

ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) args[1];

listener.onResponse(searchResponse);

return null;
}).when(clientMock).search(any(SearchRequest.class), any());

doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2);

assertTrue(args[0] instanceof GetRequest);
assertTrue(args[1] instanceof ActionListener);

ActionListener<GetResponse> listener = (ActionListener<GetResponse>) args[1];

listener.onResponse(getDetectorResponse);

return null;
}).when(clientMock).get(any(GetRequest.class), any());

ClusterName clusterName = new ClusterName("test");
ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build();
when(clusterService.state()).thenReturn(clusterState);

handler = new IndexAnomalyDetectorActionHandler(
clusterService,
clientMock,
channel,
anomalyDetectionIndices,
detectorId,
seqNo,
primaryTerm,
refreshPolicy,
detector,
requestTimeout,
maxSingleEntityAnomalyDetectors,
maxMultiEntityAnomalyDetectors,
maxAnomalyFeatures,
RestRequest.Method.PUT,
xContentRegistry(),
mock(RestClient.class),
null
);

handler.resolveUserAndStart();

ArgumentCaptor<Exception> response = ArgumentCaptor.forClass(Exception.class);
verify(clientMock, times(0)).search(any(SearchRequest.class), any());
verify(clientMock, times(1)).get(any(GetRequest.class), any());
verify(channel).onFailure(response.capture());
Exception value = response.getValue();
// make sure execution passes all necessary checks
assertTrue(value instanceof IllegalStateException);
assertTrue(value.getMessage().contains("NodeClient has not been initialized"));
}
}

0 comments on commit ab836c7

Please sign in to comment.