Skip to content

Commit

Permalink
ready datalabeler for deserialization and improvement on serializatio…
Browse files Browse the repository at this point in the history
…n for datalabeler (#879)
  • Loading branch information
ksneab7 authored Jun 20, 2023
1 parent d9b5f49 commit f46b8a9
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 116 deletions.
4 changes: 3 additions & 1 deletion dataprofiler/labelers/base_data_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,9 @@ def load_from_library(cls, name: str) -> BaseDataLabeler:
:return: DataLabeler class
:rtype: BaseDataLabeler
"""
return cls(os.path.join(default_labeler_dir, name))
labeler = cls(os.path.join(default_labeler_dir, name))
labeler._default_model_loc = name
return labeler

@classmethod
def load_from_disk(cls, dirpath: str, load_options: dict = None) -> BaseDataLabeler:
Expand Down
11 changes: 10 additions & 1 deletion dataprofiler/profilers/json_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,16 @@ def default(self, to_serialize):
elif isinstance(to_serialize, pd.Timestamp):
return to_serialize.isoformat()
elif isinstance(to_serialize, BaseDataLabeler):
return to_serialize._default_model_loc
# TODO: This does not allow the user to serialize a model if it is loaded
# "from_disk". Changes to BaseDataLabeler are needed for this feature
if to_serialize._default_model_loc is None:
raise ValueError(
"Serialization cannot be done on labelers with "
"_default_model_loc not set"
)

return {"from_library": to_serialize._default_model_loc}

elif callable(to_serialize):
return to_serialize.__name__
return json.JSONEncoder.default(self, to_serialize)
4 changes: 4 additions & 0 deletions dataprofiler/tests/labelers/test_data_labelers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,15 @@ def test_load_data_labeler(self, *mocks):
def test_load_from_library(self, *mocks):
data_labeler = dp.DataLabeler.load_from_library("structured_model")
self.assertIsInstance(data_labeler, BaseDataLabeler)
# Testing to ensure _default_model_loc is set correctly
self.assertEqual("structured_model", data_labeler._default_model_loc)

data_labeler = dp.DataLabeler.load_from_library(
"structured_model", trainable=True
)
self.assertIsInstance(data_labeler, TrainableDataLabeler)
# Testing to ensure _default_model_loc is set correctly
self.assertEqual("structured_model", data_labeler._default_model_loc)

@mock.patch("tensorflow.keras.models.load_model")
def test_load_from_disk(self, *mocks):
Expand Down
144 changes: 30 additions & 114 deletions dataprofiler/tests/profilers/test_data_labeler_column_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pandas as pd

from dataprofiler.labelers import BaseDataLabeler
from dataprofiler.profilers import utils
from dataprofiler.profilers.data_labeler_column_profile import DataLabelerColumn
from dataprofiler.profilers.json_encoder import ProfileEncoder
Expand All @@ -14,7 +15,10 @@
from . import utils as test_utils


@mock.patch("dataprofiler.profilers.data_labeler_column_profile.DataLabeler")
@mock.patch(
"dataprofiler.profilers.data_labeler_column_profile.DataLabeler",
spec=BaseDataLabeler,
)
class TestDataLabelerColumnProfiler(unittest.TestCase):
@staticmethod
def _setup_data_labeler_mock(mock_instance):
Expand Down Expand Up @@ -396,15 +400,20 @@ def test_empty_data(self, *mocks):
diff_profile = profiler1.diff(profiler2)
self.assertIsNone(merge_profile.data_label)


class TestDataLabelerColumnProfilerNoMock(unittest.TestCase):
@classmethod
def setUpClass(cls):
test_utils.set_seed(seed=0)

def test_json_encode(self):
def test_json_encode(self, mock_instance):
self._setup_data_labeler_mock(mock_instance)
profiler = DataLabelerColumn("")

# Validates that error is raised if model loc is not set for labeler
profiler.data_labeler._default_model_loc = None
with self.assertRaisesRegex(
ValueError,
"Serialization cannot be done on labelers with _default_model_loc not set",
):
_ = json.dumps(profiler, cls=ProfileEncoder)

# Reset the model loc to its initial value
profiler.data_labeler._default_model_loc = "this is a test model loc"
serialized = json.dumps(profiler, cls=ProfileEncoder)

expected = json.dumps(
Expand All @@ -418,7 +427,7 @@ def test_json_encode(self):
"times": {},
"thread_safe": False,
"_max_sample_size": 1000,
"data_labeler": "structured_model",
"data_labeler": {"from_library": "this is a test model loc"},
"_reverse_label_mapping": None,
"_possible_data_labels": None,
"_rank_distribution": None,
Expand All @@ -435,127 +444,34 @@ def test_json_encode(self):

self.assertEqual(serialized, expected)

def test_json_encode_after_update(self):
data = pd.Series(["1", "2", "3"], dtype=object)
def test_json_encode_after_update(self, mock_instance):
self._setup_data_labeler_mock(mock_instance)
data = pd.Series(["1", "2", "3", "4"], dtype=object)
profiler = DataLabelerColumn(data.name)

profiler.data_labeler._default_model_loc = "this is a test model loc"
with test_utils.mock_timeit():
profiler.update(data)

serialized = json.dumps(profiler, cls=ProfileEncoder)

expected = json.dumps(
{
"class": "DataLabelerColumn",
"data": {
"name": None,
"col_index": float("nan"),
"sample_size": 3,
"sample_size": 4,
"metadata": {},
"times": {"data_labeler_predict": 3.0},
"times": {"data_labeler_predict": 1.0},
"thread_safe": False,
"_max_sample_size": 1000,
"data_labeler": "structured_model",
"_reverse_label_mapping": {
"1": "UNKNOWN",
"2": "ADDRESS",
"3": "BAN",
"4": "CREDIT_CARD",
"5": "DATE",
"6": "TIME",
"7": "DATETIME",
"8": "DRIVERS_LICENSE",
"9": "EMAIL_ADDRESS",
"10": "UUID",
"11": "HASH_OR_KEY",
"12": "IPV4",
"13": "IPV6",
"14": "MAC_ADDRESS",
"15": "PERSON",
"16": "PHONE_NUMBER",
"17": "SSN",
"18": "URL",
"19": "US_STATE",
"20": "INTEGER",
"21": "FLOAT",
"22": "QUANTITY",
"23": "ORDINAL",
},
"_possible_data_labels": [
"UNKNOWN",
"ADDRESS",
"BAN",
"CREDIT_CARD",
"DATE",
"TIME",
"DATETIME",
"DRIVERS_LICENSE",
"EMAIL_ADDRESS",
"UUID",
"HASH_OR_KEY",
"IPV4",
"IPV6",
"MAC_ADDRESS",
"PERSON",
"PHONE_NUMBER",
"SSN",
"URL",
"US_STATE",
"INTEGER",
"FLOAT",
"QUANTITY",
"ORDINAL",
],
"data_labeler": {"from_library": "this is a test model loc"},
"_reverse_label_mapping": {0: "a", 1: "b"},
"_possible_data_labels": ["a", "b"],
"_rank_distribution": {
"UNKNOWN": 0,
"ADDRESS": 0,
"BAN": 0,
"CREDIT_CARD": 0,
"DATE": 0,
"TIME": 0,
"DATETIME": 0,
"DRIVERS_LICENSE": 0,
"EMAIL_ADDRESS": 0,
"UUID": 0,
"HASH_OR_KEY": 0,
"IPV4": 0,
"IPV6": 0,
"MAC_ADDRESS": 0,
"PERSON": 0,
"PHONE_NUMBER": 0,
"SSN": 0,
"URL": 0,
"US_STATE": 0,
"INTEGER": 3,
"FLOAT": 0,
"QUANTITY": 0,
"ORDINAL": 0,
"a": 2,
"b": 2,
},
"_sum_predictions": [
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
3.0,
0.0,
0.0,
0.0,
],
"_sum_predictions": [2.0, 2.0],
"_top_k_voting": 1,
"_min_voting_prob": 0.2,
"_min_prob_differential": 0.2,
Expand Down

0 comments on commit f46b8a9

Please sign in to comment.