Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] updating feature_importance results mapping #61104

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
@Override
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
Map<String, Object> additionalProperties = new HashMap<>();
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.classificationFeatureImportanceMapping());
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
if ((dependentVariableMapping instanceof Map) == false) {
return additionalProperties;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,46 @@

final class MapUtils {

private static final Map<String, Object> FEATURE_IMPORTANCE_MAPPING;
static {
Map<String, Object> featureImportanceMappingProperties = new HashMap<>();
private static Map<String, Object> createFeatureImportanceMapping(Map<String, Object> featureImportanceMappingProperties){
featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
featureImportanceMappingProperties.put("importance",
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
Map<String, Object> featureImportanceMapping = new HashMap<>();
// TODO sorted indices don't support nested types
//featureImportanceMapping.put("dynamic", true);
//featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
featureImportanceMapping.put("properties", featureImportanceMappingProperties);
FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(featureImportanceMapping);
return featureImportanceMapping;
}

private static final Map<String, Object> CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING;
static {
Map<String, Object> classImportancePropertiesMapping = new HashMap<>();
// TODO sorted indices don't support nested types
//classImportancePropertiesMapping.put("dynamic", true);
//classImportancePropertiesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
classImportancePropertiesMapping.put("class_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
classImportancePropertiesMapping.put("importance",
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
Map<String, Object> featureImportancePropertiesMapping = new HashMap<>();
featureImportancePropertiesMapping.put("classes", Collections.singletonMap("properties", classImportancePropertiesMapping));
CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING =
Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping));
}

private static final Map<String, Object> REGRESSION_FEATURE_IMPORTANCE_MAPPING;
static {
Map<String, Object> featureImportancePropertiesMapping = new HashMap<>();
featureImportancePropertiesMapping.put("importance",
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
REGRESSION_FEATURE_IMPORTANCE_MAPPING =
Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping));
}

static Map<String, Object> regressionFeatureImportanceMapping() {
return REGRESSION_FEATURE_IMPORTANCE_MAPPING;
}

static Map<String, Object> featureImportanceMapping() {
return FEATURE_IMPORTANCE_MAPPING;
static Map<String, Object> classificationFeatureImportanceMapping() {
return CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING;
}

private MapUtils() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
@Override
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
Map<String, Object> additionalProperties = new HashMap<>();
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.regressionFeatureImportanceMapping());
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
// high (over 10M) values of dependent variable.
additionalProperties.put(resultsFieldName + "." + predictionFieldName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -16,65 +17,74 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class FeatureImportance implements Writeable, ToXContentObject {

private final Map<String, Double> classImportance;
private final List<ClassImportance> classImportance;
private final double importance;
private final String featureName;
static final String IMPORTANCE = "importance";
static final String FEATURE_NAME = "feature_name";
static final String CLASS_IMPORTANCE = "class_importance";
static final String CLASSES = "classes";

public static FeatureImportance forRegression(String featureName, double importance) {
return new FeatureImportance(featureName, importance, null);
}

public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName,
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
classImportance);
}

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance",
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2])
);

static {
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
new ParseField(FeatureImportance.CLASS_IMPORTANCE));
PARSER.declareObjectArray(optionalConstructorArg(),
(p, c) -> ClassImportance.fromXContent(p),
new ParseField(FeatureImportance.CLASSES));
}

public static FeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
}

public FeatureImportance(StreamInput in) throws IOException {
this.featureName = in.readString();
this.importance = in.readDouble();
if (in.readBoolean()) {
this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
if (in.getVersion().before(Version.V_7_10_0)) {
Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
this.classImportance = ClassImportance.fromMap(classImportance);
} else {
this.classImportance = in.readList(ClassImportance::new);
}
} else {
this.classImportance = null;
}
}

public Map<String, Double> getClassImportance() {
public List<ClassImportance> getClassImportance() {
return classImportance;
}

Expand All @@ -92,7 +102,11 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(this.importance);
out.writeBoolean(this.classImportance != null);
if (this.classImportance != null) {
out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
if (out.getVersion().before(Version.V_7_10_0)) {
out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
} else {
out.writeList(this.classImportance);
}
}
}

Expand All @@ -101,7 +115,7 @@ public Map<String, Object> toMap() {
map.put(FEATURE_NAME, featureName);
map.put(IMPORTANCE, importance);
if (classImportance != null) {
classImportance.forEach(map::put);
map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList()));
}
return map;
}
Expand All @@ -112,11 +126,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(FEATURE_NAME, featureName);
builder.field(IMPORTANCE, importance);
if (classImportance != null && classImportance.isEmpty() == false) {
builder.startObject(CLASS_IMPORTANCE);
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
builder.field(CLASSES, classImportance);
}
builder.endObject();
return builder;
Expand All @@ -136,4 +146,92 @@ public boolean equals(Object object) {
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
}

public static class ClassImportance implements Writeable, ToXContentObject {

static final String CLASS_NAME = "class_name";

private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance_class_importance",
a -> new ClassImportance((String) a[0], (Double) a[1])
);

static {
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
}

private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
return new ClassImportance(entry.getKey(), entry.getValue());
}

private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
}

private static Map<String, Double> toMap(List<ClassImportance> importances) {
return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance));
}

public static ClassImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final String className;
private final double importance;

public ClassImportance(String className, double importance) {
this.className = className;
this.importance = importance;
}

public ClassImportance(StreamInput in) throws IOException {
this.className = in.readString();
this.importance = in.readDouble();
}

public String getClassName() {
return className;
}

public double getImportance() {
return importance;
}

public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(CLASS_NAME, className);
map.put(IMPORTANCE, importance);
return map;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(className);
out.writeDouble(importance);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME, className);
builder.field(IMPORTANCE, importance);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassImportance that = (ClassImportance) o;
return Double.compare(that.importance, importance) == 0 &&
Objects.equals(className, that.className);
}

@Override
public int hashCode() {
return Objects.hash(className, importance);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -139,11 +138,13 @@ public static List<FeatureImportance> transformFeatureImportance(Map<String, dou
if (v.length == 1) {
importances.add(FeatureImportance.forRegression(k, v[0]));
} else {
Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
// If the classificationLabels exist, their length must match leaf_value length
assert classificationLabels == null || classificationLabels.size() == v.length;
for (int i = 0; i < v.length; i++) {
classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
classImportance.add(new FeatureImportance.ClassImportance(
classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i),
v[i]));
}
importances.add(FeatureImportance.forClassification(k, classImportance));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,20 +259,20 @@ public void testFieldCardinalityLimitsIsNonEmpty() {

public void testGetExplicitlyMappedFields() {
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"),
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"),
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
assertThat(
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
"results");
assertThat(explicitlyMappedFields,
allOf(
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()));

explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
new HashMap<>() {{
Expand All @@ -287,7 +287,7 @@ public void testGetExplicitlyMappedFields() {
allOf(
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()));

assertThat(
new Classification("foo").getExplicitlyMappedFields(
Expand All @@ -296,7 +296,7 @@ public void testGetExplicitlyMappedFields() {
put("path", "missing");
}}),
"results"),
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
}

public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ public void testFieldCardinalityLimitsIsEmpty() {
public void testGetExplicitlyMappedFields() {
Map<String, Object> explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results");
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.regressionFeatureImportanceMapping()));
}

public void testGetStateDocId() {
Expand Down
Loading