From 28f6e3398bf9f470d9f44cfc5991435ba0c52b08 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 30 Oct 2023 17:03:46 -0400 Subject: [PATCH] Don't assume env vars are set in model handler (#29200) * Don't assume env vars are set in model handler * Patch notebook --- examples/notebooks/beam-ml/run_custom_inference.ipynb | 1 + sdks/python/apache_beam/ml/inference/base.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/notebooks/beam-ml/run_custom_inference.ipynb b/examples/notebooks/beam-ml/run_custom_inference.ipynb index df81ae5af56f..a66c5847de0e 100644 --- a/examples/notebooks/beam-ml/run_custom_inference.ipynb +++ b/examples/notebooks/beam-ml/run_custom_inference.ipynb @@ -356,6 +356,7 @@ " model_name: The spaCy model name. Default is en_core_web_sm.\n", " \"\"\"\n", " self._model_name = model_name\n", + " self._env_vars = {}\n", "\n", " def load_model(self) -> Language:\n", " \"\"\"Loads and initializes a model for processing.\"\"\"\n", diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 45c5078c13cf..fc8ac59a1fb7 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -512,7 +512,7 @@ def __init__( 'postprocessing functions defined into a keyed model handler. All ' 'pre/postprocessing functions must be defined on the outer model' 'handler.') - self._env_vars = unkeyed._env_vars + self._env_vars = getattr(unkeyed, '_env_vars', {}) self._unkeyed = unkeyed return @@ -553,7 +553,7 @@ def __init__( 'overriding the KeyedModelHandler.batch_elements_kwargs() method.', hints, batch_kwargs) - env_vars = mh._env_vars + env_vars = getattr(mh, '_env_vars', {}) if len(env_vars) > 0: logging.warning( 'mh %s defines the following _env_vars which will be ignored %s. ' @@ -816,7 +816,7 @@ def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): 'pre/postprocessing functions must be defined on the outer model' 'handler.') self._unkeyed = unkeyed - self._env_vars = unkeyed._env_vars + self._env_vars = getattr(unkeyed, '_env_vars', {}) def load_model(self) -> ModelT: return self._unkeyed.load_model() @@ -895,7 +895,7 @@ def __init__( preprocess_fn: the preprocessing function to use. """ self._base = base - self._env_vars = base._env_vars + self._env_vars = getattr(base, '_env_vars', {}) self._preprocess_fn = preprocess_fn def load_model(self) -> ModelT: @@ -951,7 +951,7 @@ def __init__( postprocess_fn: the preprocessing function to use. """ self._base = base - self._env_vars = base._env_vars + self._env_vars = getattr(base, '_env_vars', {}) self._postprocess_fn = postprocess_fn def load_model(self) -> ModelT: