diff --git a/sdks/python/apache_beam/examples/inference/side_input_examples/__init__.py b/sdks/python/apache_beam/examples/inference/side_input_examples/__init__.py new file mode 100644 index 000000000000..cce3acad34a4 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/side_input_examples/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdks/python/apache_beam/examples/inference/side_input_examples/sklearn_side_inputs.py b/sdks/python/apache_beam/examples/inference/side_input_examples/sklearn_side_inputs.py new file mode 100644 index 000000000000..f1e7eae43a76 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/side_input_examples/sklearn_side_inputs.py @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pylint: skip-file + +import argparse +from typing import Iterable +from typing import List +from typing import Tuple + +import apache_beam as beam +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.transforms import window +from apache_beam.ml.inference.utils import WatchFilePattern + + +def process_input(row: str) -> Tuple[int, List[int]]: + data = row.split(',') + label, pixels = int(data[0]), data[1:] + pixels = [int(pixel) for pixel in pixels] + return label, pixels + + +class PostProcessor(beam.DoFn): + """Process the PredictionResult to get the predicted label. + Returns a comma separated string with true label and predicted label. + """ + def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: + label, prediction_result = element + prediction = prediction_result.inference + yield '{},{}'.format(label, prediction) + + +def run(argv=None, save_main_session=True, test_pipeline=None): + """Entry point for running the pipeline.""" + known_args, pipeline_args = parse_known_args(argv) + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + + interval = known_args.interval + file_pattern = known_args.file_pattern + + pipeline = test_pipeline + if not test_pipeline: + pipeline = beam.Pipeline(options=pipeline_options) + + si_pcoll = pipeline | WatchFilePattern( + file_pattern=file_pattern, + interval=interval, + default_value=known_args.model_path) + + model_handler = KeyedModelHandler( + SklearnModelHandlerNumpy(model_uri=known_args.model_path)) + + label_pixel_tuple = ( + pipeline + | "ReadFromPubSub" >> beam.io.ReadFromPubSub(known_args.topic) + | "ApplyMainInputWindow" >> beam.WindowInto( + beam.transforms.window.FixedWindows(interval)) + | "Decode" >> beam.Map(lambda x: x.decode('utf-8')) + | "PreProcessInputs" >> beam.Map(process_input)) + + predictions = ( + label_pixel_tuple + | "RunInference" >> RunInference( + model_handler=model_handler, model_path_pcoll=si_pcoll) + | "PostProcessOutputs" >> beam.ParDo(PostProcessor())) + + predictions | beam.Map(print) + + # _ = predictions | "WriteOutput" >> beam.io.WriteToText( + # known_args.output, shard_name_template='', append_trailing_newlines=True) + + result = pipeline.run().wait_until_finish() + return result + + +def parse_known_args(argv): + """Parses args for the workflow.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', + dest='input', + required=True, + help='text file with comma separated int values.') + parser.add_argument( + '--output', + dest='output', + required=True, + help='Path to save output predictions.') + parser.add_argument( + '--model_path', + dest='model_path', + required=True, + help='Path to load the Sklearn model for Inference.') + parser.add_argument( + '--topic', + default='projects/apache-beam-testing/topics/anandinguva-model-updates', + help='Path to Pub/Sub topic.') + parser.add_argument( + '--file_pattern', + default='gs://apache-beam-ml/tmp/side_input_test/*.pickle', + help='Glob pattern to watch for an update.') + parser.add_argument( + '--interval', + default=360, + type=int, + help='interval used to look for updates on a given file_pattern.') + return parser.parse_known_args(argv) + + +if __name__ == '__main__': + run() diff --git a/sdks/python/apache_beam/ml/inference/utils.py b/sdks/python/apache_beam/ml/inference/utils.py index f30d8a8f6486..db41a00fde99 100644 --- a/sdks/python/apache_beam/ml/inference/utils.py +++ b/sdks/python/apache_beam/ml/inference/utils.py @@ -19,13 +19,27 @@ """ Util/helper functions used in apache_beam.ml.inference. """ +import os +from functools import partial from typing import Any from typing import Dict from typing import Iterable +from typing import List from typing import Optional +from typing import Tuple from typing import Union +import apache_beam as beam +from apache_beam.io.fileio import MatchContinuously +from apache_beam.ml.inference.base import ModelMetdata from apache_beam.ml.inference.base import PredictionResult +from apache_beam.transforms import window +from apache_beam.transforms import trigger +from apache_beam.transforms.userstate import CombiningValueStateSpec +from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import Timestamp + +_START_TIME_STAMP = Timestamp.now() def _convert_to_result( @@ -46,3 +60,109 @@ def _convert_to_result( y in zip(batch, predictions_per_tensor) ] return [PredictionResult(x, y, model_id) for x, y in zip(batch, predictions)] + + +class _CoverIterToSingleton(beam.DoFn): + """ + Internal only; No backwards compatibility. + + The MatchContinuously transform examines all files present in a given + directory and returns those that have timestamps older than the + pipeline's start time. This can produce an Iterable rather than a + Singleton. This class only returns the file path when it is first + encountered, and it is cached as part of the side input caching mechanism. + If the path is seen again, it will not return anything. + By doing this, we can ensure that the output of this transform can be wrapped + with beam.pvalue.AsSingleton(). + """ + COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum) + + def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)): + counter = count_state.read() + if counter == 0: + count_state.add(1) + yield element[1] + + +class _GetLatestFileByTimeStamp(beam.DoFn): + """ + Internal only; No backwards compatibility. + + This DoFn checks the timestamps of files against the time that the pipeline + began running. It returns the files that were modified after the pipeline + started. If no such files are found, it returns a default file as fallback. + + """ + TIME_STATE = CombiningValueStateSpec( + 'count', combine_fn=partial(max, default=_START_TIME_STAMP)) + + def __init__(self, default_value): + self._default_value = default_value + + def process( + self, element, time_state=beam.DoFn.StateParam(TIME_STATE) + ) -> List[Tuple[str, ModelMetdata]]: + path, file_metadata = element + new_ts = file_metadata.last_updated_in_seconds + old_ts = time_state.read() + if new_ts > old_ts: + time_state.clear() + time_state.add(new_ts) + return [( + path, + ModelMetdata( + model_id=file_metadata.path, + model_name=os.path.splitext(os.path.basename( + file_metadata.path))[0]))] + else: + path_short_name = None + if self._default_value: + path_short_name = os.path.splitext( + os.path.basename(self._default_value))[0] + return [( + self._default_value, + ModelMetdata( + model_id=self._default_value, model_name=path_short_name))] + + +class WatchFilePattern(beam.PTransform): + def __init__( + self, + file_pattern, + interval=360, + stop_timestamp=MAX_TIMESTAMP, + default_value=None, + ): + """ + Watch for updates using the file pattern using MatchContinuously transform. + + Note: Start timestamp will be defaulted to timestamp when pipeline was run. + All the files matching file_pattern, that are uploaded before the + pipeline started will be discarded. + + Args: + file_pattern: The file path to read from. + interval: Interval at which to check for files in seconds. + stop_timestamp: Timestamp after which no more files will be checked. + """ + self.file_pattern = file_pattern + self.interval = interval + self.stop_timestamp = stop_timestamp + self._latest_timestamp = None + self._default_value = default_value + + def expand(self, pcoll) -> beam.PCollection[ModelMetdata]: + return ( + pcoll + | 'MatchContinuously' >> MatchContinuously( + file_pattern=self.file_pattern, + interval=self.interval, + stop_timestamp=self.stop_timestamp) + | "AttachKey" >> beam.Map(lambda x: (x.path, x)) + | "GetLatestFileMetaData" >> beam.ParDo( + _GetLatestFileByTimeStamp(default_value=self._default_value)) + | "AcceptNewSideInputOnly" >> beam.ParDo(_CoverIterToSingleton()) + | 'ApplyGlobalWindow' >> beam.transforms.WindowInto( + window.GlobalWindows(), + trigger=trigger.AfterProcessingTime(1), + accumulation_mode=trigger.AccumulationMode.DISCARDING))