Skip to content

Commit

Permalink
[ML] adding prediction_field_type to inference config (#55128)
Browse files Browse the repository at this point in the history
Data frame analytics dynamically determines the classification field type. This field type then dictates the encoded JSON that is written to Elasticsearch. 

Inference needs to know about this field type so that it may provide the EXACT SAME predicted values as analytics. 

Here is added a new field `prediction_field_type` which indicates the desired type. Options are: `string` (DEFAULT), `number`, `boolean` (where close_to(1.0) == true, false otherwise). 

Analytics provides the default `prediction_field_type` when the model is created from the process.
  • Loading branch information
benwtrent authored Apr 15, 2020
1 parent da8f411 commit c1afda4
Show file tree
Hide file tree
Showing 21 changed files with 423 additions and 80 deletions.
4 changes: 4 additions & 0 deletions docs/reference/ingest/processors/inference.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field]

`prediction_field_type`::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type]

[discrete]
[[inference-processor-config-example]]
==== `inference_config` examples
Expand Down
4 changes: 4 additions & 0 deletions docs/reference/ml/df-analytics/apis/put-inference.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values]

`prediction_field_type`::::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type]

`results_field`::::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]
Expand Down
6 changes: 6 additions & 0 deletions docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,12 @@ Specifies the field to which the top classes are written. Defaults to
`top_classes`.
end::inference-config-classification-top-classes-results-field[]

tag::inference-config-classification-prediction-field-type[]
Specifies the type of the predicted field to write.
Acceptable values are: `string`, `number`, `boolean`. When `boolean` is provided
`1.0` is transformed to `true` and `0.0` to `false`.
end::inference-config-classification-prediction-field-type[]

tag::inference-config-regression-num-top-feature-importance-values[]
Specifies the maximum number of
{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.FieldAliasMapper;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand Down Expand Up @@ -234,7 +235,7 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
String predictionFieldType = getPredictionFieldType(fieldInfo.getTypes(dependentVariable));
String predictionFieldType = getPredictionFieldTypeParamString(getPredictionFieldType(fieldInfo.getTypes(dependentVariable)));
if (predictionFieldType != null) {
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
}
Expand All @@ -243,19 +244,36 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
return params;
}

