diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index ad8ea5868f7c..15f94451b8f0 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -51,7 +51,9 @@ _NANOSECOND_TO_MILLISECOND = 1_000_000 _NANOSECOND_TO_MICROSECOND = 1_000 -T = TypeVar('T') +ModelT = TypeVar('ModelT') +ExampleT = TypeVar('ExampleT') +PredictionT = TypeVar('PredictionT') def _to_milliseconds(time_ns: int) -> int: @@ -62,14 +64,15 @@ def _to_microseconds(time_ns: int) -> int: return int(time_ns / _NANOSECOND_TO_MICROSECOND) -class InferenceRunner: +class InferenceRunner(Generic[ExampleT, PredictionT, ModelT]): """Implements running inferences for a framework.""" - def run_inference(self, batch: List[Any], model: Any) -> Iterable[Any]: + def run_inference(self, batch: List[ExampleT], + model: ModelT) -> Iterable[PredictionT]: """Runs inferences on a batch of examples and returns an Iterable of Predictions.""" raise NotImplementedError(type(self)) - def get_num_bytes(self, batch: Any) -> int: + def get_num_bytes(self, batch: List[ExampleT]) -> int: """Returns the number of bytes of data for a batch.""" return len(pickle.dumps(batch)) @@ -78,13 +81,14 @@ def get_metrics_namespace(self) -> str: return 'RunInference' -class ModelLoader(Generic[T]): +class ModelLoader(Generic[ExampleT, PredictionT, ModelT]): """Has the ability to load an ML model.""" - def load_model(self) -> T: + def load_model(self) -> ModelT: """Loads and initializes a model for processing.""" raise NotImplementedError(type(self)) - def get_inference_runner(self) -> InferenceRunner: + def get_inference_runner( + self) -> InferenceRunner[ExampleT, PredictionT, ModelT]: """Returns an implementation of InferenceRunner for this model.""" raise NotImplementedError(type(self)) @@ -97,19 +101,22 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]: return {} -class RunInference(beam.PTransform): +class RunInference(beam.PTransform[beam.PCollection[ExampleT], + beam.PCollection[PredictionT]]): """An extensible transform for running inferences. Args: model_loader: An implementation of ModelLoader. clock: A clock implementing get_current_time_in_microseconds. """ - def __init__(self, model_loader: ModelLoader, clock=time): + def __init__( + self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock=time): self._model_loader = model_loader self._clock = clock # TODO(BEAM-14208): Add batch_size back off in the case there # are functional reasons large batch sizes cannot be handled. - def expand(self, pcoll: beam.PCollection) -> beam.PCollection: + def expand( + self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]: resource_hints = self._model_loader.get_resource_hints() return ( pcoll @@ -170,14 +177,12 @@ def update( self._inference_request_batch_byte_size.update(examples_byte_size) -class _RunInferenceDoFn(beam.DoFn): +class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): """A DoFn implementation generic to frameworks.""" - def __init__(self, model_loader: ModelLoader, clock): + def __init__( + self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock): self._model_loader = model_loader - self._inference_runner = model_loader.get_inference_runner() self._shared_model_handle = shared.Shared() - self._metrics_collector = _MetricsCollector( - self._inference_runner.get_metrics_namespace()) self._clock = clock self._model = None @@ -199,6 +204,9 @@ def load(): return self._shared_model_handle.acquire(load) def setup(self): + self._inference_runner = self._model_loader.get_inference_runner() + self._metrics_collector = _MetricsCollector( + self._inference_runner.get_metrics_namespace()) self._model = self._load_model() def process(self, batch): diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index d4bf6518bda9..41b166ba78dd 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -21,6 +21,7 @@ import unittest from typing import Any from typing import Iterable +from typing import List import apache_beam as beam from apache_beam.metrics.metric import MetricsFilter @@ -35,18 +36,18 @@ def predict(self, example: int) -> int: return example + 1 -class FakeInferenceRunner(base.InferenceRunner): +class FakeInferenceRunner(base.InferenceRunner[int, int, FakeModel]): def __init__(self, clock=None): self._fake_clock = clock - def run_inference(self, batch: Any, model: Any) -> Iterable[Any]: + def run_inference(self, batch: List[int], model: FakeModel) -> Iterable[int]: if self._fake_clock: self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds for example in batch: yield model.predict(example) -class FakeModelLoader(base.ModelLoader): +class FakeModelLoader(base.ModelLoader[int, int, FakeModel]): def __init__(self, clock=None): self._fake_clock = clock diff --git a/sdks/python/apache_beam/ml/inference/pytorch.py b/sdks/python/apache_beam/ml/inference/pytorch.py index d7e24d618238..d591c6867d89 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch.py +++ b/sdks/python/apache_beam/ml/inference/pytorch.py @@ -30,7 +30,9 @@ from apache_beam.ml.inference.base import ModelLoader -class PytorchInferenceRunner(InferenceRunner): +class PytorchInferenceRunner(InferenceRunner[torch.Tensor, + PredictionResult, + torch.nn.Module]): """ This class runs Pytorch inferences with the run_inference method. It also has other methods to get the bytes of a batch of Tensors as well as the namespace @@ -66,7 +68,9 @@ def get_metrics_namespace(self) -> str: return 'RunInferencePytorch' -class PytorchModelLoader(ModelLoader): +class PytorchModelLoader(ModelLoader[torch.Tensor, + PredictionResult, + torch.nn.Module]): """ Implementation of the ModelLoader interface for PyTorch. NOTE: This API and its implementation are under development and @@ -96,18 +100,17 @@ def __init__( else: self._device = torch.device('cpu') self._model_class = model_class - self.model_params = model_params - self._inference_runner = PytorchInferenceRunner(device=self._device) + self._model_params = model_params def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" - model = self._model_class(**self.model_params) + model = self._model_class(**self._model_params) model.to(self._device) file = FileSystems.open(self._state_dict_path, 'rb') model.load_state_dict(torch.load(file)) model.eval() return model - def get_inference_runner(self) -> InferenceRunner: + def get_inference_runner(self) -> PytorchInferenceRunner: """Returns a Pytorch implementation of InferenceRunner.""" - return self._inference_runner + return PytorchInferenceRunner(device=self._device) diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 80530146c2cc..7f91169ea43b 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -23,6 +23,7 @@ from typing import List import numpy +from sklearn.base import BaseEstimator from apache_beam.io.filesystems import FileSystems from apache_beam.ml.inference.api import PredictionResult @@ -41,9 +42,11 @@ class ModelFileType(enum.Enum): JOBLIB = 2 -class SklearnInferenceRunner(InferenceRunner): +class SklearnInferenceRunner(InferenceRunner[numpy.ndarray, + PredictionResult, + BaseEstimator]): def run_inference(self, batch: List[numpy.ndarray], - model: Any) -> Iterable[PredictionResult]: + model: BaseEstimator) -> Iterable[PredictionResult]: # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) predictions = model.predict(vectorized_batch) @@ -54,7 +57,9 @@ def get_num_bytes(self, batch: List[numpy.ndarray]) -> int: return sum(sys.getsizeof(element) for element in batch) -class SklearnModelLoader(ModelLoader): +class SklearnModelLoader(ModelLoader[numpy.ndarray, + PredictionResult, + BaseEstimator]): """ Implementation of the ModelLoader interface for scikit learn. NOTE: This API and its implementation are under development and @@ -66,9 +71,8 @@ def __init__( model_uri: str = ''): self._model_file_type = model_file_type self._model_uri = model_uri - self._inference_runner = SklearnInferenceRunner() - def load_model(self): + def load_model(self) -> BaseEstimator: """Loads and initializes a model for processing.""" file = FileSystems.open(self._model_uri, 'rb') if self._model_file_type == ModelFileType.PICKLE: @@ -84,4 +88,4 @@ def load_model(self): raise AssertionError('Unsupported serialization type.') def get_inference_runner(self) -> SklearnInferenceRunner: - return self._inference_runner + return SklearnInferenceRunner() diff --git a/sdks/python/apache_beam/transforms/periodicsequence_test.py b/sdks/python/apache_beam/transforms/periodicsequence_test.py index b18bf75d0709..e2fe264fbce5 100644 --- a/sdks/python/apache_beam/transforms/periodicsequence_test.py +++ b/sdks/python/apache_beam/transforms/periodicsequence_test.py @@ -76,7 +76,7 @@ def test_periodicimpulse_windowing_on_si(self): assert_that(actual, equal_to(k)) def test_periodicimpulse_default_start(self): - default_parameters = inspect.signature(PeriodicImpulse).parameters + default_parameters = inspect.signature(PeriodicImpulse.__init__).parameters it = default_parameters["start_timestamp"].default duration = 1 et = it + duration diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 86421ed16766..f3e57951e373 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -49,6 +49,7 @@ class and wrapper class that allows lambda functions to be used as from typing import Any from typing import Callable from typing import Dict +from typing import Generic from typing import List from typing import Mapping from typing import Optional @@ -99,6 +100,8 @@ class and wrapper class that allows lambda functions to be used as _LOGGER = logging.getLogger(__name__) T = TypeVar('T') +InputT = TypeVar('InputT') +OutputT = TypeVar('OutputT') PTransformT = TypeVar('PTransformT', bound='PTransform') ConstructorFn = Callable[ ['beam_runner_api_pb2.PTransform', Optional[Any], 'PipelineContext'], Any] @@ -328,7 +331,7 @@ def visit_dict(self, pvalueish, sibling, pairs, context): self.visit(p, sibling, pairs, context) -class PTransform(WithTypeHints, HasDisplayData): +class PTransform(WithTypeHints, HasDisplayData, Generic[InputT, OutputT]): """A transform object used to modify one or more PCollections. Subclasses must define an expand() method that will be used when the transform @@ -522,7 +525,7 @@ def _clone(self, new_label): transform.label = new_label return transform - def expand(self, input_or_inputs): + def expand(self, input_or_inputs: InputT) -> OutputT: raise NotImplementedError def __str__(self):