Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an annotation to expose transforms to yaml. #28208

Merged
merged 3 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions sdks/python/apache_beam/transforms/ptransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ class and wrapper class that allows lambda functions to be used as

import copy
import itertools
import json
import logging
import operator
import os
import sys
import threading
import warnings
from functools import reduce
from functools import wraps
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -83,6 +85,7 @@ class and wrapper class that allows lambda functions to be used as
from apache_beam.typehints.trivial_inference import instance_to_type
from apache_beam.typehints.typehints import validate_composite_type_param
from apache_beam.utils import proto_utils
from apache_beam.utils import python_callable

if TYPE_CHECKING:
from apache_beam import coders
Expand All @@ -95,6 +98,7 @@ class and wrapper class that allows lambda functions to be used as
'PTransform',
'ptransform_fn',
'label_from_callable',
'annotate_yaml',
]

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -1093,3 +1097,51 @@ def __ror__(self, pvalueish, _unused=None):

def expand(self, pvalue):
raise RuntimeError("Should never be expanded directly.")


# Defined here to avoid circular import issues for Beam library transforms.
def annotate_yaml(constructor):
"""Causes instances of this transform to be annotated with their yaml syntax.

Should only be used for transforms that are fully defined by their constructor
arguments.
"""
@wraps(constructor)
def wrapper(*args, **kwargs):
transform = constructor(*args, **kwargs)

fully_qualified_name = (
f'{constructor.__module__}.{constructor.__qualname__}')
try:
imported_constructor = (
python_callable.PythonCallableWithSource.
load_from_fully_qualified_name(fully_qualified_name))
if imported_constructor != wrapper:
raise ImportError('Different object.')
except ImportError:
warnings.warn(f'Cannot import {constructor} as {fully_qualified_name}.')
return transform

try:
config = json.dumps({
'constructor': fully_qualified_name,
'args': args,
'kwargs': kwargs,
})
except TypeError as exn:
warnings.warn(
f'Cannot serialize arguments for {constructor} as json: {exn}')
return transform

original_annotations = transform.annotations
transform.annotations = lambda: {
**original_annotations(),
# These override whatever may have been provided earlier.
# The outermost call is expected to be the most specific.
'yaml_provider': 'python',
'yaml_type': 'PyTransform',
'yaml_args': config,
}
return transform

return wrapper
30 changes: 30 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,23 @@ def test_name_is_ambiguous(self):
output: AnotherFilter
''')

def test_annotations(self):
t = LinearTransform(5, b=100)
annotations = t.annotations()
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
result = p | YamlTransform(
'''
type: chain
transforms:
- type: Create
config:
elements: [0, 1, 2, 3]
- type: %r
config: %s
''' % (annotations['yaml_type'], annotations['yaml_args']))
assert_that(result, equal_to([100, 105, 110, 115]))


class CreateTimestamped(beam.PTransform):
def __init__(self, elements):
Expand Down Expand Up @@ -610,6 +627,19 @@ def test_prefers_same_provider_class(self):
label='StartWith3')


@beam.transforms.ptransform.annotate_yaml
class LinearTransform(beam.PTransform):
"""A transform used for testing annotate_yaml."""
def __init__(self, a, b):
self._a = a
self._b = b

def expand(self, pcoll):
a = self._a
b = self._b
return pcoll | beam.Map(lambda x: a * x + b)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()