-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add sideinputs to the RunInference Transform (#25200)
* Add model pcoll param to the RunInference Ptransform fix up parentheses Add time.sleep Update PredictionResult Update Pcollection type hint Add model_path param in the ModelHandler * Add sklearn side input example refactor name * Add ModeMetadata and some refactoring Update example * refactor _convert_to_result and add it to the utils.py Update model path in pytorch_inference.py Add model path to sklearn model Add model id to sklearn inference PredictionResult clean up base.py Update utils update base.py Update base.py * Add tag to the RunInference DoFn * Add enable_side_input_loading flag * Add helper functions Add default value Add prefix/human readable id to ModelMetadata * Add doc string, refactor utils code * Fix pytorch inference tests * Fix up sklearn inference * Remove logging * Add thread Lock when there is an update to side input * Check if side input is EmptySideInput * Add unit test for side input loading * Remove examples * Add log when side input path is updated * Add test to Dataflow Fix snippets tests revert build settings Fix lint * Refactor side input loading code Add model_id to PredictionResult * Add documentation, changelog * Add Singleton view doc * Fix whitespace, tests * fix weird spacing * Remove beam website udpate * Revert "fix weird spacing" This reverts commit f9a61a1. * Add WatchFilePattern transform * undo changes to beam wesbite * Pass side inputs only in streaming mode * Revert "Add WatchFilePattern transform" This reverts commit ae91ffe. * Add lines to website page * Add test * Add unit test to catch --streaming flag and Singleton SideInput * Addressing PR comments * Add logic to detect windows on side inputs * Add more tests * Remove redundant code * Update test * Fix lint * Add postcommit markers. * Remove `and` * fixup lint * Modify message * Add check for default model * Update message * Add validates runner * Fix test * Add PipelineVisitor for RunInference during construction time * Address comments based on PR * Remove restriction on the side inputs * Remove/add tests * Modify logging * Add tests * Add 2.46.0 change log * fix typo --------- Co-authored-by: Anand Inguva <[email protected]> Co-authored-by: Danny McCormick <[email protected]>
- Loading branch information
1 parent
01aa470
commit 10805a2
Showing
10 changed files
with
637 additions
and
120 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
165 changes: 165 additions & 0 deletions
165
sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
Used for internal testing. No backwards compatibility. | ||
""" | ||
|
||
import argparse | ||
import logging | ||
import time | ||
from typing import Iterable | ||
from typing import Optional | ||
from typing import Sequence | ||
|
||
import apache_beam as beam | ||
from apache_beam.ml.inference import base | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
from apache_beam.options.pipeline_options import SetupOptions | ||
from apache_beam.transforms import trigger | ||
from apache_beam.transforms import window | ||
from apache_beam.transforms.periodicsequence import PeriodicImpulse | ||
from apache_beam.transforms.userstate import CombiningValueStateSpec | ||
|
||
|
||
# create some fake models which returns different inference results. | ||
class FakeModelDefault: | ||
def predict(self, example: int) -> int: | ||
return example | ||
|
||
|
||
class FakeModelAdd(FakeModelDefault): | ||
def predict(self, example: int) -> int: | ||
return example + 1 | ||
|
||
|
||
class FakeModelSub(FakeModelDefault): | ||
def predict(self, example: int) -> int: | ||
return example - 1 | ||
|
||
|
||
class FakeModelHandlerReturnsPredictionResult( | ||
base.ModelHandler[int, base.PredictionResult, FakeModelDefault]): | ||
def __init__(self, clock=None, model_id='model_default'): | ||
self.model_id = model_id | ||
self._fake_clock = clock | ||
|
||
def load_model(self): | ||
if self._fake_clock: | ||
self._fake_clock.current_time_ns += 500_000_000 # 500ms | ||
if self.model_id == 'model_add.pkl': | ||
return FakeModelAdd() | ||
elif self.model_id == 'model_sub.pkl': | ||
return FakeModelSub() | ||
return FakeModelDefault() | ||
|
||
def run_inference( | ||
self, | ||
batch: Sequence[int], | ||
model: FakeModelDefault, | ||
inference_args=None) -> Iterable[base.PredictionResult]: | ||
for example in batch: | ||
yield base.PredictionResult( | ||
model_id=self.model_id, | ||
example=example, | ||
inference=model.predict(example)) | ||
|
||
def update_model_path(self, model_path: Optional[str] = None): | ||
self.model_id = model_path if model_path else self.model_id | ||
|
||
|
||
def run(argv=None, save_main_session=True): | ||
parser = argparse.ArgumentParser() | ||
first_ts = time.time() | ||
side_input_interval = 60 | ||
main_input_interval = 20 | ||
# give some time for dataflow to start. | ||
last_ts = first_ts + 1200 | ||
mid_ts = (first_ts + last_ts) / 2 | ||
|
||
_, pipeline_args = parser.parse_known_args(argv) | ||
options = PipelineOptions(pipeline_args) | ||
options.view_as(SetupOptions).save_main_session = save_main_session | ||
|
||
class GetModel(beam.DoFn): | ||
def process(self, element) -> Iterable[base.ModelMetdata]: | ||
if time.time() > mid_ts: | ||
yield base.ModelMetdata( | ||
model_id='model_add.pkl', model_name='model_add') | ||
else: | ||
yield base.ModelMetdata( | ||
model_id='model_sub.pkl', model_name='model_sub') | ||
|
||
class _EmitSingletonSideInput(beam.DoFn): | ||
COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum) | ||
|
||
def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)): | ||
_, path = element | ||
counter = count_state.read() | ||
if counter == 0: | ||
count_state.add(1) | ||
yield path | ||
|
||
def validate_prediction_result(x: base.PredictionResult): | ||
model_id = x.model_id | ||
if model_id == 'model_sub.pkl': | ||
assert (x.example == 1 and x.inference == 0) | ||
|
||
if model_id == 'model_add.pkl': | ||
assert (x.example == 1 and x.inference == 2) | ||
|
||
if model_id == 'model_default': | ||
assert (x.example == 1 and x.inference == 1) | ||
|
||
with beam.Pipeline(options=options) as pipeline: | ||
side_input = ( | ||
pipeline | ||
| "SideInputPColl" >> PeriodicImpulse( | ||
first_ts, last_ts, fire_interval=side_input_interval) | ||
| "GetModelId" >> beam.ParDo(GetModel()) | ||
| "AttachKey" >> beam.Map(lambda x: (x, x)) | ||
# due to periodic impulse, which has a start timestamp before | ||
# Dataflow pipeline process data, it can trigger in multiple | ||
# firings, causing an Iterable instead of singleton. So, using | ||
# the _EmitSingletonSideInput DoFn will ensure unique path will be | ||
# fired only once. | ||
| "GetSingleton" >> beam.ParDo(_EmitSingletonSideInput()) | ||
| "ApplySideInputWindow" >> beam.WindowInto( | ||
window.GlobalWindows(), | ||
trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)), | ||
accumulation_mode=trigger.AccumulationMode.DISCARDING)) | ||
|
||
model_handler = FakeModelHandlerReturnsPredictionResult() | ||
inference_pcoll = ( | ||
pipeline | ||
| "MainInputPColl" >> PeriodicImpulse( | ||
first_ts, | ||
last_ts, | ||
fire_interval=main_input_interval, | ||
apply_windowing=True) | ||
| beam.Map(lambda x: 1) | ||
| base.RunInference( | ||
model_handler=model_handler, model_metadata_pcoll=side_input)) | ||
|
||
_ = inference_pcoll | "AssertPredictionResult" >> beam.Map( | ||
validate_prediction_result) | ||
|
||
_ = inference_pcoll | "Logging" >> beam.Map(logging.info) | ||
|
||
|
||
if __name__ == '__main__': | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.