Skip to content

Commit

Permalink
Added experimental annotation to fixes #22564 (#22565)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanthompson591 authored Aug 4, 2022
1 parent afb7dc9 commit c8d92b0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
7 changes: 7 additions & 0 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.utils.annotations import experimental

__all__ = [
'PytorchModelHandlerTensor',
'PytorchModelHandlerKeyedTensor',
]


def _load_model(
Expand Down Expand Up @@ -145,6 +151,7 @@ def get_metrics_namespace(self) -> str:
return 'RunInferencePytorch'


@experimental(extra_message="No backwards-compatibility guarantees.")
class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
PredictionResult,
torch.nn.Module]):
Expand Down
7 changes: 7 additions & 0 deletions sdks/python/apache_beam/ml/inference/sklearn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,19 @@
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.utils.annotations import experimental

try:
import joblib
except ImportError:
# joblib is an optional dependency.
pass

__all__ = [
'SklearnModelHandlerNumpy',
'SklearnModelHandlerPandas',
]


class ModelFileType(enum.Enum):
"""Defines how a model file is serialized. Options are pickle or joblib."""
Expand Down Expand Up @@ -132,6 +138,7 @@ def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int:
return sum(sys.getsizeof(element) for element in batch)


@experimental(extra_message="No backwards-compatibility guarantees.")
class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
PredictionResult,
BaseEstimator]):
Expand Down

0 comments on commit c8d92b0

Please sign in to comment.