Skip to content

Commit

Permalink
Add WatchFilePattern transform
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Jan 27, 2023
1 parent 73ed494 commit ae91ffe
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pylint: skip-file

import argparse
from typing import Iterable
from typing import List
from typing import Tuple

import apache_beam as beam
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.transforms import window
from apache_beam.ml.inference.utils import WatchFilePattern


def process_input(row: str) -> Tuple[int, List[int]]:
data = row.split(',')
label, pixels = int(data[0]), data[1:]
pixels = [int(pixel) for pixel in pixels]
return label, pixels


class PostProcessor(beam.DoFn):
"""Process the PredictionResult to get the predicted label.
Returns a comma separated string with true label and predicted label.
"""
def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]:
label, prediction_result = element
prediction = prediction_result.inference
yield '{},{}'.format(label, prediction)


def run(argv=None, save_main_session=True, test_pipeline=None):
"""Entry point for running the pipeline."""
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

interval = known_args.interval
file_pattern = known_args.file_pattern

pipeline = test_pipeline
if not test_pipeline:
pipeline = beam.Pipeline(options=pipeline_options)

si_pcoll = pipeline | WatchFilePattern(
file_pattern=file_pattern,
interval=interval,
default_value=known_args.model_path)

model_handler = KeyedModelHandler(
SklearnModelHandlerNumpy(model_uri=known_args.model_path))

label_pixel_tuple = (
pipeline
| "ReadFromPubSub" >> beam.io.ReadFromPubSub(known_args.topic)
| "ApplyMainInputWindow" >> beam.WindowInto(
beam.transforms.window.FixedWindows(interval))
| "Decode" >> beam.Map(lambda x: x.decode('utf-8'))
| "PreProcessInputs" >> beam.Map(process_input))

predictions = (
label_pixel_tuple
| "RunInference" >> RunInference(
model_handler=model_handler, model_path_pcoll=si_pcoll)
| "PostProcessOutputs" >> beam.ParDo(PostProcessor()))

predictions | beam.Map(print)

# _ = predictions | "WriteOutput" >> beam.io.WriteToText(
# known_args.output, shard_name_template='', append_trailing_newlines=True)

result = pipeline.run().wait_until_finish()
return result


def parse_known_args(argv):
"""Parses args for the workflow."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--input',
dest='input',
required=True,
help='text file with comma separated int values.')
parser.add_argument(
'--output',
dest='output',
required=True,
help='Path to save output predictions.')
parser.add_argument(
'--model_path',
dest='model_path',
required=True,
help='Path to load the Sklearn model for Inference.')
parser.add_argument(
'--topic',
default='projects/apache-beam-testing/topics/anandinguva-model-updates',
help='Path to Pub/Sub topic.')
parser.add_argument(
'--file_pattern',
default='gs://apache-beam-ml/tmp/side_input_test/*.pickle',
help='Glob pattern to watch for an update.')
parser.add_argument(
'--interval',
default=360,
type=int,
help='interval used to look for updates on a given file_pattern.')
return parser.parse_known_args(argv)


if __name__ == '__main__':
run()
120 changes: 120 additions & 0 deletions sdks/python/apache_beam/ml/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,27 @@
"""
Util/helper functions used in apache_beam.ml.inference.
"""
import os
from functools import partial
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import apache_beam as beam
from apache_beam.io.fileio import MatchContinuously
from apache_beam.ml.inference.base import ModelMetdata
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.transforms import window
from apache_beam.transforms import trigger
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import Timestamp

_START_TIME_STAMP = Timestamp.now()


def _convert_to_result(
Expand All @@ -46,3 +60,109 @@ def _convert_to_result(
y in zip(batch, predictions_per_tensor)
]
return [PredictionResult(x, y, model_id) for x, y in zip(batch, predictions)]


class _CoverIterToSingleton(beam.DoFn):
"""
Internal only; No backwards compatibility.
The MatchContinuously transform examines all files present in a given
directory and returns those that have timestamps older than the
pipeline's start time. This can produce an Iterable rather than a
Singleton. This class only returns the file path when it is first
encountered, and it is cached as part of the side input caching mechanism.
If the path is seen again, it will not return anything.
By doing this, we can ensure that the output of this transform can be wrapped
with beam.pvalue.AsSingleton().
"""
COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)

def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
counter = count_state.read()
if counter == 0:
count_state.add(1)
yield element[1]


class _GetLatestFileByTimeStamp(beam.DoFn):
"""
Internal only; No backwards compatibility.
This DoFn checks the timestamps of files against the time that the pipeline
began running. It returns the files that were modified after the pipeline
started. If no such files are found, it returns a default file as fallback.
"""
TIME_STATE = CombiningValueStateSpec(
'count', combine_fn=partial(max, default=_START_TIME_STAMP))

def __init__(self, default_value):
self._default_value = default_value

def process(
self, element, time_state=beam.DoFn.StateParam(TIME_STATE)
) -> List[Tuple[str, ModelMetdata]]:
path, file_metadata = element
new_ts = file_metadata.last_updated_in_seconds
old_ts = time_state.read()
if new_ts > old_ts:
time_state.clear()
time_state.add(new_ts)
return [(
path,
ModelMetdata(
model_id=file_metadata.path,
model_name=os.path.splitext(os.path.basename(
file_metadata.path))[0]))]
else:
path_short_name = None
if self._default_value:
path_short_name = os.path.splitext(
os.path.basename(self._default_value))[0]
return [(
self._default_value,
ModelMetdata(
model_id=self._default_value, model_name=path_short_name))]


class WatchFilePattern(beam.PTransform):
def __init__(
self,
file_pattern,
interval=360,
stop_timestamp=MAX_TIMESTAMP,
default_value=None,
):
"""
Watch for updates using the file pattern using MatchContinuously transform.
Note: Start timestamp will be defaulted to timestamp when pipeline was run.
All the files matching file_pattern, that are uploaded before the
pipeline started will be discarded.
Args:
file_pattern: The file path to read from.
interval: Interval at which to check for files in seconds.
stop_timestamp: Timestamp after which no more files will be checked.
"""
self.file_pattern = file_pattern
self.interval = interval
self.stop_timestamp = stop_timestamp
self._latest_timestamp = None
self._default_value = default_value

def expand(self, pcoll) -> beam.PCollection[ModelMetdata]:
return (
pcoll
| 'MatchContinuously' >> MatchContinuously(
file_pattern=self.file_pattern,
interval=self.interval,
stop_timestamp=self.stop_timestamp)
| "AttachKey" >> beam.Map(lambda x: (x.path, x))
| "GetLatestFileMetaData" >> beam.ParDo(
_GetLatestFileByTimeStamp(default_value=self._default_value))
| "AcceptNewSideInputOnly" >> beam.ParDo(_CoverIterToSingleton())
| 'ApplyGlobalWindow' >> beam.transforms.WindowInto(
window.GlobalWindows(),
trigger=trigger.AfterProcessingTime(1),
accumulation_mode=trigger.AccumulationMode.DISCARDING))

0 comments on commit ae91ffe

Please sign in to comment.