Skip to content

Commit

Permalink
Add sideinputs to the RunInference Transform (#25200)
Browse files Browse the repository at this point in the history
* Add model pcoll param to the RunInference Ptransform
fix up parentheses
Add time.sleep
Update PredictionResult
Update Pcollection type hint
Add model_path param in the ModelHandler

* Add sklearn side input example
refactor name

* Add ModeMetadata and some refactoring
Update example

* refactor _convert_to_result and add it to the utils.py
Update model path in pytorch_inference.py
Add model path to sklearn model
Add model id to sklearn inference PredictionResult
clean up base.py
Update utils
update base.py
Update base.py

* Add tag to the RunInference DoFn

* Add enable_side_input_loading flag

* Add helper functions
Add default value
Add prefix/human readable id to ModelMetadata

* Add doc string, refactor utils code

* Fix pytorch inference tests

* Fix up sklearn inference

* Remove logging

* Add thread Lock when there is an update to side input

* Check if side input is EmptySideInput

* Add unit test for side input loading

* Remove examples

* Add log when side input path is updated

* Add test to Dataflow
Fix snippets tests
revert build settings
Fix lint

* Refactor side input loading code
Add model_id to PredictionResult

* Add documentation, changelog

* Add Singleton view doc

* Fix whitespace, tests

* fix weird spacing

* Remove beam website udpate

* Revert "fix weird spacing"

This reverts commit f9a61a1.

* Add WatchFilePattern transform

* undo changes to beam wesbite

* Pass side inputs only in streaming mode

* Revert "Add WatchFilePattern transform"

This reverts commit ae91ffe.

* Add lines to website page

* Add test

* Add unit test to catch --streaming flag and Singleton SideInput

* Addressing PR comments

* Add logic to detect windows on side inputs

* Add more tests

* Remove redundant code

* Update test

* Fix lint

* Add postcommit markers.

* Remove `and`

* fixup lint

* Modify message

* Add check for default model

* Update message

* Add validates runner

* Fix test

* Add PipelineVisitor for RunInference during construction time

* Address comments based on PR

* Remove restriction on the side inputs

* Remove/add tests

* Modify logging

* Add tests

* Add 2.46.0 change log

* fix typo

---------

Co-authored-by: Anand Inguva <[email protected]>
Co-authored-by: Danny McCormick <[email protected]>
  • Loading branch information
3 people authored Feb 2, 2023
1 parent 01aa470 commit 10805a2
Show file tree
Hide file tree
Showing 10 changed files with 637 additions and 120 deletions.
31 changes: 31 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,37 @@
* ([#X](https://github.com/apache/beam/issues/X)).
-->

# [2.46.0] - Unreleased

## Highlights

* RunInference PTransform will accept model paths as SideInputs in Python SDK. ([#24042](https://github.com/apache/beam/issues/24042))

## I/Os

* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).

## New Features / Improvements

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).

## Breaking Changes

* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)).

## Deprecations

* X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)).

## Bugfixes

* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).

## Known Issues

