Skip to content

Commit

Permalink
Add WatchFilePattern transform
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Feb 9, 2023
1 parent ea1625a commit 9394f03
Show file tree
Hide file tree
Showing 6 changed files with 433 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
* Add UDF metrics support for Samza portable mode.
* Option for SparkRunner to avoid the need of SDF output to fit in memory ([#23852](https://github.com/apache/beam/issues/23852)).
This helps e.g. with ParquetIO reads. Turn the feature on by adding experiment `use_bounded_concurrent_output_for_sdf`.
* Add `WatchFilePattern` transform, which can be used as a side input to the RunInference PTransfrom to watch for model updates using a file pattern. ([#24042](https://github.com/apache/beam/issues/24042))

## Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
A pipeline that uses RunInference PTransform to perform image classification
and uses WatchFilePattern as side input to the RunInference PTransform.
WatchFilePattern is used to watch for a file updates matching the file_pattern
based on timestamps and emits latest model metadata, which is used in
RunInference API for the dynamic model updates without the need for stopping
the beam pipeline.
This pipeline follows the pattern from
https://beam.apache.org/documentation/patterns/side-inputs/
This pipeline expects a PubSub topic as source, which emits an image
path(UTF-8 encoded) that is accessible by the pipeline.
To run the example on DataflowRunner,
python apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py # pylint: disable=line-too-long
--model_path=gs://apache-beam-ml/tmp/side_input_test/resnet152.pth
--project=<your-project>
--re=<your-region>
--temp_location=<your-tmp-location>
--staging_location=<your-staging-location>
--runner=DataflowRunner
--streaming
--interval=10
--num_workers=5
--requirements_file=apache_beam/ml/inference/torch_tests_requirements.txt
"""

import argparse
import io
import logging
import os
from typing import Iterable
from typing import Iterator
from typing import Optional
from typing import Tuple

import apache_beam as beam
import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
from apache_beam.ml.inference.utils import WatchFilePattern
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.runners.runner import PipelineResult
from PIL import Image
from torchvision import models
from torchvision import transforms


def read_image(image_file_name: str,
path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
if path_to_dir is not None:
image_file_name = os.path.join(path_to_dir, image_file_name)
with FileSystems().open(image_file_name, 'r') as file:
data = Image.open(io.BytesIO(file.read())).convert('RGB')
return image_file_name, data


def preprocess_image(data: Image.Image) -> torch.Tensor:
image_size = (224, 224)
# Pre-trained PyTorch models expect input images normalized with the
# below values (see: https://pytorch.org/vision/stable/models.html)
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
normalize,
])
return transform(data)


def filter_empty_lines(text: str) -> Iterator[str]:
if len(text.strip()) > 0:
yield text


class PostProcessor(beam.DoFn):
"""
Return filename, prediction and the model id used to perform the
prediction
"""
def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
filename, prediction_result = element
prediction = torch.argmax(prediction_result.inference, dim=0)
yield filename, prediction, prediction_result.model_id


def parse_known_args(argv):
"""Parses args for the workflow."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--topic',
dest='topic',
default='projects/apache-beam-testing/topics/anandinguva-model-updates',
help='PubSub topic emitting absolute path to the images.'
'Path must be accessible by the pipeline.')
parser.add_argument(
'--model_path',
dest='model_path',
default='gs://apache-beam-samples/run_inference/resnet152.pth',
help="Path to the model's state_dict.")
parser.add_argument(
'--file_pattern',
default='gs://apache-beam-ml/tmp/side_input_test/*.pth',
help='Glob pattern to watch for an update.')
parser.add_argument(
'--interval',
default=10,
type=int,
help='Interval used to check for file updates.')

return parser.parse_known_args(argv)


def run(
argv=None,
model_class=None,
model_params=None,
save_main_session=True,
device='CPU',
test_pipeline=None) -> PipelineResult:
"""
Args:
argv: Command line arguments defined for this example.
model_class: Reference to the class definition of the model.
model_params: Parameters passed to the constructor of the model_class.
These will be used to instantiate the model object in the
RunInference PTransform.
save_main_session: Used for internal testing.
device: Device to be used on the Runner. Choices are (CPU, GPU).
test_pipeline: Used for internal testing.
"""
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

if not model_class:
model_class = models.resnet152
model_params = {'num_classes': 1000}

class PytorchModelHandlerTensorWithBatchSize(PytorchModelHandlerTensor):
def batch_elements_kwargs(self):
return {'min_batch_size': 10, 'max_batch_size': 100}

# In this example we pass keyed inputs to RunInference transform.
# Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
model_handler = KeyedModelHandler(
PytorchModelHandlerTensorWithBatchSize(
state_dict_path=known_args.model_path,
model_class=model_class,
model_params=model_params,
device=device))

pipeline = test_pipeline
if not test_pipeline:
pipeline = beam.Pipeline(options=pipeline_options)

side_input = pipeline | WatchFilePattern(
interval=known_args.interval, file_pattern=known_args.file_pattern)

filename_value_pair = (
pipeline
| 'ReadImageNamesFromPubSub' >> beam.io.ReadFromPubSub(known_args.topic)
| 'DecodeBytes' >> beam.Map(lambda x: x.decode('utf-8'))
| 'ReadImageData' >>
beam.Map(lambda image_name: read_image(image_file_name=image_name))
| 'PreprocessImages' >> beam.MapTuple(
lambda file_name, data: (file_name, preprocess_image(data))))
predictions = (
filename_value_pair
| 'PyTorchRunInference' >> RunInference(
model_handler, model_metadata_pcoll=side_input)
| 'ProcessOutput' >> beam.ParDo(PostProcessor()))

_ = predictions | beam.Map(logging.info)

result = pipeline.run()
result.wait_until_finish()
return result


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()
6 changes: 4 additions & 2 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def __init__(
start_timestamp=Timestamp.now(),
stop_timestamp=MAX_TIMESTAMP,
match_updated_files=False,
apply_windowing=False):
apply_windowing=False,
empty_match_treatment=EmptyMatchTreatment.ALLOW):
"""Initializes a MatchContinuously transform.
Args:
Expand All @@ -299,6 +300,7 @@ def __init__(
self.stop_ts = stop_timestamp
self.match_upd = match_updated_files
self.apply_windowing = apply_windowing
self.empty_match_treatment = empty_match_treatment

def expand(self, pbegin) -> beam.PCollection[filesystem.FileMetadata]:
# invoke periodic impulse
Expand All @@ -311,7 +313,7 @@ def expand(self, pbegin) -> beam.PCollection[filesystem.FileMetadata]:
match_files = (
impulse
| 'GetFilePattern' >> beam.Map(lambda x: self.file_pattern)
| MatchAll())
| MatchAll(self.empty_match_treatment))

# apply deduplication strategy if required
if self.has_deduplication:
Expand Down
22 changes: 13 additions & 9 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,15 +440,19 @@ def test_run_inference_side_input_in_batch(self):
first_ts + 22,
])

sample_side_input_elements = [(
first_ts + 8,
base.ModelMetdata(
model_id='fake_model_id_1', model_name='fake_model_id_1')),
(
first_ts + 15,
base.ModelMetdata(
model_id='fake_model_id_2',
model_name='fake_model_id_2'))]
sample_side_input_elements = [
(first_ts + 1, base.ModelMetdata(model_id='', model_name='')),
# if model_id is empty string, we use the default model
# handler model URI.
(
first_ts + 8,
base.ModelMetdata(
model_id='fake_model_id_1', model_name='fake_model_id_1')),
(
first_ts + 15,
base.ModelMetdata(
model_id='fake_model_id_2', model_name='fake_model_id_2'))
]

model_handler = FakeModelHandlerReturnsPredictionResult()

Expand Down
Loading

0 comments on commit 9394f03

Please sign in to comment.