Skip to content

Commit

Permalink
Add typing information to RunInferrence. (#17762)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored May 31, 2022
1 parent 4f059f7 commit ca33943
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 34 deletions.
38 changes: 23 additions & 15 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
_NANOSECOND_TO_MILLISECOND = 1_000_000
_NANOSECOND_TO_MICROSECOND = 1_000

T = TypeVar('T')
ModelT = TypeVar('ModelT')
ExampleT = TypeVar('ExampleT')
PredictionT = TypeVar('PredictionT')


def _to_milliseconds(time_ns: int) -> int:
Expand All @@ -62,14 +64,15 @@ def _to_microseconds(time_ns: int) -> int:
return int(time_ns / _NANOSECOND_TO_MICROSECOND)


class InferenceRunner:
class InferenceRunner(Generic[ExampleT, PredictionT, ModelT]):
"""Implements running inferences for a framework."""
def run_inference(self, batch: List[Any], model: Any) -> Iterable[Any]:
def run_inference(self, batch: List[ExampleT],
model: ModelT) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples and
returns an Iterable of Predictions."""
raise NotImplementedError(type(self))

def get_num_bytes(self, batch: Any) -> int:
def get_num_bytes(self, batch: List[ExampleT]) -> int:
"""Returns the number of bytes of data for a batch."""
return len(pickle.dumps(batch))

Expand All @@ -78,13 +81,14 @@ def get_metrics_namespace(self) -> str:
return 'RunInference'


class ModelLoader(Generic[T]):
class ModelLoader(Generic[ExampleT, PredictionT, ModelT]):
"""Has the ability to load an ML model."""
def load_model(self) -> T:
def load_model(self) -> ModelT:
"""Loads and initializes a model for processing."""
raise NotImplementedError(type(self))

def get_inference_runner(self) -> InferenceRunner:
def get_inference_runner(
self) -> InferenceRunner[ExampleT, PredictionT, ModelT]:
"""Returns an implementation of InferenceRunner for this model."""
raise NotImplementedError(type(self))

Expand All @@ -97,19 +101,22 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]:
return {}


class RunInference(beam.PTransform):
class RunInference(beam.PTransform[beam.PCollection[ExampleT],
beam.PCollection[PredictionT]]):
"""An extensible transform for running inferences.
Args:
model_loader: An implementation of ModelLoader.
clock: A clock implementing get_current_time_in_microseconds.
"""
def __init__(self, model_loader: ModelLoader, clock=time):
def __init__(
self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock=time):
self._model_loader = model_loader
self._clock = clock

# TODO(BEAM-14208): Add batch_size back off in the case there
# are functional reasons large batch sizes cannot be handled.
def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
def expand(
self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]:
resource_hints = self._model_loader.get_resource_hints()
return (
pcoll
Expand Down Expand Up @@ -170,14 +177,12 @@ def update(
self._inference_request_batch_byte_size.update(examples_byte_size)


class _RunInferenceDoFn(beam.DoFn):
class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
"""A DoFn implementation generic to frameworks."""
def __init__(self, model_loader: ModelLoader, clock):
def __init__(
self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock):
self._model_loader = model_loader
self._inference_runner = model_loader.get_inference_runner()
self._shared_model_handle = shared.Shared()
self._metrics_collector = _MetricsCollector(
self._inference_runner.get_metrics_namespace())
self._clock = clock
self._model = None

Expand All @@ -199,6 +204,9 @@ def load():
return self._shared_model_handle.acquire(load)

def setup(self):
self._inference_runner = self._model_loader.get_inference_runner()
self._metrics_collector = _MetricsCollector(
self._inference_runner.get_metrics_namespace())
self._model = self._load_model()

def process(self, batch):
Expand Down
7 changes: 4 additions & 3 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import unittest
from typing import Any
from typing import Iterable
from typing import List

import apache_beam as beam
from apache_beam.metrics.metric import MetricsFilter
Expand All @@ -35,18 +36,18 @@ def predict(self, example: int) -> int:
return example + 1


class FakeInferenceRunner(base.InferenceRunner):
class FakeInferenceRunner(base.InferenceRunner[int, int, FakeModel]):
def __init__(self, clock=None):
self._fake_clock = clock

def run_inference(self, batch: Any, model: Any) -> Iterable[Any]:
def run_inference(self, batch: List[int], model: FakeModel) -> Iterable[int]:
if self._fake_clock:
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
for example in batch:
yield model.predict(example)


class FakeModelLoader(base.ModelLoader):
class FakeModelLoader(base.ModelLoader[int, int, FakeModel]):
def __init__(self, clock=None):
self._fake_clock = clock

Expand Down
17 changes: 10 additions & 7 deletions sdks/python/apache_beam/ml/inference/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
from apache_beam.ml.inference.base import ModelLoader


class PytorchInferenceRunner(InferenceRunner):
class PytorchInferenceRunner(InferenceRunner[torch.Tensor,
PredictionResult,
torch.nn.Module]):
"""
This class runs Pytorch inferences with the run_inference method. It also has
other methods to get the bytes of a batch of Tensors as well as the namespace
Expand Down Expand Up @@ -66,7 +68,9 @@ def get_metrics_namespace(self) -> str:
return 'RunInferencePytorch'


class PytorchModelLoader(ModelLoader):
class PytorchModelLoader(ModelLoader[torch.Tensor,
PredictionResult,
torch.nn.Module]):
""" Implementation of the ModelLoader interface for PyTorch.
NOTE: This API and its implementation are under development and
Expand Down Expand Up @@ -96,18 +100,17 @@ def __init__(
else:
self._device = torch.device('cpu')
self._model_class = model_class
self.model_params = model_params
self._inference_runner = PytorchInferenceRunner(device=self._device)
self._model_params = model_params

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model = self._model_class(**self.model_params)
model = self._model_class(**self._model_params)
model.to(self._device)
file = FileSystems.open(self._state_dict_path, 'rb')
model.load_state_dict(torch.load(file))
model.eval()
return model

def get_inference_runner(self) -> InferenceRunner:
def get_inference_runner(self) -> PytorchInferenceRunner:
"""Returns a Pytorch implementation of InferenceRunner."""
return self._inference_runner
return PytorchInferenceRunner(device=self._device)
16 changes: 10 additions & 6 deletions sdks/python/apache_beam/ml/inference/sklearn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import List

import numpy
from sklearn.base import BaseEstimator

from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.api import PredictionResult
Expand All @@ -41,9 +42,11 @@ class ModelFileType(enum.Enum):
JOBLIB = 2


class SklearnInferenceRunner(InferenceRunner):
class SklearnInferenceRunner(InferenceRunner[numpy.ndarray,
PredictionResult,
BaseEstimator]):
def run_inference(self, batch: List[numpy.ndarray],
model: Any) -> Iterable[PredictionResult]:
model: BaseEstimator) -> Iterable[PredictionResult]:
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
predictions = model.predict(vectorized_batch)
Expand All @@ -54,7 +57,9 @@ def get_num_bytes(self, batch: List[numpy.ndarray]) -> int:
return sum(sys.getsizeof(element) for element in batch)


class SklearnModelLoader(ModelLoader):
class SklearnModelLoader(ModelLoader[numpy.ndarray,
PredictionResult,
BaseEstimator]):
""" Implementation of the ModelLoader interface for scikit learn.
NOTE: This API and its implementation are under development and
Expand All @@ -66,9 +71,8 @@ def __init__(
model_uri: str = ''):
self._model_file_type = model_file_type
self._model_uri = model_uri
self._inference_runner = SklearnInferenceRunner()

def load_model(self):
def load_model(self) -> BaseEstimator:
"""Loads and initializes a model for processing."""
file = FileSystems.open(self._model_uri, 'rb')
if self._model_file_type == ModelFileType.PICKLE:
Expand All @@ -84,4 +88,4 @@ def load_model(self):
raise AssertionError('Unsupported serialization type.')

def get_inference_runner(self) -> SklearnInferenceRunner:
return self._inference_runner
return SklearnInferenceRunner()
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_periodicimpulse_windowing_on_si(self):
assert_that(actual, equal_to(k))

def test_periodicimpulse_default_start(self):
default_parameters = inspect.signature(PeriodicImpulse).parameters
default_parameters = inspect.signature(PeriodicImpulse.__init__).parameters
it = default_parameters["start_timestamp"].default
duration = 1
et = it + duration
Expand Down
7 changes: 5 additions & 2 deletions sdks/python/apache_beam/transforms/ptransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class and wrapper class that allows lambda functions to be used as
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import List
from typing import Mapping
from typing import Optional
Expand Down Expand Up @@ -99,6 +100,8 @@ class and wrapper class that allows lambda functions to be used as
_LOGGER = logging.getLogger(__name__)

T = TypeVar('T')
InputT = TypeVar('InputT')
OutputT = TypeVar('OutputT')
PTransformT = TypeVar('PTransformT', bound='PTransform')
ConstructorFn = Callable[
['beam_runner_api_pb2.PTransform', Optional[Any], 'PipelineContext'], Any]
Expand Down Expand Up @@ -328,7 +331,7 @@ def visit_dict(self, pvalueish, sibling, pairs, context):
self.visit(p, sibling, pairs, context)


class PTransform(WithTypeHints, HasDisplayData):
class PTransform(WithTypeHints, HasDisplayData, Generic[InputT, OutputT]):
"""A transform object used to modify one or more PCollections.
Subclasses must define an expand() method that will be used when the transform
Expand Down Expand Up @@ -522,7 +525,7 @@ def _clone(self, new_label):
transform.label = new_label
return transform

def expand(self, input_or_inputs):
def expand(self, input_or_inputs: InputT) -> OutputT:
raise NotImplementedError

def __str__(self):
Expand Down

0 comments on commit ca33943

Please sign in to comment.