Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Return total SHAP per feature as a new result type #1387

Merged
merged 19 commits into from
Aug 13, 2020

Conversation

valeriy42
Copy link
Contributor

@valeriy42 valeriy42 commented Jul 6, 2020

This PR add computation of the total feature importance values and outputs it as a new result type.

Example outputs:

  • Regression
{
    "row_results": {
        "checksum": 0,
        "results": {
            "ml": {
                "target_prediction": 1358.2039794921876,
                "is_training": true,
                "feature_importance": [
                    {
                        "feature_name": "c1",
                        "importance": 448.3700609767421
                    },
                    {
                        "feature_name": "c2",
                        "importance": 1146.2576129663276
                    },
                    {
                        "feature_name": "c3",
                        "importance": -388.896988459571
                    },
                    {
                        "feature_name": "c4",
                        "importance": 158.7185644712811
                    }
                ]
            }
        }
    }
},
{
    "model_metadata": {
        "total_feature_importance": [
            {
                "feature_name": "c4",
                "importance": {
                    "mean_magnitude": 233.4600671221131,
                    "min": -565.7664157184156,
                    "max": 468.8953979253651
                }
            },
            {
                "feature_name": "c3",
                "importance": {
                    "mean_magnitude": 227.52681995349807,
                    "min": -474.187447119175,
                    "max": 500.79764582176218
                }
            },
            {
                "feature_name": "c1",
                "importance": {
                    "mean_magnitude": 479.8491325534919,
                    "min": -584.6059620924166,
                    "max": 601.3424189083114
                }
            },
            {
                "feature_name": "c2",
                "importance": {
                    "mean_magnitude": 729.7375145579323,
                    "min": -1438.469059491588,
                    "max": 1428.738023747545
                }
            }
        ]
    }
}
  • Binary classification
{
    "row_results": {
        "checksum": 0,
        "results": {
            "ml": {
                "target_prediction": "foo",
                "prediction_probability": 0.9632688006864724,
                "prediction_score": 0.9632688006864724,
                "is_training": true,
                "feature_importance": [
                    {
                        "feature_name": "c1",
                        "classes": [
                            {
                                "class_name": "foo",
                                "importance": -0.050036180361561228
                            },
                            {
                                "class_name": "bar",
                                "importance": 0.050036180361561228
                            }
                        ]
                    },
                    {
                        "feature_name": "c2",
                        "classes": [
                            {
                                "class_name": "foo",
                                "importance": -2.787898169333443
                            },
                            {
                                "class_name": "bar",
                                "importance": 2.787898169333443
                            }
                        ]
                    },
                    {
                        "feature_name": "c3",
                        "classes": [
                            {
                                "class_name": "foo",
                                "importance": -0.9016447487592819
                            },
                            {
                                "class_name": "bar",
                                "importance": 0.9016447487592819
                            }
                        ]
                    },
                    {
                        "feature_name": "c4",
                        "classes": [
                            {
                                "class_name": "foo",
                                "importance": 0.4345632399908005
                            },
                            {
                                "class_name": "bar",
                                "importance": -0.4345632399908005
                            }
                        ]
                    }
                ]
            }
        }
    }
},
{
    "model_metadata": {
        "total_feature_importance": [
            {
                "feature_name": "c4",
                "classes": [
                    {
                        "class_name": "foo",
                        "importance": {
                            "mean_magnitude": 0.5077140893490363,
                            "min": -1.2245953772608847,
                            "max": 1.2245953772608847
                        }
                    },
                    {
                        "class_name": "bar",
                        "importance": {
                            "mean_magnitude": 0.5077140893490363,
                            "min": -1.2245953772608847,
                            "max": 1.2245953772608847
                        }
                    }
                ]
            },
            {
                "feature_name": "c3",
                "classes": [
                    {
                        "class_name": "foo",
                        "importance": {
                            "mean_magnitude": 0.37436172343432769,
                            "min": -1.3221622827321056,
                            "max": 1.3221622827321056
                        }
                    },
                    {
                        "class_name": "bar",
                        "importance": {
                            "mean_magnitude": 0.37436172343432769,
                            "min": -1.3221622827321056,
                            "max": 1.3221622827321056
                        }
                    }
                ]
            },
            {
                "feature_name": "c1",
                "classes": [
                    {
                        "class_name": "foo",
                        "importance": {
                            "mean_magnitude": 1.0116256529234005,
                            "min": -2.4239089033397378,
                            "max": 2.4239089033397378
                        }
                    },
                    {
                        "class_name": "bar",
                        "importance": {
                            "mean_magnitude": 1.0116256529234005,
                            "min": -2.4239089033397378,
                            "max": 2.4239089033397378
                        }
                    }
                ]
            },
            {
                "feature_name": "c2",
                "classes": [
                    {
                        "class_name": "foo",
                        "importance": {
                            "mean_magnitude": 1.878800695094461,
                            "min": -3.288343526748284,
                            "max": 3.288343526748284
                        }
                    },
                    {
                        "class_name": "bar",
                        "importance": {
                            "mean_magnitude": 1.878800695094461,
                            "min": -3.288343526748284,
                            "max": 3.288343526748284
                        }
                    }
                ]
            }
        ]
    }
}
  • Multi-class classification
{
    "row_results": {
        "checksum": 0,
        "results": {
            "ml": {
                "target_prediction": "foo",
                "prediction_probability": 0.9462761477876273,
                "prediction_score": 0.17141987121246278,
                "is_training": true,
                "top_classes": [
                    {
                        "class_name": "foo",
                        "class_probability": 0.9462761477876273,
                        "class_score": 0.17141987121246278
                    },
                    {
                        "class_name": "bar",
                        "class_probability": 0.034511190424692039,
                        "class_score": 0.034511190424692039
                    },
                    {
                        "class_name": "baz",
                        "class_probability": 0.019212661787680529,
                        "class_score": 0.003818541178138616
                    }
                ],
                "feature_importance": [
                    {
                        "feature_name": "c1",
                        "classes": [
                            {
                                "class_name": "foo",
                                "importance": 0.27949101986590138
                            },
                            {
                                "class_name": "baz",
                                "importance": -0.12386717688503159
                            },
                            {
                                "class_name": "bar",
                                "importance": -0.1556238429808595
                            }
                        ]
                    },
                    {
                        "feature_name": "c2",
                        "classes": [
                            {
                                "class_name": "foo",
                                "importance": 1.663225619115169
                            },
                            {
                                "class_name": "baz",
                                "importance": -1.72288680107119
                            },
                            {
                                "class_name": "bar",
                                "importance": 0.05966118195592636
                            }
                        ]
                    },
                    {
                        "feature_name": "c3",
                        "classes": [
                            {
                                "class_name": "foo",
                                "importance": 0.22379061504358678
                            },
                            {
                                "class_name": "baz",
                                "importance": -0.23352199623126147
                            },
                            {
                                "class_name": "bar",
                                "importance": 0.00973138118767497
                            }
                        ]
                    },
                    {
                        "feature_name": "c4",
                        "classes": [
                            {
                                "class_name": "foo",
                                "importance": -0.24052346877901588
                            },
                            {
                                "class_name": "baz",
                                "importance": 0.19615020783390645
                            },
                            {
                                "class_name": "bar",
                                "importance": 0.04437326094510882
                            }
                        ]
                    }
                ]
            }
        }
    }
},
{
    "model_metadata": {
        "total_feature_importance": [
            {
                "feature_name": "c4",
                "classes": [
                    {
                        "class_name": "foo",
                        "importance": {
                            "mean_magnitude": 0.24209656030374336,
                            "min": -0.5757885311922144,
                            "max": 0.6352558320805585
                        }
                    },
                    {
                        "class_name": "baz",
                        "importance": {
                            "mean_magnitude": 0.21362926518754464,
                            "min": -0.6975561926823535,
                            "max": 0.5758437812831863
                        }
                    },
                    {
                        "class_name": "bar",
                        "importance": {
                            "mean_magnitude": 0.0346461346585683,
                            "min": -0.15232748182282736,
                            "max": 0.10645868140524567
                        }
                    }
                ]
            },
            {
                "feature_name": "c3",
                "classes": [
                    {
                        "class_name": "foo",
                        "importance": {
                            "mean_magnitude": 0.30818045910633587,
                            "min": -0.7931796597779941,
                            "max": 0.3339785961510332
                        }
                    },
                    {
                        "class_name": "baz",
                        "importance": {
                            "mean_magnitude": 0.3302457015751672,
                            "min": -0.45783991999546966,
                            "max": 0.8242004300074223
                        }
                    },
                    {
                        "class_name": "bar",
                        "importance": {
                            "mean_magnitude": 0.05158362329758712,
                            "min": -0.31757994080335757,
                            "max": 0.2435538443329867
                        }
                    }
                ]
            },
            {
                "feature_name": "c1",
                "classes": [
                    {
                        "class_name": "foo",
                        "importance": {
                            "mean_magnitude": 0.6477535654120877,
                            "min": -1.9137505247875509,
                            "max": 1.287337819860563
                        }
                    },
                    {
                        "class_name": "baz",
                        "importance": {
                            "mean_magnitude": 0.7520521962038734,
                            "min": -1.531931414792879,
                            "max": 1.6810760229277138
                        }
                    },
                    {
                        "class_name": "bar",
                        "importance": {
                            "mean_magnitude": 0.16574602557111424,
                            "min": -0.3434321409380257,
                            "max": 0.4223094459672269
                        }
                    }
                ]
            },
            {
                "feature_name": "c2",
                "classes": [
                    {
                        "class_name": "foo",
                        "importance": {
                            "mean_magnitude": 1.101426013839722,
                            "min": -2.2925533638349937,
                            "max": 1.7987193407562752
                        }
                    },
                    {
                        "class_name": "baz",
                        "importance": {
                            "mean_magnitude": 1.1717215530182037,
                            "min": -1.8971843995291227,
                            "max": 2.6284289404335188
                        }
                    },
                    {
                        "class_name": "bar",
                        "importance": {
                            "mean_magnitude": 0.18929675306403358,
                            "min": -0.43649354351400246,
                            "max": 0.612025928728523
                        }
                    }
                ]
            }
        ]
    }
}

