From 9135cace05f428b19da9ff01f29cb7df164ff19d Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 13 Aug 2020 11:02:00 -0400 Subject: [PATCH 1/3] [ML] updating feature_importance results mapping --- .../ml/dataframe/analyses/Classification.java | 2 +- .../core/ml/dataframe/analyses/MapUtils.java | 40 +++++++++++++++---- .../ml/dataframe/analyses/Regression.java | 2 +- .../analyses/ClassificationTests.java | 12 +++--- .../dataframe/analyses/RegressionTests.java | 2 +- 5 files changed, 41 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index cff761183d70b..98766633c0f3f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -312,7 +312,7 @@ public List getFieldCardinalityConstraints() { @Override public Map getExplicitlyMappedFields(Map mappingsProperties, String resultsFieldName) { Map additionalProperties = new HashMap<>(); - additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping()); + additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.classificationFeatureImportanceMapping()); Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties); if ((dependentVariableMapping instanceof Map) == false) { return additionalProperties; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java index 5440bc850c258..1b47edeeffff0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java @@ -18,22 +18,46 @@ final class MapUtils { - private static final Map FEATURE_IMPORTANCE_MAPPING; - static { - Map featureImportanceMappingProperties = new HashMap<>(); + private static Map createFeatureImportanceMapping(Map featureImportanceMappingProperties){ featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE)); - featureImportanceMappingProperties.put("importance", - Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); Map featureImportanceMapping = new HashMap<>(); // TODO sorted indices don't support nested types //featureImportanceMapping.put("dynamic", true); //featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE); featureImportanceMapping.put("properties", featureImportanceMappingProperties); - FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(featureImportanceMapping); + return featureImportanceMapping; + } + + private static final Map CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING; + static { + Map classImportancePropertiesMapping = new HashMap<>(); + // TODO sorted indices don't support nested types + //classImportancePropertiesMapping.put("dynamic", true); + //classImportancePropertiesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE); + classImportancePropertiesMapping.put("class_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE)); + classImportancePropertiesMapping.put("importance", + Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); + Map featureImportancePropertiesMapping = new HashMap<>(); + featureImportancePropertiesMapping.put("properties", Collections.singletonMap("classes", classImportancePropertiesMapping)); + CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING = + Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping)); + } + + private static final Map REGRESSION_FEATURE_IMPORTANCE_MAPPING; + static { + Map featureImportancePropertiesMapping = new HashMap<>(); + featureImportancePropertiesMapping.put("importance", + Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); + REGRESSION_FEATURE_IMPORTANCE_MAPPING = + Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping)); + } + + static Map regressionFeatureImportanceMapping() { + return REGRESSION_FEATURE_IMPORTANCE_MAPPING; } - static Map featureImportanceMapping() { - return FEATURE_IMPORTANCE_MAPPING; + static Map classificationFeatureImportanceMapping() { + return CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING; } private MapUtils() {} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index f92e660ba733e..352807b78bbf0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -233,7 +233,7 @@ public List getFieldCardinalityConstraints() { @Override public Map getExplicitlyMappedFields(Map mappingsProperties, String resultsFieldName) { Map additionalProperties = new HashMap<>(); - additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping()); + additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.regressionFeatureImportanceMapping()); // Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of // high (over 10M) values of dependent variable. additionalProperties.put(resultsFieldName + "." + predictionFieldName, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 00a8ddf54d09a..d0a3ea2c71886 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -259,12 +259,12 @@ public void testFieldCardinalityLimitsIsNonEmpty() { public void testGetExplicitlyMappedFields() { assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), - equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping()))); + equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()))); assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), - equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping()))); + equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()))); assertThat( new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"), - equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping()))); + equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()))); Map explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields( Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")), "results"); @@ -272,7 +272,7 @@ public void testGetExplicitlyMappedFields() { allOf( hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")), hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz")))); - assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping())); + assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())); explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields( new HashMap<>() {{ @@ -287,7 +287,7 @@ public void testGetExplicitlyMappedFields() { allOf( hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")), hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long")))); - assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping())); + assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())); assertThat( new Classification("foo").getExplicitlyMappedFields( @@ -296,7 +296,7 @@ public void testGetExplicitlyMappedFields() { put("path", "missing"); }}), "results"), - equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping()))); + equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()))); } public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index 29820b883a7a6..0d695d0fbbde4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -206,7 +206,7 @@ public void testFieldCardinalityLimitsIsEmpty() { public void testGetExplicitlyMappedFields() { Map explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results"); assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double"))); - assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping())); + assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.regressionFeatureImportanceMapping())); } public void testGetStateDocId() { From 7c9d4e8ce82e7a24e5dc987b09b66cf98bf5e46e Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 13 Aug 2020 13:50:42 -0400 Subject: [PATCH 2/3] fixing formatting mapping --- .../xpack/core/ml/dataframe/analyses/MapUtils.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java index 1b47edeeffff0..3cc8825944f28 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java @@ -38,7 +38,7 @@ private static Map createFeatureImportanceMapping(Map featureImportancePropertiesMapping = new HashMap<>(); - featureImportancePropertiesMapping.put("properties", Collections.singletonMap("classes", classImportancePropertiesMapping)); + featureImportancePropertiesMapping.put("classes", Collections.singletonMap("properties", classImportancePropertiesMapping)); CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping)); } From e306f4742584a78e2bd23f0ed60e43259ca1df73 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 13 Aug 2020 16:21:34 -0400 Subject: [PATCH 3/3] updating inference format --- .../inference/results/FeatureImportance.java | 136 +++++++++++++++--- .../trainedmodel/InferenceHelpers.java | 7 +- .../ClassificationInferenceResultsTests.java | 11 +- .../results/FeatureImportanceTests.java | 4 +- .../RegressionInferenceResultsTests.java | 2 +- 5 files changed, 133 insertions(+), 27 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java index 1f78ba11e319b..3c1a395a1f779 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -16,65 +17,74 @@ import java.io.IOException; import java.util.Collections; -import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; public class FeatureImportance implements Writeable, ToXContentObject { - private final Map classImportance; + private final List classImportance; private final double importance; private final String featureName; static final String IMPORTANCE = "importance"; static final String FEATURE_NAME = "feature_name"; - static final String CLASS_IMPORTANCE = "class_importance"; + static final String CLASSES = "classes"; public static FeatureImportance forRegression(String featureName, double importance) { return new FeatureImportance(featureName, importance, null); } - public static FeatureImportance forClassification(String featureName, Map classImportance) { - return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance); + public static FeatureImportance forClassification(String featureName, List classImportance) { + return new FeatureImportance(featureName, + classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(), + classImportance); } @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("feature_importance", - a -> new FeatureImportance((String) a[0], (Double) a[1], (Map) a[2]) + a -> new FeatureImportance((String) a[0], (Double) a[1], (List) a[2]) ); static { PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME)); PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); - PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue), - new ParseField(FeatureImportance.CLASS_IMPORTANCE)); + PARSER.declareObjectArray(optionalConstructorArg(), + (p, c) -> ClassImportance.fromXContent(p), + new ParseField(FeatureImportance.CLASSES)); } public static FeatureImportance fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - FeatureImportance(String featureName, double importance, Map classImportance) { + FeatureImportance(String featureName, double importance, List classImportance) { this.featureName = Objects.requireNonNull(featureName); this.importance = importance; - this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance); + this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance); } public FeatureImportance(StreamInput in) throws IOException { this.featureName = in.readString(); this.importance = in.readDouble(); if (in.readBoolean()) { - this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble); + if (in.getVersion().before(Version.V_7_10_0)) { + Map classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble); + this.classImportance = ClassImportance.fromMap(classImportance); + } else { + this.classImportance = in.readList(ClassImportance::new); + } } else { this.classImportance = null; } } - public Map getClassImportance() { + public List getClassImportance() { return classImportance; } @@ -92,7 +102,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeDouble(this.importance); out.writeBoolean(this.classImportance != null); if (this.classImportance != null) { - out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble); + if (out.getVersion().before(Version.V_7_10_0)) { + out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble); + } else { + out.writeList(this.classImportance); + } } } @@ -101,7 +115,7 @@ public Map toMap() { map.put(FEATURE_NAME, featureName); map.put(IMPORTANCE, importance); if (classImportance != null) { - classImportance.forEach(map::put); + map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList())); } return map; } @@ -112,11 +126,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(FEATURE_NAME, featureName); builder.field(IMPORTANCE, importance); if (classImportance != null && classImportance.isEmpty() == false) { - builder.startObject(CLASS_IMPORTANCE); - for (Map.Entry entry : classImportance.entrySet()) { - builder.field(entry.getKey(), entry.getValue()); - } - builder.endObject(); + builder.field(CLASSES, classImportance); } builder.endObject(); return builder; @@ -136,4 +146,92 @@ public boolean equals(Object object) { public int hashCode() { return Objects.hash(featureName, importance, classImportance); } + + public static class ClassImportance implements Writeable, ToXContentObject { + + static final String CLASS_NAME = "class_name"; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("feature_importance_class_importance", + a -> new ClassImportance((String) a[0], (Double) a[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME)); + PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); + } + + private static ClassImportance fromMapEntry(Map.Entry entry) { + return new ClassImportance(entry.getKey(), entry.getValue()); + } + + private static List fromMap(Map classImportanceMap) { + return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList()); + } + + private static Map toMap(List importances) { + return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance)); + } + + public static ClassImportance fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final String className; + private final double importance; + + public ClassImportance(String className, double importance) { + this.className = className; + this.importance = importance; + } + + public ClassImportance(StreamInput in) throws IOException { + this.className = in.readString(); + this.importance = in.readDouble(); + } + + public String getClassName() { + return className; + } + + public double getImportance() { + return importance; + } + + public Map toMap() { + Map map = new LinkedHashMap<>(); + map.put(CLASS_NAME, className); + map.put(IMPORTANCE, importance); + return map; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(className); + out.writeDouble(importance); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME, className); + builder.field(IMPORTANCE, importance); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassImportance that = (ClassImportance) o; + return Double.compare(that.importance, importance) == 0 && + Objects.equals(className, that.className); + } + + @Override + public int hashCode() { + return Objects.hash(className, importance); + } + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 5bfa4e054ff8e..d4cadf33bf489 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -15,7 +15,6 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -139,11 +138,13 @@ public static List transformFeatureImportance(Map classImportance = new LinkedHashMap<>(v.length, 1.0f); + List classImportance = new ArrayList<>(v.length); // If the classificationLabels exist, their length must match leaf_value length assert classificationLabels == null || classificationLabels.size() == v.length; for (int i = 0; i < v.length; i++) { - classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]); + classImportance.add(new FeatureImportance.ClassImportance( + classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), + v[i])); } importances.add(FeatureImportance.forClassification(k, classImportance)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java index efeb2cdb25603..64ca2b1592aeb 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -152,8 +152,15 @@ public void testWriteResultsWithImportance() { FeatureImportance importance = importanceList.get(i); assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName())); assertThat(objectMap.get("importance"), equalTo(importance.getImportance())); + @SuppressWarnings("unchecked") + List> classImportances = (List>)objectMap.get("classes"); if (importance.getClassImportance() != null) { - importance.getClassImportance().forEach((k, v) -> assertThat(objectMap.get(k), equalTo(v))); + for (int j = 0; j < importance.getClassImportance().size(); j++) { + Map classMap = classImportances.get(j); + FeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j); + assertThat(classMap.get("class_name"), equalTo(classImportance.getClassName())); + assertThat(classMap.get("importance"), equalTo(classImportance.getImportance())); + } } } } @@ -205,7 +212,7 @@ public void testToXContent() { expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}"; assertEquals(expected, stringRep); - FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap()); + FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList()); TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0); result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp), Collections.singletonList(fi), config, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java index f23366b10787e..6a3563f3a46a9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; -import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -29,7 +28,8 @@ static FeatureImportance randomClassification() { randomAlphaOfLength(10), Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomLongBetween(2, 10)) - .collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false)))); + .map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false))) + .collect(Collectors.toList())); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java index 91899b688ae8a..29a402484741c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -92,7 +92,7 @@ public void testToXContent() { String expected = "{\"" + resultsField + "\":1.0}"; assertEquals(expected, stringRep); - FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap()); + FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList()); result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi)); stringRep = Strings.toString(result); expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";