Skip to content

Commit

Permalink
[ML] Pass through the stop-on-warn setting for categorization jobs (#…
Browse files Browse the repository at this point in the history
…58738)

When per_partition_categorization.stop_on_warn is set for an analysis
config it is now passed through to the autodetect C++ process.

Also adds some end-to-end tests that exercise the functionality
added in elastic/ml-cpp#1356

Backport of #58632
  • Loading branch information
droberts195 authored Jun 30, 2020
1 parent 874ab36 commit d9e0e0b
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.results.ReservedFieldNames;
import org.elasticsearch.xpack.core.ml.job.results.Result;
Expand Down Expand Up @@ -121,7 +122,12 @@ public CategorizerStats(StreamInput in) throws IOException {
}

public String getId() {
return documentIdPrefix(jobId) + logTime.toEpochMilli();
StringBuilder idBuilder = new StringBuilder(documentIdPrefix(jobId));
idBuilder.append(logTime.toEpochMilli());
if (partitionFieldName != null) {
idBuilder.append('_').append(MachineLearningField.valuesToId(partitionFieldValue));
}
return idBuilder.toString();
}

public static String documentIdPrefix(String jobId) {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public class AutodetectBuilder {
static final String MAX_QUANTILE_INTERVAL_ARG = "--maxQuantileInterval=";
static final String SUMMARY_COUNT_FIELD_ARG = "--summarycountfield=";
static final String TIME_FIELD_ARG = "--timefield=";
static final String STOP_CATEGORIZATION_ON_WARN_ARG = "--stopCategorizationOnWarnStatus";

/**
* Name of the config setting containing the path to the logs directory
Expand Down Expand Up @@ -198,6 +199,9 @@ List<String> buildAutodetectCommand() {
if (Boolean.TRUE.equals(analysisConfig.getMultivariateByFields())) {
command.add(MULTIVARIATE_BY_FIELDS_ARG);
}
if (Boolean.TRUE.equals(analysisConfig.getPerPartitionCategorizationConfig().isStopOnWarn())) {
command.add(STOP_CATEGORIZATION_ON_WARN_ARG);
}
}

// Input is always length encoded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ public interface AutodetectProcess extends NativeProcess {
* @param rules Detector rules
* @throws IOException If the write fails
*/
void writeUpdateDetectorRulesMessage(int detectorIndex, List<DetectionRule> rules)
throws IOException;
void writeUpdateDetectorRulesMessage(int detectorIndex, List<DetectionRule> rules) throws IOException;

/**
* Write message to update the filters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ public void onFailure(Exception e) {
GetFiltersAction.Request getFilterRequest = new GetFiltersAction.Request(updateParams.getFilter().getId());
executeAsyncWithOrigin(client, ML_ORIGIN, GetFiltersAction.INSTANCE, getFilterRequest, ActionListener.wrap(
getFilterResponse -> filterListener.onResponse(getFilterResponse.getFilters().results().get(0)),
handler::accept
handler
));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void writeUpdateModelPlotMessage(ModelPlotConfig modelPlotConfig) throws
@Override
public void writeUpdatePerPartitionCategorizationMessage(PerPartitionCategorizationConfig perPartitionCategorizationConfig)
throws IOException {
// TODO: write the control message once it's been implemented on the C++ side
newMessageWriter().writeCategorizationStopOnWarnMessage(perPartitionCategorizationConfig.isStopOnWarn());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
*/
public class AutodetectControlMsgWriter extends AbstractControlMsgWriter {

/**
* This must match the code defined in the api::CFieldDataCategorizer C++ class.
*/
private static final String CATEGORIZATION_STOP_ON_WARN_MESSAGE_CODE = "c";

/**
* This must match the code defined in the api::CAnomalyJob C++ class.
*/
Expand All @@ -41,12 +46,12 @@ public class AutodetectControlMsgWriter extends AbstractControlMsgWriter {
/**
* This must match the code defined in the api::CAnomalyJob C++ class.
*/
private static final String FORECAST_MESSAGE_CODE = "p";
private static final String INTERIM_MESSAGE_CODE = "i";

/**
* This must match the code defined in the api::CAnomalyJob C++ class.
*/
private static final String INTERIM_MESSAGE_CODE = "i";
private static final String FORECAST_MESSAGE_CODE = "p";

/**
* This must match the code defined in the api::CAnomalyJob C++ class.
Expand Down Expand Up @@ -190,6 +195,10 @@ public void writeUpdateModelPlotMessage(ModelPlotConfig modelPlotConfig) throws
writeMessage(configWriter.toString());
}

public void writeCategorizationStopOnWarnMessage(boolean isStopOnWarn) throws IOException {
writeMessage(CATEGORIZATION_STOP_ON_WARN_MESSAGE_CODE + isStopOnWarn);
}

public void writeUpdateDetectorRulesMessage(int detectorIndex, List<DetectionRule> rules) throws IOException {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append(UPDATE_MESSAGE_CODE).append("[detectorRules]\n");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,18 @@ public Iterator<T> parseResults(InputStream in) throws ElasticsearchParseExcepti
if (token != XContentParser.Token.START_ARRAY) {
throw new ElasticsearchParseException("unexpected token [" + token + "]");
}
return new ResultIterator(in, parser);
return new ResultIterator(parser);
} catch (IOException e) {
throw new ElasticsearchParseException(e.getMessage(), e);
}
}

private class ResultIterator implements Iterator<T> {

private final InputStream in;
private final XContentParser parser;
private XContentParser.Token token;

private ResultIterator(InputStream in, XContentParser parser) {
this.in = in;
private ResultIterator(XContentParser parser) {
this.parser = parser;
token = parser.currentToken();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
import org.elasticsearch.xpack.core.ml.job.config.Detector;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.junit.Before;
Expand All @@ -24,6 +25,7 @@
import java.util.List;

import static org.elasticsearch.xpack.core.ml.job.config.JobTests.buildJobBuilder;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;

public class AutodetectBuilderTests extends ESTestCase {
Expand All @@ -46,15 +48,25 @@ public void setUpTests() {
}

public void testBuildAutodetectCommand() {
boolean isPerPartitionCategorization = randomBoolean();

Job.Builder job = buildJobBuilder("unit-test-job");

Detector.Builder detectorBuilder = new Detector.Builder("mean", "value");
if (isPerPartitionCategorization) {
detectorBuilder.setByFieldName("mlcategory");
}
detectorBuilder.setPartitionFieldName("foo");
AnalysisConfig.Builder acBuilder = new AnalysisConfig.Builder(Collections.singletonList(detectorBuilder.build()));
acBuilder.setBucketSpan(TimeValue.timeValueSeconds(120));
acBuilder.setLatency(TimeValue.timeValueSeconds(360));
acBuilder.setSummaryCountFieldName("summaryField");
acBuilder.setMultivariateByFields(true);
if (isPerPartitionCategorization) {
acBuilder.setCategorizationFieldName("bar");
}
acBuilder.setPerPartitionCategorizationConfig(
new PerPartitionCategorizationConfig(isPerPartitionCategorization, isPerPartitionCategorization));

job.setAnalysisConfig(acBuilder);

Expand All @@ -65,12 +77,12 @@ public void testBuildAutodetectCommand() {
job.setDataDescription(dd);

List<String> command = autodetectBuilder(job.build()).buildAutodetectCommand();
assertEquals(11, command.size());
assertTrue(command.contains(AutodetectBuilder.AUTODETECT_PATH));
assertTrue(command.contains(AutodetectBuilder.BUCKET_SPAN_ARG + "120"));
assertTrue(command.contains(AutodetectBuilder.LATENCY_ARG + "360"));
assertTrue(command.contains(AutodetectBuilder.SUMMARY_COUNT_FIELD_ARG + "summaryField"));
assertTrue(command.contains(AutodetectBuilder.MULTIVARIATE_BY_FIELDS_ARG));
assertThat(command.contains(AutodetectBuilder.STOP_CATEGORIZATION_ON_WARN_ARG), is(isPerPartitionCategorization));

assertTrue(command.contains(AutodetectBuilder.LENGTH_ENCODED_INPUT_ARG));
assertTrue(command.contains(AutodetectBuilder.maxAnomalyRecordsArg(settings)));
Expand All @@ -82,6 +94,8 @@ public void testBuildAutodetectCommand() {
assertTrue(command.contains(AutodetectBuilder.PERSIST_INTERVAL_ARG + expectedPersistInterval));
int expectedMaxQuantileInterval = 21600 + AutodetectBuilder.calculateStaggeringInterval(job.getId());
assertTrue(command.contains(AutodetectBuilder.MAX_QUANTILE_INTERVAL_ARG + expectedMaxQuantileInterval));

assertEquals(isPerPartitionCategorization ? 12 : 11, command.size());
}

public void testBuildAutodetectCommand_defaultTimeField() {
Expand Down

0 comments on commit d9e0e0b

Please sign in to comment.