diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py index 2a86f0251e7..62dbbc5fb77 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py @@ -53,18 +53,15 @@ def test_combining_value_state(self): @parameterized_class([ - {'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'dill'}, - {'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'cloudpickle'}, - {'runner': fn_api_runner.FnApiRunner, 'pickler': 'dill'}, - {'runner': fn_api_runner.FnApiRunner, 'pickler': 'cloudpickle'}, - ]) # yapf: disable + {'runner': direct_runner.BundleBasedDirectRunner}, + {'runner': fn_api_runner.FnApiRunner}, +]) # yapf: disable class LocalCombineFnLifecycleTest(unittest.TestCase): def tearDown(self): CallSequenceEnforcingCombineFn.instances.clear() def test_combine(self): - test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"]) - run_combine(TestPipeline(runner=self.runner(), options=test_options)) + run_combine(TestPipeline(runner=self.runner())) self._assert_teardown_called() def test_non_liftable_combine(self): diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 456007c4ca4..fa03840d45a 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -3147,40 +3147,33 @@ def process(self, element): yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value)) class PreCombineFn(CombineFn): - def __init__(self): - # Deepcopy of the combine_fn to avoid sharing state between lifted - # stages when using cloudpickle. - self._combine_fn_copy = copy.deepcopy(combine_fn) - self.setup = self._combine_fn_copy.setup - self.create_accumulator = self._combine_fn_copy.create_accumulator - self.add_input = self._combine_fn_copy.add_input - self.merge_accumulators = self._combine_fn_copy.merge_accumulators - self.compact = self._combine_fn_copy.compact - self.teardown = self._combine_fn_copy.teardown - @staticmethod def extract_output(accumulator): # Boolean indicates this is an accumulator. return (True, accumulator) + setup = combine_fn.setup + create_accumulator = combine_fn.create_accumulator + add_input = combine_fn.add_input + merge_accumulators = combine_fn.merge_accumulators + compact = combine_fn.compact + teardown = combine_fn.teardown + class PostCombineFn(CombineFn): - def __init__(self): - # Deepcopy of the combine_fn to avoid sharing state between lifted - # stages when using cloudpickle. - self._combine_fn_copy = copy.deepcopy(combine_fn) - self.setup = self._combine_fn_copy.setup - self.create_accumulator = self._combine_fn_copy.create_accumulator - self.merge_accumulators = self._combine_fn_copy.merge_accumulators - self.compact = self._combine_fn_copy.compact - self.extract_output = self._combine_fn_copy.extract_output - self.teardown = self._combine_fn_copy.teardown - - def add_input(self, accumulator, element): + @staticmethod + def add_input(accumulator, element): is_accumulator, value = element if is_accumulator: - return self._combine_fn_copy.merge_accumulators([accumulator, value]) + return combine_fn.merge_accumulators([accumulator, value]) else: - return self._combine_fn_copy.add_input(accumulator, value) + return combine_fn.add_input(accumulator, value) + + setup = combine_fn.setup + create_accumulator = combine_fn.create_accumulator + merge_accumulators = combine_fn.merge_accumulators + compact = combine_fn.compact + extract_output = combine_fn.extract_output + teardown = combine_fn.teardown def StripNonce(nonce_key_value): (_, key), value = nonce_key_value