Skip to content

Commit

Permalink
Revert "Deepcopy combine_fn in PrecombineFn and PostCombineFn." (apac…
Browse files Browse the repository at this point in the history
…he#32634)

This reverts commit eaf53e5.
  • Loading branch information
claudevdm authored and reeba212 committed Dec 4, 2024
1 parent 8737e66 commit 57df53a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 32 deletions.
11 changes: 4 additions & 7 deletions sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,15 @@ def test_combining_value_state(self):


@parameterized_class([
{'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'dill'},
{'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'cloudpickle'},
{'runner': fn_api_runner.FnApiRunner, 'pickler': 'dill'},
{'runner': fn_api_runner.FnApiRunner, 'pickler': 'cloudpickle'},
]) # yapf: disable
{'runner': direct_runner.BundleBasedDirectRunner},
{'runner': fn_api_runner.FnApiRunner},
]) # yapf: disable
class LocalCombineFnLifecycleTest(unittest.TestCase):
def tearDown(self):
CallSequenceEnforcingCombineFn.instances.clear()

def test_combine(self):
test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"])
run_combine(TestPipeline(runner=self.runner(), options=test_options))
run_combine(TestPipeline(runner=self.runner()))
self._assert_teardown_called()

def test_non_liftable_combine(self):
Expand Down
43 changes: 18 additions & 25 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3147,40 +3147,33 @@ def process(self, element):
yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value))

class PreCombineFn(CombineFn):
def __init__(self):
# Deepcopy of the combine_fn to avoid sharing state between lifted
# stages when using cloudpickle.
self._combine_fn_copy = copy.deepcopy(combine_fn)
self.setup = self._combine_fn_copy.setup
self.create_accumulator = self._combine_fn_copy.create_accumulator
self.add_input = self._combine_fn_copy.add_input
self.merge_accumulators = self._combine_fn_copy.merge_accumulators
self.compact = self._combine_fn_copy.compact
self.teardown = self._combine_fn_copy.teardown

@staticmethod
def extract_output(accumulator):
# Boolean indicates this is an accumulator.
return (True, accumulator)

setup = combine_fn.setup
create_accumulator = combine_fn.create_accumulator
add_input = combine_fn.add_input
merge_accumulators = combine_fn.merge_accumulators
compact = combine_fn.compact
teardown = combine_fn.teardown

class PostCombineFn(CombineFn):
def __init__(self):
# Deepcopy of the combine_fn to avoid sharing state between lifted
# stages when using cloudpickle.
self._combine_fn_copy = copy.deepcopy(combine_fn)
self.setup = self._combine_fn_copy.setup
self.create_accumulator = self._combine_fn_copy.create_accumulator
self.merge_accumulators = self._combine_fn_copy.merge_accumulators
self.compact = self._combine_fn_copy.compact
self.extract_output = self._combine_fn_copy.extract_output
self.teardown = self._combine_fn_copy.teardown

def add_input(self, accumulator, element):
@staticmethod
def add_input(accumulator, element):
is_accumulator, value = element
if is_accumulator:
return self._combine_fn_copy.merge_accumulators([accumulator, value])
return combine_fn.merge_accumulators([accumulator, value])
else:
return self._combine_fn_copy.add_input(accumulator, value)
return combine_fn.add_input(accumulator, value)

setup = combine_fn.setup
create_accumulator = combine_fn.create_accumulator
merge_accumulators = combine_fn.merge_accumulators
compact = combine_fn.compact
extract_output = combine_fn.extract_output
teardown = combine_fn.teardown

def StripNonce(nonce_key_value):
(_, key), value = nonce_key_value
Expand Down

0 comments on commit 57df53a

Please sign in to comment.