* ([#X](https://github.com/apache/beam/issues/X)).

# [2.45.0] - Unreleased

## Highlights
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Used for internal testing. No backwards compatibility.
"""

import argparse
import logging
import time
from typing import Iterable
from typing import Optional
from typing import Sequence

import apache_beam as beam
from apache_beam.ml.inference import base
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.transforms import trigger
from apache_beam.transforms import window
from apache_beam.transforms.periodicsequence import PeriodicImpulse
from apache_beam.transforms.userstate import CombiningValueStateSpec


# create some fake models which returns different inference results.
class FakeModelDefault:
def predict(self, example: int) -> int:
return example


class FakeModelAdd(FakeModelDefault):
def predict(self, example: int) -> int:
return example + 1


class FakeModelSub(FakeModelDefault):
def predict(self, example: int) -> int:
return example - 1


class FakeModelHandlerReturnsPredictionResult(
base.ModelHandler[int, base.PredictionResult, FakeModelDefault]):
def __init__(self, clock=None, model_id='model_default'):
self.model_id = model_id
self._fake_clock = clock

def load_model(self):
if self._fake_clock:
self._fake_clock.current_time_ns += 500_000_000 # 500ms
if self.model_id == 'model_add.pkl':
return FakeModelAdd()
elif self.model_id == 'model_sub.pkl':
return FakeModelSub()
return FakeModelDefault()

def run_inference(
self,
batch: Sequence[int],
model: FakeModelDefault,
inference_args=None) -> Iterable[base.PredictionResult]:
for example in batch:
yield base.PredictionResult(
model_id=self.model_id,
example=example,
inference=model.predict(example))

def update_model_path(self, model_path: Optional[str] = None):
self.model_id = model_path if model_path else self.model_id


def run(argv=None, save_main_session=True):
parser = argparse.ArgumentParser()
first_ts = time.time()
side_input_interval = 60
main_input_interval = 20
# give some time for dataflow to start.
last_ts = first_ts + 1200
mid_ts = (first_ts + last_ts) / 2

_, pipeline_args = parser.parse_known_args(argv)
options = PipelineOptions(pipeline_args)
options.view_as(SetupOptions).save_main_session = save_main_session

class GetModel(beam.DoFn):
def process(self, element) -> Iterable[base.ModelMetdata]:
if time.time() > mid_ts:
yield base.ModelMetdata(
model_id='model_add.pkl', model_name='model_add')
else:
yield base.ModelMetdata(
model_id='model_sub.pkl', model_name='model_sub')

class _EmitSingletonSideInput(beam.DoFn):
COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)

def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
_, path = element
counter = count_state.read()
if counter == 0:
count_state.add(1)
yield path

def validate_prediction_result(x: base.PredictionResult):
model_id = x.model_id
if model_id == 'model_sub.pkl':
assert (x.example == 1 and x.inference == 0)

if model_id == 'model_add.pkl':
assert (x.example == 1 and x.inference == 2)

if model_id == 'model_default':
assert (x.example == 1 and x.inference == 1)

with beam.Pipeline(options=options) as pipeline:
side_input = (
pipeline
| "SideInputPColl" >> PeriodicImpulse(
first_ts, last_ts, fire_interval=side_input_interval)
| "GetModelId" >> beam.ParDo(GetModel())
| "AttachKey" >> beam.Map(lambda x: (x, x))
# due to periodic impulse, which has a start timestamp before
# Dataflow pipeline process data, it can trigger in multiple
# firings, causing an Iterable instead of singleton. So, using
# the _EmitSingletonSideInput DoFn will ensure unique path will be
# fired only once.
| "GetSingleton" >> beam.ParDo(_EmitSingletonSideInput())
| "ApplySideInputWindow" >> beam.WindowInto(
window.GlobalWindows(),
trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
accumulation_mode=trigger.AccumulationMode.DISCARDING))

model_handler = FakeModelHandlerReturnsPredictionResult()
inference_pcoll = (
pipeline
| "MainInputPColl" >> PeriodicImpulse(
first_ts,
last_ts,
fire_interval=main_input_interval,
apply_windowing=True)
| beam.Map(lambda x: 1)
| base.RunInference(
model_handler=model_handler, model_metadata_pcoll=side_input))

_ = inference_pcoll | "AssertPredictionResult" >> beam.Map(
validate_prediction_result)

_ = inference_pcoll | "Logging" >> beam.Map(logging.info)


if __name__ == '__main__':
run()
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

# pytype: skip-file

import re
import unittest
from io import StringIO

Expand All @@ -44,40 +43,40 @@

def check_torch_keyed_model_handler():
expected = '''[START torch_keyed_model_handler]
('first_question', PredictionResult(example=tensor([105.]), inference=tensor([523.6982])))
('second_question', PredictionResult(example=tensor([108.]), inference=tensor([538.5867])))
('third_question', PredictionResult(example=tensor([1000.]), inference=tensor([4965.4019])))
('fourth_question', PredictionResult(example=tensor([1013.]), inference=tensor([5029.9180])))
('first_question', PredictionResult(example=tensor([105.]), inference=tensor([523.6982]), model_id='gs://apache-beam-samples/run_inference/five_times_table_torch.pt'))
('second_question', PredictionResult(example=tensor([108.]), inference=tensor([538.5867]), model_id='gs://apache-beam-samples/run_inference/five_times_table_torch.pt'))
('third_question', PredictionResult(example=tensor([1000.]), inference=tensor([4965.4019]), model_id='gs://apache-beam-samples/run_inference/five_times_table_torch.pt'))
('fourth_question', PredictionResult(example=tensor([1013.]), inference=tensor([5029.9180]), model_id='gs://apache-beam-samples/run_inference/five_times_table_torch.pt'))
[END torch_keyed_model_handler] '''.splitlines()[1:-1]
return expected


def check_sklearn_keyed_model_handler(actual):
expected = '''[START sklearn_keyed_model_handler]
('first_question', PredictionResult(example=[105.0], inference=array([525.])))
('second_question', PredictionResult(example=[108.0], inference=array([540.])))
('third_question', PredictionResult(example=[1000.0], inference=array([5000.])))
('fourth_question', PredictionResult(example=[1013.0], inference=array([5065.])))
('first_question', PredictionResult(example=[105.0], inference=array([525.]), model_id='gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl'))
('second_question', PredictionResult(example=[108.0], inference=array([540.]), model_id='gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl'))
('third_question', PredictionResult(example=[1000.0], inference=array([5000.]), model_id='gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl'))
('fourth_question', PredictionResult(example=[1013.0], inference=array([5065.]), model_id='gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl'))
[END sklearn_keyed_model_handler] '''.splitlines()[1:-1]
assert_matches_stdout(actual, expected)


def check_torch_unkeyed_model_handler():
expected = '''[START torch_unkeyed_model_handler]
PredictionResult(example=tensor([10.]), inference=tensor([52.2325]))
PredictionResult(example=tensor([40.]), inference=tensor([201.1165]))
PredictionResult(example=tensor([60.]), inference=tensor([300.3724]))
PredictionResult(example=tensor([90.]), inference=tensor([449.2563]))
PredictionResult(example=tensor([10.]), inference=tensor([52.2325]), model_id='gs://apache-beam-samples/run_inference/five_times_table_torch.pt')
PredictionResult(example=tensor([40.]), inference=tensor([201.1165]), model_id='gs://apache-beam-samples/run_inference/five_times_table_torch.pt')
PredictionResult(example=tensor([60.]), inference=tensor([300.3724]), model_id='gs://apache-beam-samples/run_inference/five_times_table_torch.pt')
PredictionResult(example=tensor([90.]), inference=tensor([449.2563]), model_id='gs://apache-beam-samples/run_inference/five_times_table_torch.pt')
[END torch_unkeyed_model_handler] '''.splitlines()[1:-1]
return expected


def check_sklearn_unkeyed_model_handler(actual):
expected = '''[START sklearn_unkeyed_model_handler]
PredictionResult(example=array([20.], dtype=float32), inference=array([100.], dtype=float32))
PredictionResult(example=array([40.], dtype=float32), inference=array([200.], dtype=float32))
PredictionResult(example=array([60.], dtype=float32), inference=array([300.], dtype=float32))
PredictionResult(example=array([90.], dtype=float32), inference=array([450.], dtype=float32))
PredictionResult(example=array([20.], dtype=float32), inference=array([100.], dtype=float32), model_id='gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl')
PredictionResult(example=array([40.], dtype=float32), inference=array([200.], dtype=float32), model_id='gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl')
PredictionResult(example=array([60.], dtype=float32), inference=array([300.], dtype=float32), model_id='gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl')
PredictionResult(example=array([90.], dtype=float32), inference=array([450.], dtype=float32), model_id='gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl')
[END sklearn_unkeyed_model_handler] '''.splitlines()[1:-1]
assert_matches_stdout(actual, expected)

Expand All @@ -103,22 +102,14 @@ def test_check_torch_keyed_model_handler(self, mock_stdout):
runinference.torch_keyed_model_handler()
predicted = mock_stdout.getvalue().splitlines()
expected = check_torch_keyed_model_handler()
actual_stdout = [line.split(':')[0] for line in predicted]
replace_fn = lambda x: re.sub(r"<UnbindBackward\d*>", "<UnbindBackward>", x)
actual_stdout = [replace_fn(x) for x in actual_stdout]
expected_stdout = [line.split(':')[0] for line in expected]
self.assertEqual(actual_stdout, expected_stdout)
self.assertEqual(predicted, expected)

@pytest.mark.uses_pytorch
def test_check_torch_unkeyed_model_handler(self, mock_stdout):
runinference.torch_unkeyed_model_handler()
predicted = mock_stdout.getvalue().splitlines()
expected = check_torch_unkeyed_model_handler()
actual_stdout = [line.split(':')[0] for line in predicted]
replace_fn = lambda x: re.sub(r"<UnbindBackward\d*>", "<UnbindBackward>", x)
actual_stdout = [replace_fn(x) for x in actual_stdout]
expected_stdout = [line.split(':')[0] for line in expected]
self.assertEqual(actual_stdout, expected_stdout)
self.assertEqual(predicted, expected)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 10805a2

Please sign in to comment.