private static String getPredictionFieldType(Set<String> dependentVariableTypes) {
private static String getPredictionFieldTypeParamString(PredictionFieldType predictionFieldType) {
if (predictionFieldType == null) {
return null;
}
switch(predictionFieldType)
{
case NUMBER:
// C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers.
return "int";
case STRING:
return "string";
case BOOLEAN:
return "bool";
default:
return null;
}
}

public static PredictionFieldType getPredictionFieldType(Set<String> dependentVariableTypes) {
if (dependentVariableTypes == null) {
return null;
}
if (Types.categorical().containsAll(dependentVariableTypes)) {
return "string";
return PredictionFieldType.STRING;
}
if (Types.bool().containsAll(dependentVariableTypes)) {
return "bool";
return PredictionFieldType.BOOLEAN;
}
if (Types.discreteNumerical().containsAll(dependentVariableTypes)) {
// C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers.
return "int";
return PredictionFieldType.NUMBER;
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand All @@ -30,6 +32,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
private final String resultsField;
private final String classificationLabel;
private final List<TopClassEntry> topClasses;
private final PredictionFieldType predictionFieldType;

public ClassificationInferenceResults(double value,
String classificationLabel,
Expand Down Expand Up @@ -58,6 +61,7 @@ private ClassificationInferenceResults(double value,
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
this.resultsField = classificationConfig.getResultsField();
this.predictionFieldType = classificationConfig.getPredictionFieldType();
}

public ClassificationInferenceResults(StreamInput in) throws IOException {
Expand All @@ -66,6 +70,11 @@ public ClassificationInferenceResults(StreamInput in) throws IOException {
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
this.topNumClassesField = in.readString();
this.resultsField = in.readString();
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.predictionFieldType = in.readEnum(PredictionFieldType.class);
} else {
this.predictionFieldType = PredictionFieldType.STRING;
}
}

public String getClassificationLabel() {
Expand All @@ -83,6 +92,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(topClasses);
out.writeString(topNumClassesField);
out.writeString(resultsField);
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeEnum(predictionFieldType);
}
}

@Override
Expand All @@ -95,12 +107,19 @@ public boolean equals(Object object) {
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(topNumClassesField, that.topNumClassesField)
&& Objects.equals(topClasses, that.topClasses)
&& Objects.equals(predictionFieldType, that.predictionFieldType)
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}

@Override
public int hashCode() {
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance());
return Objects.hash(value(),
classificationLabel,
topClasses,
resultsField,
topNumClassesField,
getFeatureImportance(),
predictionFieldType);
}

@Override
Expand All @@ -112,7 +131,8 @@ public String valueAsString() {
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, valueAsString());
document.setFieldValue(parentResultField + "." + this.resultsField,
predictionFieldType.transformPredictedValue(value(), valueAsString()));
if (topClasses.size() > 0) {
document.setFieldValue(parentResultField + "." + topNumClassesField,
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
Expand All @@ -130,34 +150,33 @@ public String getWriteableName() {
return NAME;
}


public static class TopClassEntry implements Writeable {

public final ParseField CLASS_NAME = new ParseField("class_name");
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
public final ParseField CLASS_SCORE = new ParseField("class_score");

private final String classification;
private final Object classification;
private final double probability;
private final double score;

public TopClassEntry(String classification, double probability) {
this(classification, probability, probability);
}

public TopClassEntry(String classification, double probability, double score) {
public TopClassEntry(Object classification, double probability, double score) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
this.probability = probability;
this.score = score;
}

public TopClassEntry(StreamInput in) throws IOException {
this.classification = in.readString();
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.classification = in.readGenericValue();
} else {
this.classification = in.readString();
}
this.probability = in.readDouble();
this.score = in.readDouble();
}

public String getClassification() {
public Object getClassification() {
return classification;
}

Expand All @@ -179,7 +198,11 @@ public Map<String, Object> asValueMap() {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(classification);
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeGenericValue(classification);
} else {
out.writeString(classification.toString());
}
out.writeDouble(probability);
out.writeDouble(score);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
public static final ParseField PREDICTION_FIELD_TYPE = new ParseField("prediction_field_type");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;

public static ClassificationConfig EMPTY_PARAMS =
new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null);
new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null, null);

private final int numTopClasses;
private final String topClassesResultsField;
private final String resultsField;
private final int numTopFeatureImportanceValues;
private final PredictionFieldType predictionFieldType;

private static final ObjectParser<ClassificationConfig.Builder, Void> LENIENT_PARSER = createParser(true);
private static final ObjectParser<ClassificationConfig.Builder, Void> STRICT_PARSER = createParser(false);
Expand All @@ -49,6 +51,17 @@ private static ObjectParser<ClassificationConfig.Builder, Void> createParser(boo
parser.declareString(ClassificationConfig.Builder::setResultsField, RESULTS_FIELD);
parser.declareString(ClassificationConfig.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
parser.declareInt(ClassificationConfig.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
parser.declareField(ClassificationConfig.Builder::setPredictionFieldType,
(p, c) -> {
try {
return PredictionFieldType.fromString(p.text());
} catch (IllegalArgumentException iae) {
if (lenient) {
return PredictionFieldType.STRING;
}
throw iae;
}
}, PREDICTION_FIELD_TYPE, ObjectParser.ValueType.STRING);
return parser;
}

Expand All @@ -61,14 +74,14 @@ public static ClassificationConfig fromXContentLenient(XContentParser parser) {
}

public ClassificationConfig(Integer numTopClasses) {
this(numTopClasses, null, null, null);
this(numTopClasses, null, null, null, null);
}

public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField) {
this(numTopClasses, resultsField, topClassesResultsField, 0);
}

public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) {
public ClassificationConfig(Integer numTopClasses,
String resultsField,
String topClassesResultsField,
Integer featureImportance,
PredictionFieldType predictionFieldType) {
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULTS_FIELD : topClassesResultsField;
this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
Expand All @@ -77,6 +90,7 @@ public ClassificationConfig(Integer numTopClasses, String resultsField, String t
"] must be greater than or equal to 0");
}
this.numTopFeatureImportanceValues = featureImportance == null ? 0 : featureImportance;
this.predictionFieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
}

public ClassificationConfig(StreamInput in) throws IOException {
Expand All @@ -88,6 +102,11 @@ public ClassificationConfig(StreamInput in) throws IOException {
} else {
this.numTopFeatureImportanceValues = 0;
}
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.predictionFieldType = PredictionFieldType.fromStream(in);
} else {
this.predictionFieldType = PredictionFieldType.STRING;
}
}

public int getNumTopClasses() {
Expand All @@ -106,6 +125,10 @@ public int getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}

public PredictionFieldType getPredictionFieldType() {
return predictionFieldType;
}

@Override
public boolean requestingImportance() {
return numTopFeatureImportanceValues > 0;
Expand All @@ -119,6 +142,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeVInt(numTopFeatureImportanceValues);
}
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
predictionFieldType.writeTo(out);
}
}

@Override
Expand All @@ -129,12 +155,13 @@ public boolean equals(Object o) {
return Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(topClassesResultsField, that.topClassesResultsField)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
&& Objects.equals(predictionFieldType, that.predictionFieldType);
}

@Override
public int hashCode() {
return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues);
return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues, predictionFieldType);
}

@Override
Expand All @@ -144,6 +171,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField);
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
builder.field(PREDICTION_FIELD_TYPE.getPreferredName(), predictionFieldType.toString());
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -176,6 +204,7 @@ public static class Builder {
private Integer numTopClasses;
private String topClassesResultsField;
private String resultsField;
private PredictionFieldType predictionFieldType;
private Integer numTopFeatureImportanceValues;

Builder() {}
Expand Down Expand Up @@ -207,8 +236,17 @@ public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceV
return this;
}

public Builder setPredictionFieldType(PredictionFieldType predictionFieldType) {
this.predictionFieldType = predictionFieldType;
return this;
}

public ClassificationConfig build() {
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues);
return new ClassificationConfig(numTopClasses,
resultsField,
topClassesResultsField,
numTopFeatureImportanceValues,
predictionFieldType);
}
}
}
Loading

0 comments on commit c1afda4

Please sign in to comment.