Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python]Remove get_artifacts in MLTranform since artifacts are stored in artifact location #29016

Merged
merged 11 commits into from
Oct 25, 2023
13 changes: 0 additions & 13 deletions sdks/python/apache_beam/ml/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,13 @@ def apply_transform(self, data: OperationInputT,
inputs: input data.
"""

@abc.abstractmethod
def get_artifacts(
self, data: OperationInputT,
output_column_prefix: str) -> Optional[Dict[str, OperationOutputT]]:
"""
If the operation generates any artifacts, they can be returned from this
method.
"""
pass

def __call__(self, data: OperationInputT,
output_column_name: str) -> Dict[str, OperationOutputT]:
"""
This method is called when the instance of the class is called.
This method will invoke the apply() method of the class.
"""
transformed_data = self.apply_transform(data, output_column_name)
artifacts = self.get_artifacts(data, output_column_name)
if artifacts:
transformed_data = {**transformed_data, **artifacts}
return transformed_data

def get_counter(self):
Expand Down
18 changes: 0 additions & 18 deletions sdks/python/apache_beam/ml/transforms/handlers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ def apply_transform(self, inputs, output_column_name, **kwargs):
return {output_column_name: inputs * 10}


class _FakeOperationWithArtifacts(TFTOperation):
def apply_transform(self, inputs, output_column_name, **kwargs):
return {output_column_name: inputs}

def get_artifacts(self, data, col_name):
return {'artifact': tf.convert_to_tensor([1])}


class IntType(NamedTuple):
x: int

Expand Down Expand Up @@ -106,16 +98,6 @@ def test_tft_operation_preprocessing_fn(
actual_result = process_handler.process_data_fn(inputs)
self.assertDictEqual(actual_result, expected_result)

def test_preprocessing_fn_with_artifacts(self):
process_handler = handlers.TFTProcessHandler(
transforms=[_FakeOperationWithArtifacts(columns=['x'])],
artifact_location=self.artifact_location)
inputs = {'x': [1, 2, 3]}
preprocessing_fn = process_handler.process_data_fn
actual_result = preprocessing_fn(inputs)
expected_result = {'x': [1, 2, 3], 'artifact': tf.convert_to_tensor([1])}
self.assertDictEqual(actual_result, expected_result)

def test_input_type_from_schema_named_tuple_pcoll(self):
data = [{'x': 1}]
with beam.Pipeline() as p:
Expand Down
30 changes: 9 additions & 21 deletions sdks/python/apache_beam/ml/transforms/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,6 @@ def __init__(self, columns: List[str]) -> None:
"Columns are not specified. Please specify the column for the "
" op %s" % self.__class__.__name__)

def get_artifacts(self, data: common_types.TensorType,
col_name: str) -> Dict[str, common_types.TensorType]:
"""
Returns the artifacts generated by the operation.
"""
return {}

@tf.function
def _split_string_with_delimiter(self, data, delimiter):
"""
Expand Down Expand Up @@ -527,6 +520,7 @@ def __init__(
ngram_range: Tuple[int, int] = (1, 1),
ngrams_separator: Optional[str] = None,
compute_word_count: bool = False,
key_vocab_filename: str = 'key_vocab_mapping',
name: Optional[str] = None,
):
"""
Expand All @@ -547,9 +541,9 @@ def __init__(
n-gram sizes.
seperator: A string that will be inserted between each ngram.
compute_word_count: A boolean that specifies whether to compute
the unique word count and add it as an artifact to the output.
Note that the count will be computed over the entire dataset so
it will be the same value for all inputs.
the unique word count over the entire dataset. Defaults to False.
key_vocab_filename: The file name for the key vocabulary file when
compute_word_count is True.
name: A name for the operation (optional).

Note that original order of the input may not be preserved.
Expand All @@ -560,6 +554,7 @@ def __init__(
self.ngrams_separator = ngrams_separator
self.name = name
self.split_string_by_delimiter = split_string_by_delimiter
self.key_vocab_filename = key_vocab_filename
if compute_word_count:
self.compute_word_count_fn = count_unqiue_words
else:
Expand All @@ -575,18 +570,11 @@ def apply_transform(self, data: tf.SparseTensor, output_col_name: str):
data, self.split_string_by_delimiter)
output = tft.bag_of_words(
data, self.ngram_range, self.ngrams_separator, self.name)
# word counts are written to the key_vocab_filename
self.compute_word_count_fn(data, self.key_vocab_filename)
return {output_col_name: output}

def get_artifacts(self, data: tf.SparseTensor,
col_name: str) -> Dict[str, tf.Tensor]:
return self.compute_word_count_fn(data, col_name)


def count_unqiue_words(data: tf.SparseTensor,
output_col_name: str) -> Dict[str, tf.Tensor]:
keys, count = tft.count_per_key(data)
shape = [tf.shape(data)[0], tf.shape(keys)[0]]
return {
output_col_name + '_unique_elements': tf.broadcast_to(keys, shape),
output_col_name + '_counts': tf.broadcast_to(count, shape)
}
output_vocab_name: str) -> Dict[str, tf.Tensor]:
tft.count_per_key(data, key_vocabulary_filename=output_vocab_name)
44 changes: 27 additions & 17 deletions sdks/python/apache_beam/ml/transforms/tft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
if not tft:
raise unittest.SkipTest('tensorflow_transform is not installed.')

z_score_expected = {'x_mean': 3.5, 'x_var': 2.9166666666666665}


class ScaleZScoreTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -576,6 +574,17 @@ def setUp(self) -> None:
def tearDown(self):
shutil.rmtree(self.artifact_location)

def validate_count_per_key(self):
import os
key_vocab_location = os.path.join(
self.artifact_location, 'transform_fn/assets/key_vocab')
with open(key_vocab_location, 'r'):
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
key_vocab_list = [line.strip() for line in f]

expected_data = ['2 yum', '4 Apple', '1 like', '1 I', '4 pie', '2 Banana']
actual_data = key_vocab_list
self.assertEqual(expected_data, actual_data)

def test_bag_of_words_on_list_seperated_words_default_ngrams(self):
data = [{
'x': ['I', 'like', 'pie', 'pie', 'pie'],
Expand Down Expand Up @@ -691,10 +700,6 @@ def test_bag_of_words_on_by_splitting_input_text(self):
assert_that(result, equal_to(expected_data, equals_fn=np.array_equal))

def test_count_per_key_on_list(self):
def map_element_to_count(elements, counts):
d = {elements[i]: counts[i] for i in range(len(elements))}
return d

data = [{
'x': ['I', 'like', 'pie', 'pie', 'pie'],
}, {
Expand All @@ -703,24 +708,29 @@ def map_element_to_count(elements, counts):
'x': ['Banana', 'Banana', 'Apple', 'Apple', 'Apple', 'Apple']
}]
with beam.Pipeline() as p:
result = (
_ = (
p
| "Create" >> beam.Create(data)
| "MLTransform" >> base.MLTransform(
write_artifact_location=self.artifact_location,
transforms=[
tft.BagOfWords(columns=['x'], compute_word_count=True)
tft.BagOfWords(
columns=['x'],
compute_word_count=True,
key_vocab_filename='my_vocab')
]))
# the unique elements and counts are artifacts and will be
# stored in the result and same for all the elements in the
# PCollection.
result = result | beam.Map(
lambda x: map_element_to_count(x.x_unique_elements, x.x_counts))

expected_data = [{
b'Apple': 4, b'Banana': 2, b'I': 1, b'like': 1, b'pie': 4, b'yum': 2
}] * 3 # since there are 3 elements in input.
assert_that(result, equal_to(expected_data))
def validate_count_per_key(key_vocab_filename):
import os
key_vocab_location = os.path.join(
self.artifact_location, 'transform_fn/assets', key_vocab_filename)
with open(key_vocab_location, 'r') as f:
key_vocab_list = [line.strip() for line in f]
return key_vocab_list

expected_data = ['2 yum', '4 Apple', '1 like', '1 I', '4 pie', '2 Banana']
actual_data = validate_count_per_key('my_vocab')
self.assertEqual(expected_data, actual_data)


if __name__ == '__main__':
Expand Down
Loading