Closes #974

EDIT: I updated the format example above.

@valeriy42
Copy link
Contributor Author

For v7.9.0 only if we manage to have a Java parser implemented timely.

Copy link
Contributor

@tveasey tveasey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a couple of style suggestions. My main comments are I think:

  1. We should normalise by the document count. (You should be able to switch the map's value type to a CBasicStatistics::SSampleMean<TVector>::TAccumulator to do this. Although note you have to initialise this with a zero vector of the correct size.)
  2. It would be nice to compute these quantities accurately, i.e. not just summing over the top importances for each document.

lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc Outdated Show resolved Hide resolved
lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc Outdated Show resolved Hide resolved
lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc Outdated Show resolved Hide resolved
lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc Outdated Show resolved Hide resolved
@droberts195
Copy link
Contributor

We agreed to defer this from 7.9 to 7.10, so I altered the labels.

writer.Double(item.second(0));
} else {
for (int j = 0; j < item.second.size() && j < numberClasses; ++j) {
writer.Key(classValues[j]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work for storage in ES. This index stores the information for all trained models and indexing the class names for the feature importances will not scale.

I propose this format:

{
	"feature_name": "c4",
	"importance": 0.4810469375580312,
	"class_importance": [
		{
			"class_name": "foo",
			"importance": 0.24052346877901588
		},
		{
			"class_name": "baz",
			"importance": 0.19615020783390645
		},
		{
			"class_name": "bar",
			"importance": 0.04437326094510882
		}
	]
}

class_importance will be a nested data type that allows aggregations and searches for specific models and classnames.

@benwtrent
Copy link
Member

Java side parsing: elastic/elasticsearch#59725

We will probably have to mute integration tests, merge C++ side, then unmute and merge the parsing java side.

@valeriy42 valeriy42 added the WIP label Jul 17, 2020
@benwtrent
Copy link
Member

@valeriy42 the new format has min, max, mean ? Are there any other changes we are considering?

@valeriy42
Copy link
Contributor Author

the new format has min, max, mean ? Are there any other changes we are considering?

@benwtrent I cannot think of anything more. min and max are useful for visualization to define the axis range.

@benwtrent
Copy link
Member

@valeriy42 it would be nice if the per class importance looked like the regression importance. That way we have consistent JSON objects.

{
                "feature_name": "c2",
                "classes":[
                    {
                        "class_name": "baz",
                        "importance": {
                           "mean_magnitude": 1.1717215530182037,
                           "min": -1.8971843995291227,
                           "max": 2.6284289404335188
                        }
                    }...
                ]
            }

@valeriy42
Copy link
Contributor Author

@benwtrent you are right! I overlooked at in my past commit. I'll fix it accordingly.

@valeriy42 valeriy42 removed the WIP label Aug 12, 2020
Copy link
Contributor

@tveasey tveasey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only a couple of minor comments and a suggestion for one additional bit of testing, which I think is worthwhile. Otherwise, LGTM.

include/api/CInferenceModelMetadata.h Outdated Show resolved Hide resolved
lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc Outdated Show resolved Hide resolved
lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc Outdated Show resolved Hide resolved
@valeriy42 valeriy42 merged commit 3f1b575 into elastic:master Aug 13, 2020
@valeriy42 valeriy42 deleted the total-shap branch August 13, 2020 18:32
valeriy42 added a commit to valeriy42/ml-cpp that referenced this pull request Aug 13, 2020
This PR add computation of the total feature importance values.
valeriy42 added a commit that referenced this pull request Aug 14, 2020
This PR add computation of the total feature importance values.

Backport of #1387.
benwtrent added a commit to elastic/elasticsearch that referenced this pull request Aug 14, 2020
This updates the feature_importance mapping change from elastic/ml-cpp#1387
benwtrent added a commit to benwtrent/elasticsearch that referenced this pull request Aug 14, 2020
This updates the feature_importance mapping change from elastic/ml-cpp#1387
benwtrent added a commit to elastic/elasticsearch that referenced this pull request Aug 14, 2020
This updates the feature_importance mapping change from elastic/ml-cpp#1387
benwtrent pushed a commit that referenced this pull request Aug 14, 2020
Activate the output of the model metadata and the corresponding unit tests for total feature importance.

The implementation itself was introduced in #1387 however, I need to fix the documentation, it was originally attributed to v7.10. Hence, I mark this PR as enhancement to rectify the docs.
benwtrent added a commit that referenced this pull request Aug 24, 2020
Activate the output of the model metadata and the corresponding unit tests for total feature importance.

The implementation itself was introduced in #1387 however, I need to fix the documentation, it was originally attributed to v7.10. Hence, I mark this PR as enhancement to rectify the docs.

Co-authored-by: Valeriy Khakhutskyy <[email protected]>
valeriy42 added a commit to valeriy42/ml-cpp that referenced this pull request Sep 1, 2020
Activate the output of the model metadata and the corresponding unit tests for total feature importance.

The implementation itself was introduced in elastic#1387 however, I need to fix the documentation, it was originally attributed to v7.10. Hence, I mark this PR as enhancement to rectify the docs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ML] Return total SHAP per feature as a new result type
4 participants