Skip to content

Commit

Permalink
[ML] Parse single named object in config classes (elastic#53472) (ela…
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle authored Mar 17, 2020
1 parent 71b703e commit 2b63573
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ public class TrainedModelDefinition implements ToXContentObject {
true,
TrainedModelDefinition.Builder::new);
static {
PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
PARSER.declareNamedObject(TrainedModelDefinition.Builder::setTrainedModel,
(p, c, n) -> p.namedObject(TrainedModel.class, n, null),
(modelDocBuilder) -> { /* Noop does not matter client side*/ },
TRAINED_MODEL);
PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
Expand Down Expand Up @@ -124,11 +123,6 @@ public Builder setTrainedModel(TrainedModel trainedModel) {
return this;
}

private Builder setTrainedModel(List<TrainedModel> trainedModel) {
assert trainedModel.size() == 1;
return setTrainedModel(trainedModel.get(0));
}

public TrainedModelDefinition build() {
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ public class Ensemble implements TrainedModel {
p.namedObject(TrainedModel.class, n, null),
(ensembleBuilder) -> { /* Noop does not matter client side */ },
TRAINED_MODELS);
PARSER.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
PARSER.declareNamedObject(Ensemble.Builder::setOutputAggregator,
(p, c, n) -> p.namedObject(OutputAggregator.class, n, null),
(ensembleBuilder) -> { /* Noop does not matter client side */ },
AGGREGATE_OUTPUT);
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
Expand Down Expand Up @@ -194,9 +193,6 @@ public Builder setClassificationWeights(List<Double> classificationWeights) {
return this;
}

private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
this.setOutputAggregator(outputAggregators.get(0));
}

private void setTargetType(String targetType) {
this.targetType = TargetType.fromString(targetType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(b
ObjectParser<TrainedModelDefinition.Builder, Void> parser = new ObjectParser<>(NAME,
ignoreUnknownFields,
TrainedModelDefinition.Builder::builderForParser);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
parser.declareNamedObject(TrainedModelDefinition.Builder::setTrainedModel,
(p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
(modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
TRAINED_MODEL);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
(p, c, n) -> ignoreUnknownFields ?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,10 @@ private static ObjectParser<Ensemble.Builder, Void> createParser(boolean lenient
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
(ensembleBuilder) -> ensembleBuilder.setModelsAreOrdered(true),
TRAINED_MODELS);
parser.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
parser.declareNamedObject(Ensemble.Builder::setOutputAggregator,
(p, c, n) ->
lenient ? p.namedObject(LenientlyParsedOutputAggregator.class, n, null) :
p.namedObject(StrictlyParsedOutputAggregator.class, n, null),
(ensembleBuilder) -> {/*Noop as it could be an array or object, it just has to be a one*/},
AGGREGATE_OUTPUT);
parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
Expand Down Expand Up @@ -414,14 +413,6 @@ public Builder setClassificationWeights(List<Double> classificationWeights) {
return this;
}

private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
if (outputAggregators.size() != 1) {
throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.",
AGGREGATE_OUTPUT.getPreferredName());
}
this.setOutputAggregator(outputAggregators.get(0));
}

private void setTargetType(String targetType) {
this.targetType = TargetType.fromString(targetType);
}
Expand Down

0 comments on commit 2b63573

Please sign in to comment.