diff --git a/dataprofiler/labelers/base_data_labeler.py b/dataprofiler/labelers/base_data_labeler.py index 611d07d8b..201f78998 100644 --- a/dataprofiler/labelers/base_data_labeler.py +++ b/dataprofiler/labelers/base_data_labeler.py @@ -637,7 +637,9 @@ def load_from_library(cls, name: str) -> BaseDataLabeler: :return: DataLabeler class :rtype: BaseDataLabeler """ - return cls(os.path.join(default_labeler_dir, name)) + labeler = cls(os.path.join(default_labeler_dir, name)) + labeler._default_model_loc = name + return labeler @classmethod def load_from_disk(cls, dirpath: str, load_options: dict = None) -> BaseDataLabeler: diff --git a/dataprofiler/profilers/json_encoder.py b/dataprofiler/profilers/json_encoder.py index 45c771b73..95a06efc6 100644 --- a/dataprofiler/profilers/json_encoder.py +++ b/dataprofiler/profilers/json_encoder.py @@ -35,7 +35,16 @@ def default(self, to_serialize): elif isinstance(to_serialize, pd.Timestamp): return to_serialize.isoformat() elif isinstance(to_serialize, BaseDataLabeler): - return to_serialize._default_model_loc + # TODO: This does not allow the user to serialize a model if it is loaded + # "from_disk". Changes to BaseDataLabeler are needed for this feature + if to_serialize._default_model_loc is None: + raise ValueError( + "Serialization cannot be done on labelers with " + "_default_model_loc not set" + ) + + return {"from_library": to_serialize._default_model_loc} + elif callable(to_serialize): return to_serialize.__name__ return json.JSONEncoder.default(self, to_serialize) diff --git a/dataprofiler/tests/labelers/test_data_labelers.py b/dataprofiler/tests/labelers/test_data_labelers.py index b7b2f00f3..bbde1c506 100644 --- a/dataprofiler/tests/labelers/test_data_labelers.py +++ b/dataprofiler/tests/labelers/test_data_labelers.py @@ -137,11 +137,15 @@ def test_load_data_labeler(self, *mocks): def test_load_from_library(self, *mocks): data_labeler = dp.DataLabeler.load_from_library("structured_model") self.assertIsInstance(data_labeler, BaseDataLabeler) + # Testing to ensure _default_model_loc is set correctly + self.assertEqual("structured_model", data_labeler._default_model_loc) data_labeler = dp.DataLabeler.load_from_library( "structured_model", trainable=True ) self.assertIsInstance(data_labeler, TrainableDataLabeler) + # Testing to ensure _default_model_loc is set correctly + self.assertEqual("structured_model", data_labeler._default_model_loc) @mock.patch("tensorflow.keras.models.load_model") def test_load_from_disk(self, *mocks): diff --git a/dataprofiler/tests/profilers/test_data_labeler_column_profile.py b/dataprofiler/tests/profilers/test_data_labeler_column_profile.py index ac362f4b9..d71d50ed6 100644 --- a/dataprofiler/tests/profilers/test_data_labeler_column_profile.py +++ b/dataprofiler/tests/profilers/test_data_labeler_column_profile.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +from dataprofiler.labelers import BaseDataLabeler from dataprofiler.profilers import utils from dataprofiler.profilers.data_labeler_column_profile import DataLabelerColumn from dataprofiler.profilers.json_encoder import ProfileEncoder @@ -14,7 +15,10 @@ from . import utils as test_utils -@mock.patch("dataprofiler.profilers.data_labeler_column_profile.DataLabeler") +@mock.patch( + "dataprofiler.profilers.data_labeler_column_profile.DataLabeler", + spec=BaseDataLabeler, +) class TestDataLabelerColumnProfiler(unittest.TestCase): @staticmethod def _setup_data_labeler_mock(mock_instance): @@ -396,15 +400,20 @@ def test_empty_data(self, *mocks): diff_profile = profiler1.diff(profiler2) self.assertIsNone(merge_profile.data_label) - -class TestDataLabelerColumnProfilerNoMock(unittest.TestCase): - @classmethod - def setUpClass(cls): - test_utils.set_seed(seed=0) - - def test_json_encode(self): + def test_json_encode(self, mock_instance): + self._setup_data_labeler_mock(mock_instance) profiler = DataLabelerColumn("") + # Validates that error is raised if model loc is not set for labeler + profiler.data_labeler._default_model_loc = None + with self.assertRaisesRegex( + ValueError, + "Serialization cannot be done on labelers with _default_model_loc not set", + ): + _ = json.dumps(profiler, cls=ProfileEncoder) + + # Reset the model loc to its initial value + profiler.data_labeler._default_model_loc = "this is a test model loc" serialized = json.dumps(profiler, cls=ProfileEncoder) expected = json.dumps( @@ -418,7 +427,7 @@ def test_json_encode(self): "times": {}, "thread_safe": False, "_max_sample_size": 1000, - "data_labeler": "structured_model", + "data_labeler": {"from_library": "this is a test model loc"}, "_reverse_label_mapping": None, "_possible_data_labels": None, "_rank_distribution": None, @@ -435,127 +444,34 @@ def test_json_encode(self): self.assertEqual(serialized, expected) - def test_json_encode_after_update(self): - data = pd.Series(["1", "2", "3"], dtype=object) + def test_json_encode_after_update(self, mock_instance): + self._setup_data_labeler_mock(mock_instance) + data = pd.Series(["1", "2", "3", "4"], dtype=object) profiler = DataLabelerColumn(data.name) - + profiler.data_labeler._default_model_loc = "this is a test model loc" with test_utils.mock_timeit(): profiler.update(data) serialized = json.dumps(profiler, cls=ProfileEncoder) - expected = json.dumps( { "class": "DataLabelerColumn", "data": { "name": None, "col_index": float("nan"), - "sample_size": 3, + "sample_size": 4, "metadata": {}, - "times": {"data_labeler_predict": 3.0}, + "times": {"data_labeler_predict": 1.0}, "thread_safe": False, "_max_sample_size": 1000, - "data_labeler": "structured_model", - "_reverse_label_mapping": { - "1": "UNKNOWN", - "2": "ADDRESS", - "3": "BAN", - "4": "CREDIT_CARD", - "5": "DATE", - "6": "TIME", - "7": "DATETIME", - "8": "DRIVERS_LICENSE", - "9": "EMAIL_ADDRESS", - "10": "UUID", - "11": "HASH_OR_KEY", - "12": "IPV4", - "13": "IPV6", - "14": "MAC_ADDRESS", - "15": "PERSON", - "16": "PHONE_NUMBER", - "17": "SSN", - "18": "URL", - "19": "US_STATE", - "20": "INTEGER", - "21": "FLOAT", - "22": "QUANTITY", - "23": "ORDINAL", - }, - "_possible_data_labels": [ - "UNKNOWN", - "ADDRESS", - "BAN", - "CREDIT_CARD", - "DATE", - "TIME", - "DATETIME", - "DRIVERS_LICENSE", - "EMAIL_ADDRESS", - "UUID", - "HASH_OR_KEY", - "IPV4", - "IPV6", - "MAC_ADDRESS", - "PERSON", - "PHONE_NUMBER", - "SSN", - "URL", - "US_STATE", - "INTEGER", - "FLOAT", - "QUANTITY", - "ORDINAL", - ], + "data_labeler": {"from_library": "this is a test model loc"}, + "_reverse_label_mapping": {0: "a", 1: "b"}, + "_possible_data_labels": ["a", "b"], "_rank_distribution": { - "UNKNOWN": 0, - "ADDRESS": 0, - "BAN": 0, - "CREDIT_CARD": 0, - "DATE": 0, - "TIME": 0, - "DATETIME": 0, - "DRIVERS_LICENSE": 0, - "EMAIL_ADDRESS": 0, - "UUID": 0, - "HASH_OR_KEY": 0, - "IPV4": 0, - "IPV6": 0, - "MAC_ADDRESS": 0, - "PERSON": 0, - "PHONE_NUMBER": 0, - "SSN": 0, - "URL": 0, - "US_STATE": 0, - "INTEGER": 3, - "FLOAT": 0, - "QUANTITY": 0, - "ORDINAL": 0, + "a": 2, + "b": 2, }, - "_sum_predictions": [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 3.0, - 0.0, - 0.0, - 0.0, - ], + "_sum_predictions": [2.0, 2.0], "_top_k_voting": 1, "_min_voting_prob": 0.2, "_min_prob_differential": 0.2,