-
Notifications
You must be signed in to change notification settings - Fork 165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Categorical PSI #1039
Categorical PSI #1039
Changes from 4 commits
b269fe7
bd9f11c
9aed67c
dd1aa6f
1b42df0
0b47910
bd770ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,18 +2,20 @@ | |
from __future__ import annotations | ||
|
||
import math | ||
import warnings | ||
from collections import defaultdict | ||
from operator import itemgetter | ||
from typing import cast | ||
|
||
import datasketches | ||
from pandas import DataFrame, Series | ||
|
||
from .. import dp_logging | ||
from . import profiler_utils | ||
from .base_column_profilers import BaseColumnProfiler | ||
from .profiler_options import CategoricalOptions | ||
|
||
logger = dp_logging.get_child_logger(__name__) | ||
|
||
|
||
class CategoricalColumn(BaseColumnProfiler["CategoricalColumn"]): | ||
""" | ||
|
@@ -306,24 +308,29 @@ def diff(self, other_profile: CategoricalColumn, options: dict = None) -> dict: | |
other_profile._categories.items(), key=itemgetter(1), reverse=True | ||
) | ||
) | ||
if cat_count1.keys() == cat_count2.keys(): | ||
total_psi = 0.0 | ||
for key in cat_count1.keys(): | ||
perc_A = cat_count1[key] / self.sample_size | ||
perc_B = cat_count2[key] / other_profile.sample_size | ||
total_psi += (perc_B - perc_A) * math.log(perc_B / perc_A) | ||
differences["statistics"]["psi"] = total_psi | ||
else: | ||
warnings.warn( | ||
"psi was not calculated due to the differences in categories " | ||
"of the profiles. Differences:\n" | ||
f"{set(cat_count1.keys()) ^ set(cat_count2.keys())}", | ||
RuntimeWarning, | ||
) | ||
( | ||
self_cat_count, | ||
other_cat_count, | ||
) = self._preprocess_for_categorical_psi_calculation( | ||
self_cat_count=cat_count1, | ||
other_cat_count=cat_count2, | ||
) | ||
Comment on lines
+311
to
+317
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. formatting is pre-commit |
||
|
||
total_psi = 0.0 | ||
for iter_key in self_cat_count.keys(): | ||
percent_self = self_cat_count[iter_key] / self.sample_size | ||
percent_other = other_cat_count[iter_key] / other_profile.sample_size | ||
if (percent_other == 0) or (percent_self == 0): | ||
total_psi += 0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this |
||
else: | ||
total_psi += (percent_other - percent_self) * math.log( | ||
percent_other / percent_self | ||
) | ||
taylorfturner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
differences["statistics"]["psi"] = total_psi | ||
|
||
differences["statistics"][ | ||
"categorical_count" | ||
] = profiler_utils.find_diff_of_dicts(cat_count1, cat_count2) | ||
] = profiler_utils.find_diff_of_dicts(self_cat_count, other_cat_count) | ||
|
||
return differences | ||
|
||
|
@@ -431,6 +438,27 @@ def is_match(self) -> bool: | |
is_match = True | ||
return is_match | ||
|
||
def _preprocess_for_categorical_psi_calculation( | ||
self, self_cat_count, other_cat_count | ||
): | ||
super_set_categories = set(self_cat_count.keys()) | set(other_cat_count.keys()) | ||
if (super_set_categories != self_cat_count.keys()) or ( | ||
super_set_categories != other_cat_count.keys() | ||
): | ||
taylorfturner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.info( | ||
f"""PSI data pre-processing found that categories between | ||
the profiles were not equal. Both profiles not contain | ||
taylorfturner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
the following categories {super_set_categories}.""" | ||
) | ||
|
||
for iter_key in super_set_categories: | ||
for iter_dictionary in [self_cat_count, other_cat_count]: | ||
try: | ||
iter_dictionary[iter_key] = iter_dictionary[iter_key] | ||
except KeyError: | ||
iter_dictionary[iter_key] = 0 | ||
return self_cat_count, other_cat_count | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. main fix is here to ensure that each cat_count has the same keys even if some are zero.... this is to ensure the PSI is calculated when new categories are added or old categories are removed over time |
||
def _check_stop_condition_is_met(self, sample_size: int, unqiue_ratio: float): | ||
"""Return boolean given stop conditions. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -701,6 +701,7 @@ def test_gini_impurity(self): | |
self.assertEqual(profile.gini_impurity, None) | ||
|
||
def test_categorical_diff(self): | ||
# test psi new category in another profile | ||
df_categorical = pd.Series(["y", "y", "y", "y", "n", "n", "n"]) | ||
profile = CategoricalColumn(df_categorical.name) | ||
profile.update(df_categorical) | ||
|
@@ -720,21 +721,17 @@ def test_categorical_diff(self): | |
"categories": [[], ["y", "n"], ["maybe"]], | ||
"gini_impurity": -0.16326530612244894, | ||
"unalikeability": -0.19047619047619047, | ||
"categorical_count": {"y": 1, "n": 1, "maybe": [None, 2]}, | ||
"categorical_count": {"y": 1, "n": 1, "maybe": -2}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. justification for this change lies in profiler_utils.find_diff_of_dicts ... on L577, we do not hit that any more due to both cat_count dictionaries having all the categories between both profiles |
||
"chi2-test": { | ||
"chi2-statistic": 82 / 35, | ||
"df": 2, | ||
"p-value": 0.3099238764710244, | ||
}, | ||
"psi": 0.0990210257942779, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PSI value for when new category is added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new |
||
}, | ||
} | ||
with self.assertWarnsRegex( | ||
RuntimeWarning, | ||
"psi was not calculated due to the differences in categories " | ||
"of the profiles. Differences:\n{'maybe'}", | ||
): | ||
test_profile_diff = profile.diff(profile2) | ||
Comment on lines
-731
to
-736
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removing since this is un-needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we now handle keys not being equal in the the preprocessing function for categorical PSI |
||
self.assertDictEqual(expected_diff, test_profile_diff) | ||
actual_diff = profile.diff(profile2) | ||
self.assertDictEqual(expected_diff, actual_diff) | ||
Comment on lines
+733
to
+734
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. variable clean up -- no functional change |
||
|
||
# Test with one categorical column matching | ||
df_not_categorical = pd.Series( | ||
|
@@ -770,10 +767,6 @@ def test_categorical_diff(self): | |
profile2 = CategoricalColumn(df_categorical.name) | ||
profile2.update(df_categorical) | ||
|
||
# chi2-statistic = sum((observed-expected)^2/expected for each category in each column) | ||
# df = categories - 1 | ||
# psi = (% of records based on Sample (A) - % of records Sample (B)) * ln(A/ B) | ||
Comment on lines
-773
to
-775
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removing un-needed |
||
# p-value found through using chi2 CDF | ||
expected_diff = { | ||
"categorical": "unchanged", | ||
"statistics": { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -500,12 +500,13 @@ def test_column_stats_profile_compiler_stats_diff(self): | |
"categories": [["1"], ["9"], ["10"]], | ||
"gini_impurity": 0.06944444444444448, | ||
"unalikeability": 0.16666666666666663, | ||
"categorical_count": {"9": -1, "1": [1, None], "10": [None, 1]}, | ||
"categorical_count": {"9": -1, "1": 1, "10": -1}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually I found the justification for this change and it lies in |
||
"chi2-test": { | ||
"chi2-statistic": 2.1, | ||
"df": 2, | ||
"p-value": 0.3499377491111554, | ||
}, | ||
"psi": 0.009815252971365292, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PSI value for ColumnStatsProfileCompiler that is identifying the test data as categorical... therefore needs a PSI value for expected_diff |
||
}, | ||
} | ||
self.assertDictEqual(expected_diff, compiler1.diff(compiler2)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding a logger for awareness of preprocessing changes to categories