diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index fd86ff1f9342..262bc30b3d5c 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -38,11 +38,13 @@ class and wrapper class that allows lambda functions to be used as import copy import itertools +import json import logging import operator import os import sys import threading +import warnings from functools import reduce from functools import wraps from typing import TYPE_CHECKING @@ -83,6 +85,7 @@ class and wrapper class that allows lambda functions to be used as from apache_beam.typehints.trivial_inference import instance_to_type from apache_beam.typehints.typehints import validate_composite_type_param from apache_beam.utils import proto_utils +from apache_beam.utils import python_callable if TYPE_CHECKING: from apache_beam import coders @@ -95,6 +98,7 @@ class and wrapper class that allows lambda functions to be used as 'PTransform', 'ptransform_fn', 'label_from_callable', + 'annotate_yaml', ] _LOGGER = logging.getLogger(__name__) @@ -1093,3 +1097,51 @@ def __ror__(self, pvalueish, _unused=None): def expand(self, pvalue): raise RuntimeError("Should never be expanded directly.") + + +# Defined here to avoid circular import issues for Beam library transforms. +def annotate_yaml(constructor): + """Causes instances of this transform to be annotated with their yaml syntax. + + Should only be used for transforms that are fully defined by their constructor + arguments. + """ + @wraps(constructor) + def wrapper(*args, **kwargs): + transform = constructor(*args, **kwargs) + + fully_qualified_name = ( + f'{constructor.__module__}.{constructor.__qualname__}') + try: + imported_constructor = ( + python_callable.PythonCallableWithSource. + load_from_fully_qualified_name(fully_qualified_name)) + if imported_constructor != wrapper: + raise ImportError('Different object.') + except ImportError: + warnings.warn(f'Cannot import {constructor} as {fully_qualified_name}.') + return transform + + try: + config = json.dumps({ + 'constructor': fully_qualified_name, + 'args': args, + 'kwargs': kwargs, + }) + except TypeError as exn: + warnings.warn( + f'Cannot serialize arguments for {constructor} as json: {exn}') + return transform + + original_annotations = transform.annotations + transform.annotations = lambda: { + **original_annotations(), + # These override whatever may have been provided earlier. + # The outermost call is expected to be the most specific. + 'yaml_provider': 'python', + 'yaml_type': 'PyTransform', + 'yaml_args': config, + } + return transform + + return wrapper diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index 9a540e3551ff..8a6b73b999f2 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -229,6 +229,23 @@ def test_name_is_ambiguous(self): output: AnotherFilter ''') + def test_annotations(self): + t = LinearTransform(5, b=100) + annotations = t.annotations() + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + result = p | YamlTransform( + ''' + type: chain + transforms: + - type: Create + config: + elements: [0, 1, 2, 3] + - type: %r + config: %s + ''' % (annotations['yaml_type'], annotations['yaml_args'])) + assert_that(result, equal_to([100, 105, 110, 115])) + class CreateTimestamped(beam.PTransform): def __init__(self, elements): @@ -610,6 +627,19 @@ def test_prefers_same_provider_class(self): label='StartWith3') +@beam.transforms.ptransform.annotate_yaml +class LinearTransform(beam.PTransform): + """A transform used for testing annotate_yaml.""" + def __init__(self, a, b): + self._a = a + self._b = b + + def expand(self, pcoll): + a = self._a + b = self._b + return pcoll | beam.Map(lambda x: a * x + b) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()