diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index 06da5280da..258db41882 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -85,11 +85,11 @@ def __init__(self, online_store_class_name: str): ) -class FeastStoreConfigInvalidName(Exception): - def __init__(self, store_config_class_name: str, store_type: str): +class FeastClassInvalidName(Exception): + def __init__(self, class_name: str, class_type: str): super().__init__( - f"Config Class '{store_config_class_name}' " - f"should end with the string `{store_type}Config`.'" + f"Config Class '{class_name}' " + f"should end with the string `{class_type}`.'" ) diff --git a/sdk/python/feast/importer.py b/sdk/python/feast/importer.py new file mode 100644 index 0000000000..5dcd7c71c1 --- /dev/null +++ b/sdk/python/feast/importer.py @@ -0,0 +1,28 @@ +import importlib + +from feast import errors + + +def get_class_from_type(module_name: str, class_name: str, class_type: str): + if not class_name.endswith(class_type): + raise errors.FeastClassInvalidName(class_name, class_type) + + # Try importing the module that contains the custom provider + try: + module = importlib.import_module(module_name) + except Exception as e: + # The original exception can be anything - either module not found, + # or any other kind of error happening during the module import time. + # So we should include the original error as well in the stack trace. + raise errors.FeastModuleImportError(module_name, class_type) from e + + # Try getting the provider class definition + try: + _class = getattr(module, class_name) + except AttributeError: + # This can only be one type of error, when class_name attribute does not exist in the module + # So we don't have to include the original exception here + raise errors.FeastClassImportError( + module_name, class_name, class_type=class_type + ) from None + return _class diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index d6fb88c7d7..905d0fd1dc 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -1,5 +1,4 @@ import abc -import importlib from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -8,7 +7,7 @@ import pyarrow from tqdm import tqdm -from feast import errors +from feast import errors, importer from feast.entity import Entity from feast.feature_table import FeatureTable from feast.feature_view import FeatureView @@ -156,24 +155,9 @@ def get_provider(config: RepoConfig, repo_path: Path) -> Provider: # For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider' module_name, class_name = config.provider.rsplit(".", 1) - # Try importing the module that contains the custom provider - try: - module = importlib.import_module(module_name) - except Exception as e: - # The original exception can be anything - either module not found, - # or any other kind of error happening during the module import time. - # So we should include the original error as well in the stack trace. - raise errors.FeastModuleImportError(module_name, "provider") from e - - # Try getting the provider class definition - try: - ProviderCls = getattr(module, class_name) - except AttributeError: - # This can only be one type of error, when class_name attribute does not exist in the module - # So we don't have to include the original exception here - raise errors.FeastClassImportError(module_name, class_name) from None - - return ProviderCls(config, repo_path) + cls = importer.get_class_from_type(module_name, class_name, "Provider") + + return cls(config, repo_path) def _get_requested_feature_views_to_features_dict( diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 00bbaf05cf..5587fb5905 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -1,4 +1,3 @@ -import importlib from pathlib import Path from typing import Any @@ -7,7 +6,7 @@ from pydantic.error_wrappers import ErrorWrapper from pydantic.typing import Dict, Optional, Union -from feast import errors +from feast.importer import get_class_from_type from feast.telemetry import log_exceptions # These dict exists so that: @@ -185,33 +184,6 @@ def __repr__(self) -> str: ) -def get_config_class_from_type( - module_name: str, config_class_name: str, store_type: str -): - if not config_class_name.endswith(f"{store_type}Config"): - raise errors.FeastStoreConfigInvalidName(config_class_name, store_type) - - # Try importing the module that contains the custom provider - try: - module = importlib.import_module(module_name) - except Exception as e: - # The original exception can be anything - either module not found, - # or any other kind of error happening during the module import time. - # So we should include the original error as well in the stack trace. - raise errors.FeastModuleImportError(module_name, store_type) from e - - # Try getting the provider class definition - try: - online_store_config_class = getattr(module, config_class_name) - except AttributeError: - # This can only be one type of error, when class_name attribute does not exist in the module - # So we don't have to include the original exception here - raise errors.FeastClassImportError( - module_name, config_class_name, class_type=f"{store_type}Config" - ) from None - return online_store_config_class - - def get_online_config_from_type(online_store_type: str): if online_store_type in ONLINE_STORE_CLASS_FOR_TYPE: online_store_type = ONLINE_STORE_CLASS_FOR_TYPE[online_store_type] @@ -220,7 +192,7 @@ def get_online_config_from_type(online_store_type: str): module_name, online_store_class_type = online_store_type.rsplit(".", 1) config_class_name = f"{online_store_class_type}Config" - return get_config_class_from_type(module_name, config_class_name, "OnlineStore") + return get_class_from_type(module_name, config_class_name, config_class_name) def get_offline_config_from_type(offline_store_type: str): @@ -231,7 +203,7 @@ def get_offline_config_from_type(offline_store_type: str): module_name, offline_store_class_type = offline_store_type.rsplit(".", 1) config_class_name = f"{offline_store_class_type}Config" - return get_config_class_from_type(module_name, config_class_name, "OfflineStore") + return get_class_from_type(module_name, config_class_name, config_class_name) def load_repo_config(repo_path: Path) -> RepoConfig: diff --git a/sdk/python/tests/test_cli_local.py b/sdk/python/tests/test_cli_local.py index 288a246245..4399819073 100644 --- a/sdk/python/tests/test_cli_local.py +++ b/sdk/python/tests/test_cli_local.py @@ -164,18 +164,18 @@ def test_3rd_party_providers() -> None: assertpy.assert_that(return_code).is_equal_to(1) assertpy.assert_that(output).contains(b"Provider 'feast123' is not implemented") # Check with incorrect third-party provider name (with dots) - with setup_third_party_provider_repo("feast_foo.provider") as repo_path: + with setup_third_party_provider_repo("feast_foo.Provider") as repo_path: return_code, output = runner.run_with_output(["apply"], cwd=repo_path) assertpy.assert_that(return_code).is_equal_to(1) assertpy.assert_that(output).contains( - b"Could not import provider module 'feast_foo'" + b"Could not import Provider module 'feast_foo'" ) # Check with incorrect third-party provider name (with dots) with setup_third_party_provider_repo("foo.FooProvider") as repo_path: return_code, output = runner.run_with_output(["apply"], cwd=repo_path) assertpy.assert_that(return_code).is_equal_to(1) assertpy.assert_that(output).contains( - b"Could not import provider 'FooProvider' from module 'foo'" + b"Could not import Provider 'FooProvider' from module 'foo'" ) # Check with correct third-party provider name with setup_third_party_provider_repo("foo.provider.FooProvider") as repo_path: