diff --git a/docs/reference/ingest/processors/inference.asciidoc b/docs/reference/ingest/processors/inference.asciidoc index 619be794a728b..bad46ca3abb6f 100644 --- a/docs/reference/ingest/processors/inference.asciidoc +++ b/docs/reference/ingest/processors/inference.asciidoc @@ -71,6 +71,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-results-field] (Optional, string) include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field] +`prediction_field_type`:: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type] + [discrete] [[inference-processor-config-example]] ==== `inference_config` examples diff --git a/docs/reference/ml/df-analytics/apis/put-inference.asciidoc b/docs/reference/ml/df-analytics/apis/put-inference.asciidoc index e88f861fde147..5521ccdde676a 100644 --- a/docs/reference/ml/df-analytics/apis/put-inference.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-inference.asciidoc @@ -375,6 +375,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num- (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values] +`prediction_field_type`:::: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type] + `results_field`:::: (Optional, string) include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-results-field] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 478e94c2a0800..f24e988cfccda 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -1027,6 +1027,12 @@ Specifies the field to which the top classes are written. Defaults to `top_classes`. end::inference-config-classification-top-classes-results-field[] +tag::inference-config-classification-prediction-field-type[] +Specifies the type of the predicted field to write. +Acceptable values are: `string`, `number`, `boolean`. When `boolean` is provided +`1.0` is transformed to `true` and `0.0` to `false`. +end::inference-config-classification-prediction-field-type[] + tag::inference-config-regression-num-top-feature-importance-values[] Specifies the maximum number of {ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index a7eefe199d4a5..93ad7e7d85d4e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.mapper.FieldAliasMapper; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -236,7 +237,7 @@ public Map getParams(FieldInfo fieldInfo) { if (predictionFieldName != null) { params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } - String predictionFieldType = getPredictionFieldType(fieldInfo.getTypes(dependentVariable)); + String predictionFieldType = getPredictionFieldTypeParamString(getPredictionFieldType(fieldInfo.getTypes(dependentVariable))); if (predictionFieldType != null) { params.put(PREDICTION_FIELD_TYPE, predictionFieldType); } @@ -245,19 +246,36 @@ public Map getParams(FieldInfo fieldInfo) { return params; } - private static String getPredictionFieldType(Set dependentVariableTypes) { + private static String getPredictionFieldTypeParamString(PredictionFieldType predictionFieldType) { + if (predictionFieldType == null) { + return null; + } + switch(predictionFieldType) + { + case NUMBER: + // C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers. + return "int"; + case STRING: + return "string"; + case BOOLEAN: + return "bool"; + default: + return null; + } + } + + public static PredictionFieldType getPredictionFieldType(Set dependentVariableTypes) { if (dependentVariableTypes == null) { return null; } if (Types.categorical().containsAll(dependentVariableTypes)) { - return "string"; + return PredictionFieldType.STRING; } if (Types.bool().containsAll(dependentVariableTypes)) { - return "bool"; + return PredictionFieldType.BOOLEAN; } if (Types.discreteNumerical().containsAll(dependentVariableTypes)) { - // C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers. - return "int"; + return PredictionFieldType.NUMBER; } return null; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index d196506adc187..c3f9ef145e8a8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -12,6 +13,7 @@ import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -30,6 +32,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults private final String resultsField; private final String classificationLabel; private final List topClasses; + private final PredictionFieldType predictionFieldType; public ClassificationInferenceResults(double value, String classificationLabel, @@ -58,6 +61,7 @@ private ClassificationInferenceResults(double value, this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); this.topNumClassesField = classificationConfig.getTopClassesResultsField(); this.resultsField = classificationConfig.getResultsField(); + this.predictionFieldType = classificationConfig.getPredictionFieldType(); } public ClassificationInferenceResults(StreamInput in) throws IOException { @@ -66,6 +70,11 @@ public ClassificationInferenceResults(StreamInput in) throws IOException { this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); this.topNumClassesField = in.readString(); this.resultsField = in.readString(); + if (in.getVersion().onOrAfter(Version.V_7_8_0)) { + this.predictionFieldType = in.readEnum(PredictionFieldType.class); + } else { + this.predictionFieldType = PredictionFieldType.STRING; + } } public String getClassificationLabel() { @@ -83,6 +92,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(topClasses); out.writeString(topNumClassesField); out.writeString(resultsField); + if (out.getVersion().onOrAfter(Version.V_7_8_0)) { + out.writeEnum(predictionFieldType); + } } @Override @@ -95,12 +107,19 @@ public boolean equals(Object object) { && Objects.equals(resultsField, that.resultsField) && Objects.equals(topNumClassesField, that.topNumClassesField) && Objects.equals(topClasses, that.topClasses) + && Objects.equals(predictionFieldType, that.predictionFieldType) && Objects.equals(getFeatureImportance(), that.getFeatureImportance()); } @Override public int hashCode() { - return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance()); + return Objects.hash(value(), + classificationLabel, + topClasses, + resultsField, + topNumClassesField, + getFeatureImportance(), + predictionFieldType); } @Override @@ -112,7 +131,8 @@ public String valueAsString() { public void writeResult(IngestDocument document, String parentResultField) { ExceptionsHelper.requireNonNull(document, "document"); ExceptionsHelper.requireNonNull(parentResultField, "parentResultField"); - document.setFieldValue(parentResultField + "." + this.resultsField, valueAsString()); + document.setFieldValue(parentResultField + "." + this.resultsField, + predictionFieldType.transformPredictedValue(value(), valueAsString())); if (topClasses.size() > 0) { document.setFieldValue(parentResultField + "." + topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())); @@ -130,34 +150,33 @@ public String getWriteableName() { return NAME; } - public static class TopClassEntry implements Writeable { public final ParseField CLASS_NAME = new ParseField("class_name"); public final ParseField CLASS_PROBABILITY = new ParseField("class_probability"); public final ParseField CLASS_SCORE = new ParseField("class_score"); - private final String classification; + private final Object classification; private final double probability; private final double score; - public TopClassEntry(String classification, double probability) { - this(classification, probability, probability); - } - - public TopClassEntry(String classification, double probability, double score) { + public TopClassEntry(Object classification, double probability, double score) { this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME); this.probability = probability; this.score = score; } public TopClassEntry(StreamInput in) throws IOException { - this.classification = in.readString(); + if (in.getVersion().onOrAfter(Version.V_7_8_0)) { + this.classification = in.readGenericValue(); + } else { + this.classification = in.readString(); + } this.probability = in.readDouble(); this.score = in.readDouble(); } - public String getClassification() { + public Object getClassification() { return classification; } @@ -179,7 +198,11 @@ public Map asValueMap() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(classification); + if (out.getVersion().onOrAfter(Version.V_7_8_0)) { + out.writeGenericValue(classification); + } else { + out.writeString(classification.toString()); + } out.writeDouble(probability); out.writeDouble(score); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index 82749c27ce9e4..b25e78677b327 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -27,15 +27,17 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field"); public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); + public static final ParseField PREDICTION_FIELD_TYPE = new ParseField("prediction_field_type"); private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0; public static ClassificationConfig EMPTY_PARAMS = - new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null); + new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null, null); private final int numTopClasses; private final String topClassesResultsField; private final String resultsField; private final int numTopFeatureImportanceValues; + private final PredictionFieldType predictionFieldType; private static final ObjectParser LENIENT_PARSER = createParser(true); private static final ObjectParser STRICT_PARSER = createParser(false); @@ -49,6 +51,17 @@ private static ObjectParser createParser(boo parser.declareString(ClassificationConfig.Builder::setResultsField, RESULTS_FIELD); parser.declareString(ClassificationConfig.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD); parser.declareInt(ClassificationConfig.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); + parser.declareField(ClassificationConfig.Builder::setPredictionFieldType, + (p, c) -> { + try { + return PredictionFieldType.fromString(p.text()); + } catch (IllegalArgumentException iae) { + if (lenient) { + return PredictionFieldType.STRING; + } + throw iae; + } + }, PREDICTION_FIELD_TYPE, ObjectParser.ValueType.STRING); return parser; } @@ -61,14 +74,14 @@ public static ClassificationConfig fromXContentLenient(XContentParser parser) { } public ClassificationConfig(Integer numTopClasses) { - this(numTopClasses, null, null, null); + this(numTopClasses, null, null, null, null); } - public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField) { - this(numTopClasses, resultsField, topClassesResultsField, 0); - } - - public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) { + public ClassificationConfig(Integer numTopClasses, + String resultsField, + String topClassesResultsField, + Integer featureImportance, + PredictionFieldType predictionFieldType) { this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULTS_FIELD : topClassesResultsField; this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField; @@ -77,6 +90,7 @@ public ClassificationConfig(Integer numTopClasses, String resultsField, String t "] must be greater than or equal to 0"); } this.numTopFeatureImportanceValues = featureImportance == null ? 0 : featureImportance; + this.predictionFieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType; } public ClassificationConfig(StreamInput in) throws IOException { @@ -88,6 +102,11 @@ public ClassificationConfig(StreamInput in) throws IOException { } else { this.numTopFeatureImportanceValues = 0; } + if (in.getVersion().onOrAfter(Version.V_7_8_0)) { + this.predictionFieldType = PredictionFieldType.fromStream(in); + } else { + this.predictionFieldType = PredictionFieldType.STRING; + } } public int getNumTopClasses() { @@ -106,6 +125,10 @@ public int getNumTopFeatureImportanceValues() { return numTopFeatureImportanceValues; } + public PredictionFieldType getPredictionFieldType() { + return predictionFieldType; + } + @Override public boolean requestingImportance() { return numTopFeatureImportanceValues > 0; @@ -119,6 +142,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_7_7_0)) { out.writeVInt(numTopFeatureImportanceValues); } + if (out.getVersion().onOrAfter(Version.V_7_8_0)) { + predictionFieldType.writeTo(out); + } } @Override @@ -129,12 +155,13 @@ public boolean equals(Object o) { return Objects.equals(numTopClasses, that.numTopClasses) && Objects.equals(topClassesResultsField, that.topClassesResultsField) && Objects.equals(resultsField, that.resultsField) - && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) + && Objects.equals(predictionFieldType, that.predictionFieldType); } @Override public int hashCode() { - return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues); + return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues, predictionFieldType); } @Override @@ -144,6 +171,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField); builder.field(RESULTS_FIELD.getPreferredName(), resultsField); builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + builder.field(PREDICTION_FIELD_TYPE.getPreferredName(), predictionFieldType.toString()); builder.endObject(); return builder; } @@ -176,6 +204,7 @@ public static class Builder { private Integer numTopClasses; private String topClassesResultsField; private String resultsField; + private PredictionFieldType predictionFieldType; private Integer numTopFeatureImportanceValues; Builder() {} @@ -207,8 +236,17 @@ public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceV return this; } + public Builder setPredictionFieldType(PredictionFieldType predictionFieldType) { + this.predictionFieldType = predictionFieldType; + return this; + } + public ClassificationConfig build() { - return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues); + return new ClassificationConfig(numTopClasses, + resultsField, + topClassesResultsField, + numTopFeatureImportanceValues, + predictionFieldType); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java index 2df5678eb5e4a..d06bd529f6713 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java @@ -20,6 +20,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.NUM_TOP_CLASSES; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.PREDICTION_FIELD_TYPE; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.RESULTS_FIELD; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.TOP_CLASSES_RESULTS_FIELD; @@ -28,12 +29,13 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate map) { Map options = new HashMap<>(map); @@ -41,18 +43,24 @@ public static ClassificationConfigUpdate fromMap(Map map) { String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName()); String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + String predictionFieldTypeStr = (String)options.remove(PREDICTION_FIELD_TYPE.getPreferredName()); if (options.isEmpty() == false) { throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); } - return new ClassificationConfigUpdate(numTopClasses, resultsField, topClassesResultsField, featureImportance); + return new ClassificationConfigUpdate(numTopClasses, + resultsField, + topClassesResultsField, + featureImportance, + predictionFieldTypeStr == null ? null : PredictionFieldType.fromString(predictionFieldTypeStr)); } public static ClassificationConfigUpdate fromConfig(ClassificationConfig config) { return new ClassificationConfigUpdate(config.getNumTopClasses(), config.getResultsField(), config.getTopClassesResultsField(), - config.getNumTopFeatureImportanceValues()); + config.getNumTopFeatureImportanceValues(), + config.getPredictionFieldType()); } private static final ObjectParser STRICT_PARSER = createParser(false); @@ -66,6 +74,7 @@ private static ObjectParser createPars parser.declareString(ClassificationConfigUpdate.Builder::setResultsField, RESULTS_FIELD); parser.declareString(ClassificationConfigUpdate.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD); parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); + parser.declareString(ClassificationConfigUpdate.Builder::setPredictionFieldType, PREDICTION_FIELD_TYPE); return parser; } @@ -76,7 +85,8 @@ public static ClassificationConfigUpdate fromXContentStrict(XContentParser parse public ClassificationConfigUpdate(Integer numTopClasses, String resultsField, String topClassesResultsField, - Integer featureImportance) { + Integer featureImportance, + PredictionFieldType predictionFieldType) { this.numTopClasses = numTopClasses; this.topClassesResultsField = topClassesResultsField; this.resultsField = resultsField; @@ -85,6 +95,7 @@ public ClassificationConfigUpdate(Integer numTopClasses, "] must be greater than or equal to 0"); } this.numTopFeatureImportanceValues = featureImportance; + this.predictionFieldType = predictionFieldType; } public ClassificationConfigUpdate(StreamInput in) throws IOException { @@ -92,6 +103,7 @@ public ClassificationConfigUpdate(StreamInput in) throws IOException { this.topClassesResultsField = in.readOptionalString(); this.resultsField = in.readOptionalString(); this.numTopFeatureImportanceValues = in.readOptionalVInt(); + this.predictionFieldType = in.readOptionalWriteable(PredictionFieldType::fromStream); } public Integer getNumTopClasses() { @@ -110,12 +122,17 @@ public Integer getNumTopFeatureImportanceValues() { return numTopFeatureImportanceValues; } + public PredictionFieldType getPredictionFieldType() { + return predictionFieldType; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeOptionalInt(numTopClasses); out.writeOptionalString(topClassesResultsField); out.writeOptionalString(resultsField); out.writeOptionalVInt(numTopFeatureImportanceValues); + out.writeOptionalWriteable(predictionFieldType); } @Override @@ -126,12 +143,13 @@ public boolean equals(Object o) { return Objects.equals(numTopClasses, that.numTopClasses) && Objects.equals(topClassesResultsField, that.topClassesResultsField) && Objects.equals(resultsField, that.resultsField) - && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) + && Objects.equals(predictionFieldType, that.predictionFieldType); } @Override public int hashCode() { - return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues); + return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues, predictionFieldType); } @Override @@ -149,6 +167,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (numTopFeatureImportanceValues != null) { builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); } + if (predictionFieldType != null) { + builder.field(PREDICTION_FIELD_TYPE.getPreferredName(), predictionFieldType.toString()); + } builder.endObject(); return builder; } @@ -181,6 +202,9 @@ public ClassificationConfig apply(ClassificationConfig originalConfig) { if (numTopClasses != null) { builder.setNumTopClasses(numTopClasses); } + if (predictionFieldType != null) { + builder.setPredictionFieldType(predictionFieldType); + } return builder.build(); } @@ -199,7 +223,8 @@ boolean isNoop(ClassificationConfig originalConfig) { && (numTopFeatureImportanceValues == null || originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues) && (topClassesResultsField == null || topClassesResultsField.equals(originalConfig.getTopClassesResultsField())) - && (numTopClasses == null || originalConfig.getNumTopClasses() == numTopClasses); + && (numTopClasses == null || originalConfig.getNumTopClasses() == numTopClasses) + && (predictionFieldType == null || predictionFieldType.equals(originalConfig.getPredictionFieldType())); } public static class Builder { @@ -207,6 +232,7 @@ public static class Builder { private String topClassesResultsField; private String resultsField; private Integer numTopFeatureImportanceValues; + private PredictionFieldType predictionFieldType; public Builder setNumTopClasses(int numTopClasses) { this.numTopClasses = numTopClasses; @@ -228,8 +254,17 @@ public Builder setNumTopFeatureImportanceValues(int numTopFeatureImportanceValue return this; } + private Builder setPredictionFieldType(String predictionFieldType) { + this.predictionFieldType = PredictionFieldType.fromString(predictionFieldType); + return this; + } + public ClassificationConfigUpdate build() { - return new ClassificationConfigUpdate(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues); + return new ClassificationConfigUpdate(numTopClasses, + resultsField, + topClassesResultsField, + numTopFeatureImportanceValues, + predictionFieldType); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 05b0693ac04c5..4fdc42568d914 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -31,7 +31,8 @@ private InferenceHelpers() { } public static Tuple> topClasses(double[] probabilities, List classificationLabels, @Nullable double[] classificationWeights, - int numToInclude) { + int numToInclude, + PredictionFieldType predictionFieldType) { if (classificationLabels != null && probabilities.length != classificationLabels.size()) { throw ExceptionsHelper @@ -67,7 +68,10 @@ public static Tuple> List topClassEntries = new ArrayList<>(count); for(int i = 0; i < count; i++) { int idx = sortedIndices[i]; - topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities[idx], scores[idx])); + topClassEntries.add(new ClassificationInferenceResults.TopClassEntry( + predictionFieldType.transformPredictedValue((double)idx, labels.get(idx)), + probabilities[idx], + scores[idx])); } return Tuple.tuple(sortedIndices[0], topClassEntries); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldType.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldType.java new file mode 100644 index 0000000000000..88dfd211403a3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldType.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +/** + * The type of the prediction field. + * This modifies how the predicted class values are written for classification models + */ +public enum PredictionFieldType implements Writeable { + + STRING, + NUMBER, + BOOLEAN; + + private static final double EPS = 1.0E-9; + + public static PredictionFieldType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static PredictionFieldType fromStream(StreamInput in) throws IOException { + return in.readEnum(PredictionFieldType.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + public Object transformPredictedValue(Double value, String stringRep) { + if (value == null) { + return null; + } + switch(this) { + case STRING: + return stringRep == null ? value.toString() : stringRep; + case BOOLEAN: + if ((areClose(value, 1.0D) || areClose(value, 0.0D)) == false) { + throw new IllegalArgumentException( + "Cannot transform numbers other than 0.0 or 1.0 to boolean. Provided number [" + value + "]"); + } + return areClose(value, 1.0D); + case NUMBER: + default: + return value; + } + } + + private static boolean areClose(double value1, double value2) { + return Math.abs(value1 - value2) < EPS; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 7ee3ef1668b82..8799b56ff4cb2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -186,7 +186,8 @@ private InferenceResults buildResults(double[] processedInferences, processedInferences, classificationLabels, classificationWeights, - classificationConfig.getNumTopClasses()); + classificationConfig.getNumTopClasses(), + classificationConfig.getPredictionFieldType()); return new ClassificationInferenceResults((double)topClasses.v1(), classificationLabel(topClasses.v1(), classificationLabels), topClasses.v2(), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java index 5d31caba2eff1..86fe580646a0f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -136,7 +137,8 @@ public InferenceResults infer(Map fields, InferenceConfig config probabilities, LANGUAGE_NAMES, null, - classificationConfig.getNumTopClasses()); + classificationConfig.getNumTopClasses(), + PredictionFieldType.STRING); assert topClasses.v1() >= 0 && topClasses.v1() < LANGUAGE_NAMES.size() : "Invalid language predicted. Predicted language index " + topClasses.v1(); return new ClassificationInferenceResults(topClasses.v1(), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index ddb1d6027ef6e..9565a393fc64d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -162,7 +162,8 @@ private InferenceResults buildResult(double[] value, Map featu classificationProbability(value), classificationLabels, null, - classificationConfig.getNumTopClasses()); + classificationConfig.getNumTopClasses(), + classificationConfig.getPredictionFieldType()); return new ClassificationInferenceResults(topClasses.v1(), classificationLabel(topClasses.v1(), classificationLabels), topClasses.v2(), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java index e6ea0398dbc83..e98a48ea4f4d8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import java.util.Arrays; import java.util.Collections; @@ -44,7 +45,7 @@ public static ClassificationInferenceResults createRandomResults() { } private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() { - return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble()); + return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble(), randomDouble()); } public void testWriteResultsWithClassificationLabel() { @@ -70,13 +71,13 @@ public void testWriteResultsWithoutClassificationLabel() { @SuppressWarnings("unchecked") public void testWriteResultsWithTopClasses() { List entries = Arrays.asList( - new ClassificationInferenceResults.TopClassEntry("foo", 0.7), - new ClassificationInferenceResults.TopClassEntry("bar", 0.2), - new ClassificationInferenceResults.TopClassEntry("baz", 0.1)); + new ClassificationInferenceResults.TopClassEntry("foo", 0.7, 0.7), + new ClassificationInferenceResults.TopClassEntry("bar", 0.2, 0.2), + new ClassificationInferenceResults.TopClassEntry("baz", 0.1, 0.1)); ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, "foo", entries, - new ClassificationConfig(3, "my_results", "bar")); + new ClassificationConfig(3, "my_results", "bar", null, PredictionFieldType.STRING)); IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); result.writeResult(document, "result_field"); @@ -103,7 +104,7 @@ public void testWriteResultsWithImportance() { "foo", Collections.emptyList(), importanceList, - new ClassificationConfig(0, "predicted_value", "top_classes", 3)); + new ClassificationConfig(0, "predicted_value", "top_classes", 3, PredictionFieldType.STRING)); IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); result.writeResult(document, "result_field"); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java index 8c00e38a81ed8..8ffd1d7ff858a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java @@ -21,7 +21,9 @@ public class ClassificationConfigTests extends AbstractBWCSerializationTestCase< public static ClassificationConfig randomClassificationConfig() { return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10), randomBoolean() ? null : randomAlphaOfLength(10), - randomBoolean() ? null : randomAlphaOfLength(10) + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomIntBetween(0, 10), + randomFrom(PredictionFieldType.values()) ); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdateTests.java index ccbedd841f074..ff8bbe132511c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdateTests.java @@ -24,20 +24,22 @@ public static ClassificationConfigUpdate randomClassificationConfig() { return new ClassificationConfigUpdate(randomBoolean() ? null : randomIntBetween(-1, 10), randomBoolean() ? null : randomAlphaOfLength(10), randomBoolean() ? null : randomAlphaOfLength(10), - randomBoolean() ? null : randomIntBetween(0, 10) + randomBoolean() ? null : randomIntBetween(0, 10), + randomBoolean() ? null : randomFrom(PredictionFieldType.values()) ); } public void testFromMap() { - ClassificationConfigUpdate expected = new ClassificationConfigUpdate(null, null, null, null); + ClassificationConfigUpdate expected = ClassificationConfigUpdate.EMPTY_PARAMS; assertThat(ClassificationConfigUpdate.fromMap(Collections.emptyMap()), equalTo(expected)); - expected = new ClassificationConfigUpdate(3, "foo", "bar", 2); + expected = new ClassificationConfigUpdate(3, "foo", "bar", 2, PredictionFieldType.NUMBER); Map configMap = new HashMap<>(); configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3); configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo"); configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar"); configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2); + configMap.put(ClassificationConfig.PREDICTION_FIELD_TYPE.getPreferredName(), PredictionFieldType.NUMBER.toString()); assertThat(ClassificationConfigUpdate.fromMap(configMap), equalTo(expected)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldTypeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldTypeTests.java new file mode 100644 index 0000000000000..ff8dd83cde2ee --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldTypeTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.nullValue; + +public class PredictionFieldTypeTests extends ESTestCase { + + public void testTransformPredictedValueBoolean() { + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(null, randomBoolean() ? null : randomAlphaOfLength(10)), + is(nullValue())); + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, randomBoolean() ? null : randomAlphaOfLength(10)), + is(true)); + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, randomBoolean() ? null : randomAlphaOfLength(10)), + is(false)); + expectThrows(IllegalArgumentException.class, + () -> PredictionFieldType.BOOLEAN.transformPredictedValue(0.1, randomBoolean() ? null : randomAlphaOfLength(10))); + expectThrows(IllegalArgumentException.class, + () -> PredictionFieldType.BOOLEAN.transformPredictedValue(1.1, randomBoolean() ? null : randomAlphaOfLength(10))); + } + + public void testTransformPredictedValueString() { + assertThat(PredictionFieldType.STRING.transformPredictedValue(null, randomBoolean() ? null : randomAlphaOfLength(10)), + is(nullValue())); + assertThat(PredictionFieldType.STRING.transformPredictedValue(1.0, "foo"), equalTo("foo")); + assertThat(PredictionFieldType.STRING.transformPredictedValue(1.0, null), equalTo("1.0")); + } + + public void testTransformPredictedValueNumber() { + assertThat(PredictionFieldType.NUMBER.transformPredictedValue(null, randomBoolean() ? null : randomAlphaOfLength(10)), + is(nullValue())); + assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, "foo"), equalTo(1.0)); + assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, null), equalTo(1.0)); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index c09b38d3bcd67..5b3f5b1078e98 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -47,6 +48,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -243,9 +245,11 @@ private InferenceConfig buildInferenceConfig(TargetType targetType) { case CLASSIFICATION: assert analytics.getAnalysis() instanceof Classification; Classification classification = ((Classification)analytics.getAnalysis()); + PredictionFieldType predictionFieldType = getPredictionFieldType(classification); return ClassificationConfig.builder() .setNumTopClasses(classification.getNumTopClasses()) .setNumTopFeatureImportanceValues(classification.getBoostedTreeParams().getNumTopFeatureImportanceValues()) + .setPredictionFieldType(predictionFieldType) .build(); case REGRESSION: assert analytics.getAnalysis() instanceof Regression; @@ -254,14 +258,24 @@ private InferenceConfig buildInferenceConfig(TargetType targetType) { .setNumTopFeatureImportanceValues(regression.getBoostedTreeParams().getNumTopFeatureImportanceValues()) .build(); default: - setAndReportFailure(ExceptionsHelper.serverError( + throw ExceptionsHelper.serverError( "process created a model with an unsupported target type [{}]", null, - targetType)); - return null; + targetType); } } + PredictionFieldType getPredictionFieldType(Classification classification) { + String dependentVariable = classification.getDependentVariable(); + Optional extractedField = fieldNames.stream() + .filter(f -> f.getName().equals(dependentVariable)) + .findAny(); + PredictionFieldType predictionFieldType = Classification.getPredictionFieldType( + extractedField.isPresent() ? extractedField.get().getTypes() : null + ); + return predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType; + } + private String getDependentVariable() { if (analytics.getAnalysis() instanceof Classification) { return ((Classification)analytics.getAnalysis()).getDependentVariable(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 1d9c748f37670..7955e43111439 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -15,11 +15,13 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; @@ -42,6 +44,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; @@ -210,6 +213,19 @@ public void testProcess_GivenInferenceModelIsStoredSuccessfully() { Mockito.verifyNoMoreInteractions(auditor); } + public void testGetPredictionFieldType() { + List extractedFieldList = Arrays.asList( + new DocValueField("foo", Collections.emptySet()), + new DocValueField("bar", Set.of("keyword")), + new DocValueField("baz", Set.of("long")), + new DocValueField("bingo", Set.of("boolean"))); + AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); + assertThat(resultProcessor.getPredictionFieldType(new Classification("foo")), equalTo(PredictionFieldType.STRING)); + assertThat(resultProcessor.getPredictionFieldType(new Classification("bar")), equalTo(PredictionFieldType.STRING)); + assertThat(resultProcessor.getPredictionFieldType(new Classification("baz")), equalTo(PredictionFieldType.NUMBER)); + assertThat(resultProcessor.getPredictionFieldType(new Classification("bingo")), equalTo(PredictionFieldType.BOOLEAN)); + } + @SuppressWarnings("unchecked") public void testProcess_GivenInferenceModelFailedToStore() { givenDataFrameRows(0); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 81b842047ffb6..5e25bab0e4898 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -77,8 +78,8 @@ public void testMutateDocumentWithClassification() { @SuppressWarnings("unchecked") public void testMutateDocumentClassificationTopNClasses() { - ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, null); - ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, null); + ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, null, null); + ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, null, PredictionFieldType.STRING); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", @@ -92,8 +93,8 @@ public void testMutateDocumentClassificationTopNClasses() { IngestDocument document = new IngestDocument(source, ingestMetadata); List classes = new ArrayList<>(2); - classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6)); - classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4)); + classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6)); + classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4)); InternalInferModelAction.Response response = new InternalInferModelAction.Response( Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)), @@ -107,8 +108,8 @@ public void testMutateDocumentClassificationTopNClasses() { } public void testMutateDocumentClassificationFeatureInfluence() { - ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2); - ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, 2); + ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2, PredictionFieldType.STRING); + ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, 2, null); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", @@ -122,8 +123,8 @@ public void testMutateDocumentClassificationFeatureInfluence() { IngestDocument document = new IngestDocument(source, ingestMetadata); List classes = new ArrayList<>(2); - classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6)); - classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4)); + classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6)); + classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4)); List featureInfluence = new ArrayList<>(); featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13)); @@ -148,8 +149,8 @@ public void testMutateDocumentClassificationFeatureInfluence() { @SuppressWarnings("unchecked") public void testMutateDocumentClassificationTopNClassesWithSpecificField() { - ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops"); - ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, "result", "tops", null); + ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops", null, PredictionFieldType.STRING); + ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, "result", "tops", null, null); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", @@ -163,8 +164,8 @@ public void testMutateDocumentClassificationTopNClassesWithSpecificField() { IngestDocument document = new IngestDocument(source, ingestMetadata); List classes = new ArrayList<>(2); - classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6)); - classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4)); + classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6)); + classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4)); InternalInferModelAction.Response response = new InternalInferModelAction.Response( Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)), @@ -240,7 +241,7 @@ public void testGenerateRequestWithEmptyMapping() { "my_processor", "my_field", modelId, - new ClassificationConfigUpdate(topNClasses, null, null, null), + new ClassificationConfigUpdate(topNClasses, null, null, null, null), Collections.emptyMap()); Map source = new HashMap(){{ @@ -269,7 +270,7 @@ public void testGenerateWithMapping() { "my_processor", "my_field", modelId, - new ClassificationConfigUpdate(topNClasses, null, null, null), + new ClassificationConfigUpdate(topNClasses, null, null, null, null), fieldMapping); Map source = new HashMap(5){{ @@ -305,7 +306,7 @@ public void testGenerateWithMappingNestedFields() { "my_processor", "my_field", modelId, - new ClassificationConfigUpdate(topNClasses, null, null, null), + new ClassificationConfigUpdate(topNClasses, null, null, null, null), fieldMapping); Map source = new HashMap(5){{ diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index b2127d4923eec..ef8d3d44678b8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; @@ -17,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -74,13 +76,13 @@ public void testClassificationInfer() throws Exception { put("categorical", "dog"); }}; - SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null)); + SingleValueInferenceResults result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), is("0")); assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L)); ClassificationInferenceResults classificationResult = - (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null)); + (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null, null)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L)); @@ -97,30 +99,88 @@ public void testClassificationInfer() throws Exception { Collections.singletonMap("field.foo", "field.foo.keyword"), ClassificationConfig.EMPTY_PARAMS, modelStatsService); - result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null)); + result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), equalTo("not_to_be")); classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, - new ClassificationConfigUpdate(1, null, null, null)); + new ClassificationConfigUpdate(1, null, null, null, null)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L)); classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, - new ClassificationConfigUpdate(2, null, null, null)); + new ClassificationConfigUpdate(2, null, null, null, null)); assertThat(classificationResult.getTopClasses(), hasSize(2)); assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L)); classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, - new ClassificationConfigUpdate(-1, null, null, null)); + new ClassificationConfigUpdate(-1, null, null, null, null)); assertThat(classificationResult.getTopClasses(), hasSize(2)); assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L)); } + @SuppressWarnings("unchecked") + public void testClassificationInferWithDifferentPredictionFieldTypes() throws Exception { + TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class); + doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class)); + String modelId = "classification_model"; + List inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical"); + TrainedModelDefinition definition = new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildClassification(true)) + .build(); + + Model model = new LocalModel<>(modelId, + "test-node", + definition, + new TrainedModelInput(inputFields), + Collections.singletonMap("field.foo", "field.foo.keyword"), + ClassificationConfig.EMPTY_PARAMS, + modelStatsService); + Map fields = new HashMap<>() {{ + put("field.foo", 1.0); + put("field.bar", 0.5); + put("categorical", "dog"); + }}; + + InferenceResults result = getInferenceResult( + model, + fields, + new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.STRING)); + + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("not_to_be")); + List list = document.getFieldValue("result_field.top_classes", List.class); + assertThat(list.size(), equalTo(2)); + assertThat(((Map)list.get(0)).get("class_name"), equalTo("not_to_be")); + assertThat(((Map)list.get(1)).get("class_name"), equalTo("to_be")); + + result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.NUMBER)); + + document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.0)); + list = document.getFieldValue("result_field.top_classes", List.class); + assertThat(list.size(), equalTo(2)); + assertThat(((Map)list.get(0)).get("class_name"), equalTo(0.0)); + assertThat(((Map)list.get(1)).get("class_name"), equalTo(1.0)); + + result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.BOOLEAN)); + + document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + assertThat(document.getFieldValue("result_field.predicted_value", Boolean.class), equalTo(false)); + list = document.getFieldValue("result_field.top_classes", List.class); + assertThat(list.size(), equalTo(2)); + assertThat(((Map)list.get(0)).get("class_name"), equalTo(false)); + assertThat(((Map)list.get(1)).get("class_name"), equalTo(true)); + } + public void testRegression() throws Exception { TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class); doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class)); @@ -201,9 +261,9 @@ public void testInferPersistsStatsAfterNumberOfCalls() throws Exception { }}; for(int i = 0; i < 100; i++) { - getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null)); + getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS); } - SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null)); + SingleValueInferenceResults result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), is("0")); // Should have reset after persistence, so only 2 docs have been seen since last persistence diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 1749d00f82191..a47ad91e8c5e3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -168,7 +168,7 @@ public void testInferModels() throws Exception { contains("not_to_be", "to_be")); // Get top classes - request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null), true); + request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null, null), true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); ClassificationInferenceResults classificationInferenceResults = @@ -187,7 +187,7 @@ public void testInferModels() throws Exception { greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); // Test that top classes restrict the number returned - request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfigUpdate(1, null, null, null), true); + request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfigUpdate(1, null, null, null, null), true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); @@ -281,7 +281,7 @@ public void testInferModelMultiClassModel() throws Exception { // Get top classes - request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null), true); + request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null, null), true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); ClassificationInferenceResults classificationInferenceResults =