diff --git a/.github/workflows/pr_integration_tests.yml b/.github/workflows/pr_integration_tests.yml index 169482b038..7699d70381 100644 --- a/.github/workflows/pr_integration_tests.yml +++ b/.github/workflows/pr_integration_tests.yml @@ -38,6 +38,7 @@ jobs: env: ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }} ECR_REPOSITORY: feast-python-server + # Note: the image tags should be in sync with sdk/python/feast/infra/aws.py:_get_docker_image_version run: | docker build \ --file sdk/python/feast/infra/feature_servers/aws_lambda/Dockerfile \ diff --git a/sdk/python/feast/constants.py b/sdk/python/feast/constants.py index efcea6d8f1..1e88e60435 100644 --- a/sdk/python/feast/constants.py +++ b/sdk/python/feast/constants.py @@ -17,7 +17,7 @@ # Maximum interval(secs) to wait between retries for retry function MAX_WAIT_INTERVAL: str = "60" -AWS_LAMBDA_FEATURE_SERVER_IMAGE = "feastdev/feature-server:aws" +AWS_LAMBDA_FEATURE_SERVER_IMAGE = "feastdev/feature-server" AWS_LAMBDA_FEATURE_SERVER_REPOSITORY = "feast-python-server" # feature_store.yaml environment variable name for remote feature server diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index 27f77124d9..6eb27894fa 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -7,6 +7,7 @@ import feast from feast import proto_json from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesRequest +from feast.type_map import feast_value_type_to_python_type def get_app(store: "feast.FeatureStore"): @@ -36,7 +37,10 @@ async def get_online_features(request: Request): raise HTTPException(status_code=500, detail="Uneven number of columns") entity_rows = [ - {k: v.val[idx] for k, v in request_proto.entities.items()} + { + k: feast_value_type_to_python_type(v.val[idx]) + for k, v in request_proto.entities.items() + } for idx in range(num_entities) ] @@ -45,7 +49,9 @@ async def get_online_features(request: Request): ).proto # Convert the Protobuf object to JSON and return it - return MessageToDict(response_proto, preserving_proto_field_name=True) + return MessageToDict( # type: ignore + response_proto, preserving_proto_field_name=True, float_precision=18 + ) except Exception as e: # Print the original exception on the server side logger.exception(e) diff --git a/sdk/python/feast/infra/aws.py b/sdk/python/feast/infra/aws.py index f0348d2901..2ff38de9f9 100644 --- a/sdk/python/feast/infra/aws.py +++ b/sdk/python/feast/infra/aws.py @@ -1,6 +1,8 @@ import base64 +import hashlib import logging import os +import subprocess import uuid from datetime import datetime from pathlib import Path @@ -10,6 +12,7 @@ from colorama import Fore, Style +from feast import flags_helper from feast.constants import ( AWS_LAMBDA_FEATURE_SERVER_IMAGE, AWS_LAMBDA_FEATURE_SERVER_REPOSITORY, @@ -88,18 +91,18 @@ def update_infra( ) ecr_client = boto3.client("ecr") + docker_image_version = _get_docker_image_version() repository_uri = self._create_or_get_repository_uri(ecr_client) - version = _get_version_for_aws() # Only download & upload the docker image if it doesn't already exist in ECR if not ecr_client.batch_get_image( repositoryName=AWS_LAMBDA_FEATURE_SERVER_REPOSITORY, - imageIds=[{"imageTag": version}], + imageIds=[{"imageTag": docker_image_version}], ).get("images"): image_uri = self._upload_docker_image( - ecr_client, repository_uri, version + ecr_client, repository_uri, docker_image_version ) else: - image_uri = f"{repository_uri}:{version}" + image_uri = f"{repository_uri}:{docker_image_version}" self._deploy_feature_server(project, image_uri) @@ -154,11 +157,10 @@ def _deploy_feature_server(self, project: str, image_uri: str): # feature views, feature services, and other definitions does not update lambda). _logger.info(" Updating AWS Lambda...") - lambda_client.update_function_configuration( - FunctionName=resource_name, - Environment={ - "Variables": {FEATURE_STORE_YAML_ENV_NAME: config_base64} - }, + aws_utils.update_lambda_function_environment( + lambda_client, + resource_name, + {"Variables": {FEATURE_STORE_YAML_ENV_NAME: config_base64}}, ) api = aws_utils.get_first_api_gateway(api_gateway_client, resource_name) @@ -235,7 +237,7 @@ def get_feature_server_endpoint(self) -> Optional[str]: return f"https://{api_id}.execute-api.{region}.amazonaws.com" def _upload_docker_image( - self, ecr_client, repository_uri: str, version: str + self, ecr_client, repository_uri: str, docker_image_version: str ) -> str: """ Pulls the AWS Lambda docker image from Dockerhub and uploads it to AWS ECR. @@ -258,12 +260,11 @@ def _upload_docker_image( raise DockerDaemonNotRunning() + dockerhub_image = f"{AWS_LAMBDA_FEATURE_SERVER_IMAGE}:{docker_image_version}" _logger.info( - f"Pulling remote image {Style.BRIGHT + Fore.GREEN}{AWS_LAMBDA_FEATURE_SERVER_IMAGE}{Style.RESET_ALL}" + f"Pulling remote image {Style.BRIGHT + Fore.GREEN}{dockerhub_image}{Style.RESET_ALL}" ) - for line in docker_client.api.pull( - AWS_LAMBDA_FEATURE_SERVER_IMAGE, stream=True, decode=True - ): + for line in docker_client.api.pull(dockerhub_image, stream=True, decode=True): _logger.debug(f" {line}") auth_token = ecr_client.get_authorization_token()["authorizationData"][0][ @@ -280,14 +281,14 @@ def _upload_docker_image( ) _logger.debug(f" {login_status}") - image = docker_client.images.get(AWS_LAMBDA_FEATURE_SERVER_IMAGE) - image_remote_name = f"{repository_uri}:{version}" + image = docker_client.images.get(dockerhub_image) + image_remote_name = f"{repository_uri}:{docker_image_version}" _logger.info( f"Pushing local image to remote {Style.BRIGHT + Fore.GREEN}{image_remote_name}{Style.RESET_ALL}" ) image.tag(image_remote_name) for line in docker_client.api.push( - repository_uri, tag=version, stream=True, decode=True + repository_uri, tag=docker_image_version, stream=True, decode=True ): _logger.debug(f" {line}") @@ -310,21 +311,53 @@ def _create_or_get_repository_uri(self, ecr_client): def _get_lambda_name(project: str): lambda_prefix = AWS_LAMBDA_FEATURE_SERVER_REPOSITORY - lambda_suffix = f"{project}-{_get_version_for_aws()}" + lambda_suffix = f"{project}-{_get_docker_image_version()}" # AWS Lambda name can't have the length greater than 64 bytes. - # This usually occurs during integration tests or when feast is - # installed in editable mode (pip install -e), where feast version is long + # This usually occurs during integration tests where feast version is long if len(lambda_prefix) + len(lambda_suffix) >= 63: - lambda_suffix = base64.b64encode(lambda_suffix.encode()).decode()[:40] + lambda_suffix = hashlib.md5(lambda_suffix.encode()).hexdigest() return f"{lambda_prefix}-{lambda_suffix}" -def _get_version_for_aws(): - """Returns Feast version with certain characters replaced. +def _get_docker_image_version() -> str: + """Returns a version for the feature server Docker image. + + For public Feast releases this equals to the Feast SDK version modified by replacing "." with "_". + For example, Feast SDK version "0.14.1" would correspond to Docker image version "0_14_1". + + For integration tests this equals to the git commit hash of HEAD. This is necessary, + because integration tests need to use images built from the same commit hash. + + During development (when Feast is installed in editable mode) this equals to the Feast SDK version + modified by removing the "dev..." suffix and replacing "." with "_". For example, Feast SDK version + "0.14.1.dev41+g1cbfa225.d20211103" would correspond to Docker image version "0_14_1". This way, + Feast SDK will use an already existing Docker image built during the previous public release. - This allows the version to be included in names for AWS resources. """ - return get_version().replace(".", "_").replace("+", "_") + if flags_helper.is_test(): + # Note: this should be in sync with https://github.com/feast-dev/feast/blob/6fbe01b6e9a444dc77ec3328a54376f4d9387664/.github/workflows/pr_integration_tests.yml#L41 + return ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], cwd=Path(__file__).resolve().parent + ) + .decode() + .strip() + ) + else: + version = get_version() + if "dev" in version: + version = version[: version.find("dev") - 1].replace(".", "_") + _logger.warning( + "You are trying to use AWS Lambda feature server while Feast is in a development mode. " + f"Feast will use a docker image version {version} derived from Feast SDK " + f"version {get_version()}. If you want to update the Feast SDK version, make " + "sure to first fetch all new release tags from Github and then reinstall the library:\n" + "> git fetch --all --tags\n" + "> pip install -e sdk/python" + ) + else: + version = version.replace(".", "_") + return version class S3RegistryStore(RegistryStore): diff --git a/sdk/python/feast/infra/feature_servers/aws_lambda/Dockerfile b/sdk/python/feast/infra/feature_servers/aws_lambda/Dockerfile index 4d46abd3db..3384bca293 100644 --- a/sdk/python/feast/infra/feature_servers/aws_lambda/Dockerfile +++ b/sdk/python/feast/infra/feature_servers/aws_lambda/Dockerfile @@ -9,7 +9,7 @@ COPY protos protos COPY README.md README.md # Install Feast for AWS with Lambda dependencies -RUN pip3 install -e 'sdk/python[aws,redis]' +RUN pip3 install -e 'sdk/python[aws]' RUN pip3 install -r sdk/python/feast/infra/feature_servers/aws_lambda/requirements.txt --target "${LAMBDA_TASK_ROOT}" # Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index 28d48a489e..6211c75e37 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -2,7 +2,7 @@ import os import tempfile import uuid -from typing import Dict, Iterator, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple import pandas as pd import pyarrow as pa @@ -60,6 +60,7 @@ def get_bucket_and_key(s3_path: str) -> Tuple[str, str]: wait=wait_exponential(multiplier=1, max=4), retry=retry_if_exception_type(ConnectionClosedError), stop=stop_after_attempt(5), + reraise=True, ) def execute_redshift_statement_async( redshift_data_client, cluster_id: str, database: str, user: str, query: str @@ -96,6 +97,7 @@ class RedshiftStatementNotFinishedError(Exception): wait=wait_exponential(multiplier=1, max=30), retry=retry_if_exception_type(RedshiftStatementNotFinishedError), stop=stop_after_delay(300), # 300 seconds + reraise=True, ) def wait_for_redshift_statement(redshift_data_client, statement: dict) -> None: """Waits for the Redshift statement to finish. Raises RedshiftQueryError if the statement didn't succeed. @@ -426,6 +428,29 @@ def delete_lambda_function(lambda_client, function_name: str) -> Dict: return lambda_client.delete_function(FunctionName=function_name) +@retry( + wait=wait_exponential(multiplier=1, max=4), + retry=retry_if_exception_type(ClientError), + stop=stop_after_attempt(5), + reraise=True, +) +def update_lambda_function_environment( + lambda_client, function_name: str, environment: Dict[str, Any] +) -> None: + """ + Update AWS Lambda function environment. The function is retried multiple times in case another action is + currently being run on the lambda (e.g. it's being created or being updated in parallel). + Args: + lambda_client: AWS Lambda client. + function_name: Name of the AWS Lambda function. + environment: The desired lambda environment. + + """ + lambda_client.update_function_configuration( + FunctionName=function_name, Environment=environment + ) + + def get_first_api_gateway(api_gateway_client, api_gateway_name: str) -> Optional[Dict]: """ Get the first API Gateway with the given name. Note, that API Gateways can have the same name. diff --git a/sdk/python/feast/online_response.py b/sdk/python/feast/online_response.py index 070c86abbc..62bd86ae9e 100644 --- a/sdk/python/feast/online_response.py +++ b/sdk/python/feast/online_response.py @@ -45,6 +45,12 @@ def __init__(self, online_response_proto: GetOnlineFeaturesResponse): online_response_proto: GetOnlineResponse proto object to construct from. """ self.proto = online_response_proto + # Delete DUMMY_ENTITY_ID from proto if it exists + for item in self.proto.field_values: + if DUMMY_ENTITY_ID in item.statuses: + del item.statuses[DUMMY_ENTITY_ID] + if DUMMY_ENTITY_ID in item.fields: + del item.fields[DUMMY_ENTITY_ID] @property def field_values(self): @@ -57,13 +63,9 @@ def to_dict(self) -> Dict[str, Any]: """ Converts GetOnlineFeaturesResponse features into a dictionary form. """ - fields = [ - k - for row in self.field_values - for k, _ in row.statuses.items() - if k != DUMMY_ENTITY_ID - ] - features_dict: Dict[str, List[Any]] = {k: list() for k in fields} + features_dict: Dict[str, List[Any]] = { + k: list() for row in self.field_values for k, _ in row.statuses.items() + } for row in self.field_values: for feature in features_dict.keys(): @@ -77,9 +79,7 @@ def to_df(self) -> pd.DataFrame: Converts GetOnlineFeaturesResponse features into Panda dataframe form. """ - return pd.DataFrame(self.to_dict()).drop( - DUMMY_ENTITY_ID, axis=1, errors="ignore" - ) + return pd.DataFrame(self.to_dict()) def _infer_online_entity_rows( diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 5bbe09d8e0..70e64c845c 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -44,6 +44,11 @@ "gcp_cloudrun": "feast.infra.feature_servers.gcp_cloudrun.config.GcpCloudRunFeatureServerConfig", } +FEATURE_SERVER_TYPE_FOR_PROVIDER = { + "aws": "aws_lambda", + "gcp": "gcp_cloudrun", +} + class FeastBaseModel(BaseModel): """ Feast Pydantic Configuration Class """ @@ -226,15 +231,12 @@ def _validate_feature_server_config(cls, values): if "provider" not in values: raise FeastProviderNotSetError() - # Make sure that the type is not set, since we will set it based on the provider. - if "type" in values["feature_server"]: - raise FeastFeatureServerTypeSetError(values["feature_server"]["type"]) - - # Set the default type. We only support AWS Lambda for now. - if values["provider"] == "aws": - values["feature_server"]["type"] = "aws_lambda" - - feature_server_type = values["feature_server"]["type"] + feature_server_type = FEATURE_SERVER_TYPE_FOR_PROVIDER.get(values["provider"]) + defined_type = values["feature_server"].get("type") + # Make sure that the type is either not set, or set correctly, since it's defined by the provider + if defined_type not in (None, feature_server_type): + raise FeastFeatureServerTypeSetError(defined_type) + values["feature_server"]["type"] = feature_server_type # Validate the dict to ensure one of the union types match try: diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index 1de54800b9..9992517ba6 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -47,7 +47,11 @@ def feast_value_type_to_python_type(field_value_proto: ProtoValue) -> Any: Returns: Python native type representation/version of the given field_value_proto """ - field_value_dict = MessageToDict(field_value_proto) + field_value_dict = MessageToDict(field_value_proto, float_precision=18) # type: ignore + + # This can happen when proto_json.patch() has been called before this call, which is true for a feature server + if not isinstance(field_value_dict, dict): + return field_value_dict for k, v in field_value_dict.items(): if "List" in k: diff --git a/sdk/python/setup.py b/sdk/python/setup.py index 02216609e7..58ae08e6ff 100644 --- a/sdk/python/setup.py +++ b/sdk/python/setup.py @@ -64,6 +64,7 @@ "tqdm==4.*", "fastapi>=0.68.0", "uvicorn[standard]>=0.14.0", + "proto-plus<1.19.7", ] GCP_REQUIRED = [ @@ -113,7 +114,7 @@ "firebase-admin==4.5.2", "pre-commit", "assertpy==1.1", - "pip-tools" + "pip-tools", ] + GCP_REQUIRED + REDIS_REQUIRED + AWS_REQUIRED DEV_REQUIRED = ["mypy-protobuf==1.*", "grpcio-testing==1.*"] + CI_REQUIRED diff --git a/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py b/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py index 6e4f42cff8..e4ff667764 100644 --- a/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py +++ b/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py @@ -22,6 +22,7 @@ class IntegrationTestRepoConfig: full_feature_names: bool = True infer_features: bool = False + python_feature_server: bool = False def __repr__(self) -> str: return "-".join( diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index dc84a3a531..67f8b04f64 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -1,4 +1,5 @@ import importlib +import json import os import tempfile import uuid @@ -9,6 +10,7 @@ from typing import Any, Dict, List, Optional import pandas as pd +import yaml from feast import FeatureStore, FeatureView, RepoConfig, driver_test_data from feast.constants import FULL_REPO_CONFIGS_MODULE_ENV_NAME @@ -51,10 +53,10 @@ IntegrationTestRepoConfig(), ] if os.getenv("FEAST_IS_LOCAL_TEST", "False") != "True": - IntegrationTestRepoConfig(online_store=REDIS_CONFIG), - # GCP configurations DEFAULT_FULL_REPO_CONFIGS.extend( [ + IntegrationTestRepoConfig(online_store=REDIS_CONFIG), + # GCP configurations IntegrationTestRepoConfig( provider="gcp", offline_store_creator=BigQueryDataSourceCreator, @@ -70,6 +72,7 @@ provider="aws", offline_store_creator=RedshiftDataSourceCreator, online_store=DYNAMO_CONFIG, + python_feature_server=True, ), IntegrationTestRepoConfig( provider="aws", @@ -211,6 +214,7 @@ class Environment: test_repo_config: IntegrationTestRepoConfig feature_store: FeatureStore data_source_creator: DataSourceCreator + python_feature_server: bool end_date: datetime = field( default=datetime.utcnow().replace(microsecond=0, second=0, minute=0) @@ -241,15 +245,36 @@ def construct_test_environment( online_store = test_repo_config.online_store with tempfile.TemporaryDirectory() as repo_dir_name: + if test_repo_config.python_feature_server: + from feast.infra.feature_servers.aws_lambda.config import ( + AwsLambdaFeatureServerConfig, + ) + + feature_server = AwsLambdaFeatureServerConfig( + enabled=True, + execution_role_name="arn:aws:iam::402087665549:role/lambda_execution_role", + ) + + registry = f"s3://feast-integration-tests/registries/{project}/registry.db" + else: + feature_server = None + registry = str(Path(repo_dir_name) / "registry.db") + config = RepoConfig( - registry=str(Path(repo_dir_name) / "registry.db"), + registry=registry, project=project, provider=test_repo_config.provider, offline_store=offline_store_config, online_store=online_store, repo_path=repo_dir_name, + feature_server=feature_server, ) - fs = FeatureStore(config=config) + + # Create feature_store.yaml out of the config + with open(Path(repo_dir_name) / "feature_store.yaml", "w") as f: + yaml.safe_dump(json.loads(config.json()), f) + + fs = FeatureStore(repo_dir_name) # We need to initialize the registry, because if nothing is applied in the test before tearing down # the feature store, that will cause the teardown method to blow up. fs.registry._initialize_registry() @@ -258,6 +283,7 @@ def construct_test_environment( test_repo_config=test_repo_config, feature_store=fs, data_source_creator=offline_creator, + python_feature_server=test_repo_config.python_feature_server, ) try: diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index c90021f9ce..c47f2bbfd0 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -1,12 +1,16 @@ import datetime import itertools import os +import time import unittest from datetime import timedelta +from typing import Any, Dict, List, Union +import assertpy import numpy as np import pandas as pd import pytest +import requests from feast import Entity, Feature, FeatureService, FeatureView, ValueType from feast.errors import ( @@ -14,6 +18,7 @@ RequestDataNotFoundInEntityRowsException, ) from tests.integration.feature_repos.repo_configuration import ( + Environment, construct_universal_feature_views, ) from tests.integration.feature_repos.universal.entities import ( @@ -165,16 +170,97 @@ def test_write_to_online_store(environment, universal_data_sources): ], entity_rows=[{"driver": 123}], ).to_df() - assert df["avg_daily_trips"].iloc[0] == 14 - assert df["acc_rate"].iloc[0] == 0.91 - assert df["conv_rate"].iloc[0] == 0.85 + assertpy.assert_that(df["avg_daily_trips"].iloc[0]).is_equal_to(14) + assertpy.assert_that(df["acc_rate"].iloc[0]).is_close_to(0.91, 1e-6) + assertpy.assert_that(df["conv_rate"].iloc[0]).is_close_to(0.85, 1e-6) + + +def _get_online_features_dict_remotely( + endpoint: str, + features: Union[List[str], FeatureService], + entity_rows: List[Dict[str, Any]], + full_feature_names: bool = False, +) -> Dict[str, List[Any]]: + """Sends the online feature request to a remote feature server (through endpoint) and returns the feature dict. + + The output should be identical to: + + >>> fs.get_online_features(features=features, entity_rows=entity_rows, full_feature_names=full_feature_names).to_dict() + + This makes it easy to test the remote feature server by comparing the output to the local method. + + """ + request = { + # Convert list of dicts (entity_rows) into dict of lists (entities) for json request + "entities": {key: [row[key] for row in entity_rows] for key in entity_rows[0]}, + "full_feature_names": full_feature_names, + } + # Either set features of feature_service depending on the parameter + if isinstance(features, list): + request["features"] = features + else: + request["feature_service"] = features.name + for _ in range(25): + # Send the request to the remote feature server and get the response in JSON format + response = requests.post( + f"{endpoint}/get-online-features", json=request, timeout=30 + ).json() + # Retry if the response is internal server error, which can happen when lambda is being restarted + if response.get("message") != "Internal Server Error": + break + # Sleep between retries to give the server some time to start + time.sleep(1) + else: + raise Exception("Failed to get online features from remote feature server") + keys = response["field_values"][0]["statuses"].keys() + # Get rid of unnecessary structure in the response, leaving list of dicts + response = [row["fields"] for row in response["field_values"]] + # Convert list of dicts (response) into dict of lists which is the format of the return value + return {key: [row.get(key) for row in response] for key in keys} + + +def get_online_features_dict( + environment: Environment, + features: Union[List[str], FeatureService], + entity_rows: List[Dict[str, Any]], + full_feature_names: bool = False, +) -> Dict[str, List[Any]]: + """Get the online feature values from both SDK and remote feature servers, assert equality and return values. + + Always use this method instead of fs.get_online_features(...) in this test file. + + """ + online_features = environment.feature_store.get_online_features( + features=features, + entity_rows=entity_rows, + full_feature_names=full_feature_names, + ) + assertpy.assert_that(online_features).is_not_none() + dict1 = online_features.to_dict() + + endpoint = environment.feature_store.get_feature_server_endpoint() + # If endpoint is None, it means that the remote feature server isn't configured + if endpoint is not None: + dict2 = _get_online_features_dict_remotely( + endpoint=endpoint, + features=features, + entity_rows=entity_rows, + full_feature_names=full_feature_names, + ) + + # Make sure that the two dicts are equal + assertpy.assert_that(dict1).is_equal_to(dict2) + elif environment.python_feature_server: + raise ValueError( + "feature_store.get_feature_server_endpoint() is None while python feature server is enabled" + ) + return dict1 @pytest.mark.integration @pytest.mark.universal @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) def test_online_retrieval(environment, universal_data_sources, full_feature_names): - fs = environment.feature_store entities, datasets, data_sources = universal_data_sources feature_views = construct_universal_feature_views(data_sources) @@ -237,7 +323,9 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name ] location_pairs = np.array(list(itertools.permutations(entities["location"], 2))) - sample_location_pairs = location_pairs[np.random.choice(len(location_pairs), 10)].T + sample_location_pairs = location_pairs[ + np.random.choice(len(location_pairs), 10) + ].T.tolist() origins_df = datasets["location"][ datasets["location"]["location_id"].isin(sample_location_pairs[0]) ] @@ -270,23 +358,24 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name unprefixed_feature_refs.remove("conv_rate_plus_100") unprefixed_feature_refs.remove("conv_rate_plus_val_to_add") - online_features = fs.get_online_features( + online_features_dict = get_online_features_dict( + environment=environment, features=feature_refs, entity_rows=entity_rows, full_feature_names=full_feature_names, ) - assert online_features is not None # Test that the on demand feature views compute properly even if the dependent conv_rate # feature isn't requested. - online_features_no_conv_rate = fs.get_online_features( + online_features_no_conv_rate = get_online_features_dict( + environment=environment, features=[ref for ref in feature_refs if ref != "driver_stats:conv_rate"], entity_rows=entity_rows, full_feature_names=full_feature_names, ) + assert online_features_no_conv_rate is not None - online_features_dict = online_features.to_dict() keys = online_features_dict.keys() assert ( len(keys) == len(feature_refs) + 2 @@ -340,13 +429,14 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name ) # Check what happens for missing values - missing_responses_dict = fs.get_online_features( + missing_responses_dict = get_online_features_dict( + environment=environment, features=feature_refs, entity_rows=[ {"driver": 0, "customer_id": 0, "val_to_add": 100, "driver_age": 125} ], full_feature_names=full_feature_names, - ).to_dict() + ) assert missing_responses_dict is not None for unprefixed_feature_ref in unprefixed_feature_refs: if unprefixed_feature_ref not in {"num_rides", "avg_ride_length", "driver_age"}: @@ -358,22 +448,24 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name # Check what happens for missing request data with pytest.raises(RequestDataNotFoundInEntityRowsException): - fs.get_online_features( + get_online_features_dict( + environment=environment, features=feature_refs, entity_rows=[{"driver": 0, "customer_id": 0}], full_feature_names=full_feature_names, - ).to_dict() + ) # Also with request data with pytest.raises(RequestDataNotFoundInEntityRowsException): - fs.get_online_features( + get_online_features_dict( + environment=environment, features=feature_refs, entity_rows=[{"driver": 0, "customer_id": 0, "val_to_add": 20}], full_feature_names=full_feature_names, - ).to_dict() + ) assert_feature_service_correctness( - fs, + environment, feature_service, entity_rows, full_feature_names, @@ -395,7 +487,7 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name ) ] assert_feature_service_entity_mapping_correctness( - fs, + environment, feature_service_entity_mapping, entity_rows, full_feature_names, @@ -511,7 +603,7 @@ def get_latest_feature_values_from_dataframes( def assert_feature_service_correctness( - fs, + environment, feature_service, entity_rows, full_feature_names, @@ -520,14 +612,12 @@ def assert_feature_service_correctness( orders_df, global_df, ): - feature_service_response = fs.get_online_features( + feature_service_online_features_dict = get_online_features_dict( + environment=environment, features=feature_service, entity_rows=entity_rows, full_feature_names=full_feature_names, ) - assert feature_service_response is not None - - feature_service_online_features_dict = feature_service_response.to_dict() feature_service_keys = feature_service_online_features_dict.keys() assert ( @@ -560,7 +650,7 @@ def assert_feature_service_correctness( def assert_feature_service_entity_mapping_correctness( - fs, + environment, feature_service, entity_rows, full_feature_names, @@ -571,14 +661,12 @@ def assert_feature_service_entity_mapping_correctness( destinations_df, ): if full_feature_names: - feature_service_response = fs.get_online_features( + feature_service_online_features_dict = get_online_features_dict( + environment=environment, features=feature_service, entity_rows=entity_rows, full_feature_names=full_feature_names, ) - assert feature_service_response is not None - - feature_service_online_features_dict = feature_service_response.to_dict() feature_service_keys = feature_service_online_features_dict.keys() assert ( @@ -609,7 +697,8 @@ def assert_feature_service_entity_mapping_correctness( else: # using 2 of the same FeatureView without full_feature_names=True will result in collision with pytest.raises(FeatureNameCollisionError): - feature_service_response = fs.get_online_features( + get_online_features_dict( + environment=environment, features=feature_service, entity_rows=entity_rows, full_feature_names=full_feature_names,