diff --git a/Makefile b/Makefile index 35becab0cd..76ee18248d 100644 --- a/Makefile +++ b/Makefile @@ -139,6 +139,33 @@ test-python-universal-trino: not test_universal_types" \ sdk/python/tests +#To use Athena as an offline store, you need to create an Athena database and an S3 bucket on AWS. https://docs.aws.amazon.com/athena/latest/ug/getting-started.html +#Modify environment variables ATHENA_DATA_SOURCE, ATHENA_DATABASE, ATHENA_S3_BUCKET_NAME if you want to change the data source, database, and bucket name of S3 to use. +#If tests fail with the pytest -n 8 option, change the number to 1. +test-python-universal-athena: + PYTHONPATH='.' \ + FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.athena_repo_configuration \ + PYTEST_PLUGINS=feast.infra.offline_stores.contrib.athena_offline_store.tests \ + FEAST_USAGE=False IS_TEST=True \ + ATHENA_DATA_SOURCE=AwsDataCatalog \ + ATHENA_DATABASE=default \ + ATHENA_S3_BUCKET_NAME=feast-integration-tests \ + python -m pytest -n 8 --integration \ + -k "not test_go_feature_server and \ + not test_logged_features_validation and \ + not test_lambda and \ + not test_feature_logging and \ + not test_offline_write and \ + not test_push_offline and \ + not test_historical_retrieval_with_validation and \ + not test_historical_features_persisting and \ + not test_historical_retrieval_fails_on_validation and \ + not gcs_registry and \ + not s3_registry" \ + sdk/python/tests + + + test-python-universal-postgres: PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.postgres_repo_configuration \ diff --git a/protos/feast/core/DataSource.proto b/protos/feast/core/DataSource.proto index 62f5859ee8..5258618f3b 100644 --- a/protos/feast/core/DataSource.proto +++ b/protos/feast/core/DataSource.proto @@ -49,6 +49,7 @@ message DataSource { PUSH_SOURCE = 9; BATCH_TRINO = 10; BATCH_SPARK = 11; + BATCH_ATHENA = 12; } // Unique name of data source within the project @@ -171,6 +172,22 @@ message DataSource { string database = 4; } + // Defines options for DataSource that sources features from a Athena Query + message AthenaOptions { + // Athena table name + string table = 1; + + // SQL query that returns a table containing feature data. Must contain an event_timestamp column, and respective + // entity columns + string query = 2; + + // Athena database name + string database = 3; + + // Athena schema name + string data_source = 4; + } + // Defines options for DataSource that sources features from a Snowflake Query message SnowflakeOptions { // Snowflake table name @@ -242,5 +259,6 @@ message DataSource { PushOptions push_options = 22; SparkOptions spark_options = 27; TrinoOptions trino_options = 30; + AthenaOptions athena_options = 35; } } diff --git a/protos/feast/core/FeatureService.proto b/protos/feast/core/FeatureService.proto index 51b9c6c02a..80d32eb4de 100644 --- a/protos/feast/core/FeatureService.proto +++ b/protos/feast/core/FeatureService.proto @@ -60,6 +60,7 @@ message LoggingConfig { RedshiftDestination redshift_destination = 5; SnowflakeDestination snowflake_destination = 6; CustomDestination custom_destination = 7; + AthenaDestination athena_destination = 8; } message FileDestination { @@ -80,6 +81,11 @@ message LoggingConfig { string table_name = 1; } + message AthenaDestination { + // Destination table name. data_source and database will be taken from an offline store config + string table_name = 1; + } + message SnowflakeDestination { // Destination table name. Schema and database will be taken from an offline store config string table_name = 1; diff --git a/protos/feast/core/SavedDataset.proto b/protos/feast/core/SavedDataset.proto index 53f06f73a9..111548aa48 100644 --- a/protos/feast/core/SavedDataset.proto +++ b/protos/feast/core/SavedDataset.proto @@ -59,6 +59,7 @@ message SavedDatasetStorage { DataSource.TrinoOptions trino_storage = 8; DataSource.SparkOptions spark_storage = 9; DataSource.CustomSourceOptions custom_storage = 10; + DataSource.AthenaOptions athena_storage = 11; } } diff --git a/sdk/python/feast/__init__.py b/sdk/python/feast/__init__.py index 5d1663f7cb..d043f1a973 100644 --- a/sdk/python/feast/__init__.py +++ b/sdk/python/feast/__init__.py @@ -5,6 +5,9 @@ from importlib_metadata import PackageNotFoundError, version as _version # type: ignore from feast.infra.offline_stores.bigquery_source import BigQuerySource +from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( + AthenaSource, +) from feast.infra.offline_stores.file_source import FileSource from feast.infra.offline_stores.redshift_source import RedshiftSource from feast.infra.offline_stores.snowflake_source import SnowflakeSource @@ -50,4 +53,5 @@ "SnowflakeSource", "PushSource", "RequestSource", + "AthenaSource", ] diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index db96704640..f714573810 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -14,6 +14,7 @@ "SnowflakeSource", "SparkSource", "TrinoSource", + "AthenaSource", } diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index 931568f4e2..76b012e585 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -156,6 +156,7 @@ def to_proto(self) -> DataSourceProto.KinesisOptions: DataSourceProto.SourceType.BATCH_SNOWFLAKE: "feast.infra.offline_stores.snowflake_source.SnowflakeSource", DataSourceProto.SourceType.BATCH_TRINO: "feast.infra.offline_stores.contrib.trino_offline_store.trino_source.TrinoSource", DataSourceProto.SourceType.BATCH_SPARK: "feast.infra.offline_stores.contrib.spark_offline_store.spark_source.SparkSource", + DataSourceProto.SourceType.BATCH_ATHENA: "feast.infra.offline_stores.contrib.athena_offline_store.athena_source.AthenaSource", DataSourceProto.SourceType.STREAM_KAFKA: "feast.data_source.KafkaSource", DataSourceProto.SourceType.STREAM_KINESIS: "feast.data_source.KinesisSource", DataSourceProto.SourceType.REQUEST_SOURCE: "feast.data_source.RequestSource", @@ -183,6 +184,7 @@ class DataSource(ABC): maintainer. timestamp_field (optional): Event timestamp field used for point in time joins of feature values. + date_partition_column (optional): Timestamp column used for partitioning. Not supported by all offline stores. """ name: str @@ -192,6 +194,7 @@ class DataSource(ABC): description: str tags: Dict[str, str] owner: str + date_partition_column: str def __init__( self, @@ -203,6 +206,7 @@ def __init__( description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", + date_partition_column: Optional[str] = None, ): """ Creates a DataSource object. @@ -220,6 +224,7 @@ def __init__( tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. + date_partition_column (optional): Timestamp column used for partitioning. Not supported by all stores """ self.name = name self.timestamp_field = timestamp_field or "" @@ -237,6 +242,9 @@ def __init__( self.description = description or "" self.tags = tags or {} self.owner = owner or "" + self.date_partition_column = ( + date_partition_column if date_partition_column else "" + ) def __hash__(self): return hash((self.name, self.timestamp_field)) @@ -256,6 +264,7 @@ def __eq__(self, other): or self.timestamp_field != other.timestamp_field or self.created_timestamp_column != other.created_timestamp_column or self.field_mapping != other.field_mapping + or self.date_partition_column != other.date_partition_column or self.description != other.description or self.tags != other.tags or self.owner != other.owner diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/__init__.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py new file mode 100644 index 0000000000..bbbc6170e1 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py @@ -0,0 +1,711 @@ +import contextlib +import uuid +from datetime import datetime +from pathlib import Path +from typing import ( + Callable, + ContextManager, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, +) + +import numpy as np +import pandas as pd +import pyarrow +import pyarrow as pa +from pydantic import StrictStr +from pydantic.typing import Literal +from pytz import utc + +from feast import OnDemandFeatureView +from feast.data_source import DataSource +from feast.errors import InvalidEntityType +from feast.feature_logging import LoggingConfig, LoggingSource +from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView +from feast.infra.offline_stores import offline_utils +from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( + AthenaLoggingDestination, + AthenaSource, + SavedDatasetAthenaStorage, +) +from feast.infra.offline_stores.offline_store import ( + OfflineStore, + RetrievalJob, + RetrievalMetadata, +) +from feast.infra.utils import aws_utils +from feast.registry import BaseRegistry, Registry +from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.saved_dataset import SavedDatasetStorage +from feast.usage import log_exceptions_and_usage + + +class AthenaOfflineStoreConfig(FeastConfigBaseModel): + """Offline store config for AWS Athena""" + + type: Literal["athena"] = "athena" + """ Offline store type selector""" + + data_source: StrictStr + """ athena data source ex) AwsDataCatalog """ + + region: StrictStr + """ Athena's AWS region """ + + database: StrictStr + """ Athena database name """ + + s3_staging_location: StrictStr + """ S3 path for importing & exporting data to Athena """ + + +class AthenaOfflineStore(OfflineStore): + @staticmethod + @log_exceptions_and_usage(offline_store="athena") + def pull_latest_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + assert isinstance(data_source, AthenaSource) + assert isinstance(config.offline_store, AthenaOfflineStoreConfig) + + from_expression = data_source.get_table_query_string(config) + + partition_by_join_key_string = ", ".join(join_key_columns) + if partition_by_join_key_string != "": + partition_by_join_key_string = ( + "PARTITION BY " + partition_by_join_key_string + ) + timestamp_columns = [timestamp_field] + if created_timestamp_column: + timestamp_columns.append(created_timestamp_column) + timestamp_desc_string = " DESC, ".join(timestamp_columns) + " DESC" + field_string = ", ".join( + join_key_columns + feature_name_columns + timestamp_columns + ) + + date_partition_column = data_source.date_partition_column + + athena_client = aws_utils.get_athena_data_client(config.offline_store.region) + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + + start_date = start_date.astimezone(tz=utc) + end_date = end_date.astimezone(tz=utc) + + query = f""" + SELECT + {field_string} + {f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""} + FROM ( + SELECT {field_string}, + ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row + FROM {from_expression} + WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date.strftime('%Y-%m-%d %H:%M:%S')}' AND TIMESTAMP '{end_date.strftime('%Y-%m-%d %H:%M:%S')}' + {"AND "+date_partition_column+" >= '"+start_date.strftime('%Y-%m-%d')+"' AND "+date_partition_column+" <= '"+end_date.strftime('%Y-%m-%d')+"' " if date_partition_column != "" and date_partition_column is not None else ''} + ) + WHERE _feast_row = 1 + """ + # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized + return AthenaRetrievalJob( + query=query, + athena_client=athena_client, + s3_resource=s3_resource, + config=config, + full_feature_names=False, + ) + + @staticmethod + @log_exceptions_and_usage(offline_store="athena") + def pull_all_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + assert isinstance(data_source, AthenaSource) + from_expression = data_source.get_table_query_string(config) + + field_string = ", ".join( + join_key_columns + feature_name_columns + [timestamp_field] + ) + + athena_client = aws_utils.get_athena_data_client(config.offline_store.region) + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + + date_partition_column = data_source.date_partition_column + + query = f""" + SELECT {field_string} + FROM {from_expression} + WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date.astimezone(tz=utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}' AND TIMESTAMP '{end_date.astimezone(tz=utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}' + {"AND "+date_partition_column+" >= '"+start_date.strftime('%Y-%m-%d')+"' AND "+date_partition_column+" <= '"+end_date.strftime('%Y-%m-%d')+"' " if date_partition_column != "" and date_partition_column is not None else ''} + """ + + return AthenaRetrievalJob( + query=query, + athena_client=athena_client, + s3_resource=s3_resource, + config=config, + full_feature_names=False, + ) + + @staticmethod + @log_exceptions_and_usage(offline_store="athena") + def get_historical_features( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: Registry, + project: str, + full_feature_names: bool = False, + ) -> RetrievalJob: + assert isinstance(config.offline_store, AthenaOfflineStoreConfig) + + athena_client = aws_utils.get_athena_data_client(config.offline_store.region) + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + + # get pandas dataframe consisting of 1 row (LIMIT 1) and generate the schema out of it + entity_schema = _get_entity_schema( + entity_df, athena_client, config, s3_resource + ) + + # find timestamp column of entity df.(default = "event_timestamp"). Exception occurs if there are more than two timestamp columns. + entity_df_event_timestamp_col = ( + offline_utils.infer_event_timestamp_from_entity_df(entity_schema) + ) + + # get min,max of event_timestamp. + entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( + entity_df, + entity_df_event_timestamp_col, + athena_client, + config, + ) + + @contextlib.contextmanager + def query_generator() -> Iterator[str]: + + table_name = offline_utils.get_temp_entity_table_name() + + _upload_entity_df(entity_df, athena_client, config, s3_resource, table_name) + + expected_join_keys = offline_utils.get_expected_join_keys( + project, feature_views, registry + ) + + offline_utils.assert_expected_columns_in_entity_df( + entity_schema, expected_join_keys, entity_df_event_timestamp_col + ) + + # Build a query context containing all information required to template the Athena SQL query + query_context = offline_utils.get_feature_view_query_context( + feature_refs, + feature_views, + registry, + project, + entity_df_event_timestamp_range, + ) + + # Generate the Athena SQL query from the query context + query = offline_utils.build_point_in_time_query( + query_context, + left_table_query_string=table_name, + entity_df_event_timestamp_col=entity_df_event_timestamp_col, + entity_df_columns=entity_schema.keys(), + query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, + full_feature_names=full_feature_names, + ) + + try: + yield query + finally: + + # Always clean up the temp Athena table + aws_utils.execute_athena_query( + athena_client, + config.offline_store.data_source, + config.offline_store.database, + f"DROP TABLE IF EXISTS {config.offline_store.database}.{table_name}", + ) + + bucket = config.offline_store.s3_staging_location.replace( + "s3://", "" + ).split("/", 1)[0] + aws_utils.delete_s3_directory( + s3_resource, bucket, "entity_df/" + table_name + "/" + ) + + return AthenaRetrievalJob( + query=query_generator, + athena_client=athena_client, + s3_resource=s3_resource, + config=config, + full_feature_names=full_feature_names, + on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs( + feature_refs, project, registry + ), + metadata=RetrievalMetadata( + features=feature_refs, + keys=list(entity_schema.keys() - {entity_df_event_timestamp_col}), + min_event_timestamp=entity_df_event_timestamp_range[0], + max_event_timestamp=entity_df_event_timestamp_range[1], + ), + ) + + @staticmethod + def write_logged_features( + config: RepoConfig, + data: Union[pyarrow.Table, Path], + source: LoggingSource, + logging_config: LoggingConfig, + registry: BaseRegistry, + ): + destination = logging_config.destination + assert isinstance(destination, AthenaLoggingDestination) + + athena_client = aws_utils.get_athena_data_client(config.offline_store.region) + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + if isinstance(data, Path): + s3_path = f"{config.offline_store.s3_staging_location}/logged_features/{uuid.uuid4()}" + else: + s3_path = f"{config.offline_store.s3_staging_location}/logged_features/{uuid.uuid4()}.parquet" + + aws_utils.upload_arrow_table_to_athena( + table=data, + athena_client=athena_client, + data_source=config.offline_store.data_source, + database=config.offline_store.database, + s3_resource=s3_resource, + s3_path=s3_path, + table_name=destination.table_name, + schema=source.get_schema(registry), + fail_if_exists=False, + ) + + +class AthenaRetrievalJob(RetrievalJob): + def __init__( + self, + query: Union[str, Callable[[], ContextManager[str]]], + athena_client, + s3_resource, + config: RepoConfig, + full_feature_names: bool, + on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + metadata: Optional[RetrievalMetadata] = None, + ): + """Initialize AthenaRetrievalJob object. + + Args: + query: Athena SQL query to execute. Either a string, or a generator function that handles the artifact cleanup. + athena_client: boto3 athena client + s3_resource: boto3 s3 resource object + config: Feast repo config + full_feature_names: Whether to add the feature view prefixes to the feature names + on_demand_feature_views (optional): A list of on demand transforms to apply at retrieval time + """ + + if not isinstance(query, str): + self._query_generator = query + else: + + @contextlib.contextmanager + def query_generator() -> Iterator[str]: + assert isinstance(query, str) + yield query + + self._query_generator = query_generator + self._athena_client = athena_client + self._s3_resource = s3_resource + self._config = config + self._full_feature_names = full_feature_names + self._on_demand_feature_views = ( + on_demand_feature_views if on_demand_feature_views else [] + ) + self._metadata = metadata + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: + return self._on_demand_feature_views + + def get_temp_s3_path(self) -> str: + return ( + self._config.offline_store.s3_staging_location + + "/unload/" + + str(uuid.uuid4()) + ) + + def get_temp_table_dml_header( + self, temp_table_name: str, temp_external_location: str + ) -> str: + temp_table_dml_header = f""" + CREATE TABLE {temp_table_name} + WITH ( + external_location = '{temp_external_location}', + format = 'parquet', + write_compression = 'snappy' + ) + as + """ + return temp_table_dml_header + + @log_exceptions_and_usage + def _to_df_internal(self) -> pd.DataFrame: + with self._query_generator() as query: + temp_table_name = "_" + str(uuid.uuid4()).replace("-", "") + temp_external_location = self.get_temp_s3_path() + return aws_utils.unload_athena_query_to_df( + self._athena_client, + self._config.offline_store.data_source, + self._config.offline_store.database, + self._s3_resource, + temp_external_location, + self.get_temp_table_dml_header(temp_table_name, temp_external_location) + + query, + temp_table_name, + ) + + @log_exceptions_and_usage + def _to_arrow_internal(self) -> pa.Table: + with self._query_generator() as query: + temp_table_name = "_" + str(uuid.uuid4()).replace("-", "") + temp_external_location = self.get_temp_s3_path() + return aws_utils.unload_athena_query_to_pa( + self._athena_client, + self._config.offline_store.data_source, + self._config.offline_store.database, + self._s3_resource, + temp_external_location, + self.get_temp_table_dml_header(temp_table_name, temp_external_location) + + query, + temp_table_name, + ) + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata + + def persist(self, storage: SavedDatasetStorage): + assert isinstance(storage, SavedDatasetAthenaStorage) + self.to_athena(table_name=storage.athena_options.table) + + @log_exceptions_and_usage + def to_athena(self, table_name: str) -> None: + + if self.on_demand_feature_views: + transformed_df = self.to_df() + + _upload_entity_df( + transformed_df, + self._athena_client, + self._config, + self._s3_resource, + table_name, + ) + + return + + with self._query_generator() as query: + query = f'CREATE TABLE "{table_name}" AS ({query});\n' + + aws_utils.execute_athena_query( + self._athena_client, + self._config.offline_store.data_source, + self._config.offline_store.database, + query, + ) + + +def _upload_entity_df( + entity_df: Union[pd.DataFrame, str], + athena_client, + config: RepoConfig, + s3_resource, + table_name: str, +): + if isinstance(entity_df, pd.DataFrame): + # If the entity_df is a pandas dataframe, upload it to Athena + aws_utils.upload_df_to_athena( + athena_client, + config.offline_store.data_source, + config.offline_store.database, + s3_resource, + f"{config.offline_store.s3_staging_location}/entity_df/{table_name}/{table_name}.parquet", + table_name, + entity_df, + ) + elif isinstance(entity_df, str): + # If the entity_df is a string (SQL query), create a Athena table out of it + aws_utils.execute_athena_query( + athena_client, + config.offline_store.data_source, + config.offline_store.database, + f"CREATE TABLE {table_name} AS ({entity_df})", + ) + else: + raise InvalidEntityType(type(entity_df)) + + +def _get_entity_schema( + entity_df: Union[pd.DataFrame, str], + athena_client, + config: RepoConfig, + s3_resource, +) -> Dict[str, np.dtype]: + if isinstance(entity_df, pd.DataFrame): + return dict(zip(entity_df.columns, entity_df.dtypes)) + + elif isinstance(entity_df, str): + # get pandas dataframe consisting of 1 row (LIMIT 1) and generate the schema out of it + entity_df_sample = AthenaRetrievalJob( + f"SELECT * FROM ({entity_df}) LIMIT 1", + athena_client, + s3_resource, + config, + full_feature_names=False, + ).to_df() + return dict(zip(entity_df_sample.columns, entity_df_sample.dtypes)) + else: + raise InvalidEntityType(type(entity_df)) + + +def _get_entity_df_event_timestamp_range( + entity_df: Union[pd.DataFrame, str], + entity_df_event_timestamp_col: str, + athena_client, + config: RepoConfig, +) -> Tuple[datetime, datetime]: + if isinstance(entity_df, pd.DataFrame): + entity_df_event_timestamp = entity_df.loc[ + :, entity_df_event_timestamp_col + ].infer_objects() + if pd.api.types.is_string_dtype(entity_df_event_timestamp): + entity_df_event_timestamp = pd.to_datetime( + entity_df_event_timestamp, utc=True + ) + entity_df_event_timestamp_range = ( + entity_df_event_timestamp.min().to_pydatetime(), + entity_df_event_timestamp.max().to_pydatetime(), + ) + elif isinstance(entity_df, str): + # If the entity_df is a string (SQL query), determine range + # from table + statement_id = aws_utils.execute_athena_query( + athena_client, + config.offline_store.data_source, + config.offline_store.database, + f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max " + f"FROM ({entity_df})", + ) + res = aws_utils.get_athena_query_result(athena_client, statement_id) + entity_df_event_timestamp_range = ( + datetime.strptime( + res["Rows"][1]["Data"][0]["VarCharValue"], "%Y-%m-%d %H:%M:%S.%f" + ), + datetime.strptime( + res["Rows"][1]["Data"][1]["VarCharValue"], "%Y-%m-%d %H:%M:%S.%f" + ), + ) + else: + raise InvalidEntityType(type(entity_df)) + + return entity_df_event_timestamp_range + + +MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """ +/* + Compute a deterministic hash for the `left_table_query_string` that will be used throughout + all the logic as the field to GROUP BY the data +*/ +WITH entity_dataframe AS ( + SELECT *, + {{entity_df_event_timestamp_col}} AS entity_timestamp + {% for featureview in featureviews %} + {% if featureview.entities %} + ,( + {% for entity in featureview.entities %} + CAST({{entity}} as VARCHAR) || + {% endfor %} + CAST({{entity_df_event_timestamp_col}} AS VARCHAR) + ) AS {{featureview.name}}__entity_row_unique_id + {% else %} + ,CAST({{entity_df_event_timestamp_col}} AS VARCHAR) AS {{featureview.name}}__entity_row_unique_id + {% endif %} + {% endfor %} + FROM {{ left_table_query_string }} +), + +{% for featureview in featureviews %} + +{{ featureview.name }}__entity_dataframe AS ( + SELECT + {{ featureview.entities | join(', ')}}{% if featureview.entities %},{% else %}{% endif %} + entity_timestamp, + {{featureview.name}}__entity_row_unique_id + FROM entity_dataframe + GROUP BY + {{ featureview.entities | join(', ')}}{% if featureview.entities %},{% else %}{% endif %} + entity_timestamp, + {{featureview.name}}__entity_row_unique_id +), + +/* + This query template performs the point-in-time correctness join for a single feature set table + to the provided entity table. + + 1. We first join the current feature_view to the entity dataframe that has been passed. + This JOIN has the following logic: + - For each row of the entity dataframe, only keep the rows where the `timestamp_field` + is less than the one provided in the entity dataframe + - If there a TTL for the current feature_view, also keep the rows where the `timestamp_field` + is higher the the one provided minus the TTL + - For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been + computed previously + + The output of this CTE will contain all the necessary information and already filtered out most + of the data that is not relevant. +*/ + +{{ featureview.name }}__subquery AS ( + SELECT + {{ featureview.timestamp_field }} as event_timestamp, + {{ featureview.created_timestamp_column ~ ' as created_timestamp,' if featureview.created_timestamp_column else '' }} + {{ featureview.entity_selections | join(', ')}}{% if featureview.entity_selections %},{% else %}{% endif %} + {% for feature in featureview.features %} + {{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %}{% if loop.last %}{% else %}, {% endif %} + {% endfor %} + FROM {{ featureview.table_subquery }} + WHERE {{ featureview.timestamp_field }} <= from_iso8601_timestamp('{{ featureview.max_event_timestamp }}') + {% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %} + AND {{ featureview.date_partition_column }} <= '{{ featureview.max_event_timestamp[:10] }}' + {% endif %} + + {% if featureview.ttl == 0 %}{% else %} + AND {{ featureview.timestamp_field }} >= from_iso8601_timestamp('{{ featureview.min_event_timestamp }}') + {% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %} + AND {{ featureview.date_partition_column }} >= '{{ featureview.min_event_timestamp[:10] }}' + {% endif %} + {% endif %} + +), + +{{ featureview.name }}__base AS ( + SELECT + subquery.*, + entity_dataframe.entity_timestamp, + entity_dataframe.{{featureview.name}}__entity_row_unique_id + FROM {{ featureview.name }}__subquery AS subquery + INNER JOIN {{ featureview.name }}__entity_dataframe AS entity_dataframe + ON TRUE + AND subquery.event_timestamp <= entity_dataframe.entity_timestamp + + {% if featureview.ttl == 0 %}{% else %} + AND subquery.event_timestamp >= entity_dataframe.entity_timestamp - {{ featureview.ttl }} * interval '1' second + {% endif %} + + {% for entity in featureview.entities %} + AND subquery.{{ entity }} = entity_dataframe.{{ entity }} + {% endfor %} +), + +/* + 2. If the `created_timestamp_column` has been set, we need to + deduplicate the data first. This is done by calculating the + `MAX(created_at_timestamp)` for each event_timestamp. + We then join the data on the next CTE +*/ +{% if featureview.created_timestamp_column %} +{{ featureview.name }}__dedup AS ( + SELECT + {{featureview.name}}__entity_row_unique_id, + event_timestamp, + MAX(created_timestamp) as created_timestamp + FROM {{ featureview.name }}__base + GROUP BY {{featureview.name}}__entity_row_unique_id, event_timestamp +), +{% endif %} + +/* + 3. The data has been filtered during the first CTE "*__base" + Thus we only need to compute the latest timestamp of each feature. +*/ +{{ featureview.name }}__latest AS ( + SELECT + event_timestamp, + {% if featureview.created_timestamp_column %}created_timestamp,{% endif %} + {{featureview.name}}__entity_row_unique_id + FROM + ( + SELECT base.*, + ROW_NUMBER() OVER( + PARTITION BY base.{{featureview.name}}__entity_row_unique_id + ORDER BY base.event_timestamp DESC{% if featureview.created_timestamp_column %},base.created_timestamp DESC{% endif %} + ) AS row_number + FROM {{ featureview.name }}__base as base + {% if featureview.created_timestamp_column %} + INNER JOIN {{ featureview.name }}__dedup as dedup + ON TRUE + AND base.{{featureview.name}}__entity_row_unique_id = dedup.{{featureview.name}}__entity_row_unique_id + AND base.event_timestamp = dedup.event_timestamp + AND base.created_timestamp = dedup.created_timestamp + {% endif %} + ) + WHERE row_number = 1 +), + +/* + 4. Once we know the latest value of each feature for a given timestamp, + we can join again the data back to the original "base" dataset +*/ +{{ featureview.name }}__cleaned AS ( + SELECT base.* + FROM {{ featureview.name }}__base as base + INNER JOIN {{ featureview.name }}__latest as latest + ON TRUE + AND base.{{featureview.name}}__entity_row_unique_id = latest.{{featureview.name}}__entity_row_unique_id + AND base.event_timestamp = latest.event_timestamp + {% if featureview.created_timestamp_column %} + AND base.created_timestamp = latest.created_timestamp + {% endif %} +){% if loop.last %}{% else %}, {% endif %} + + +{% endfor %} +/* + Joins the outputs of multiple time travel joins to a single table. + The entity_dataframe dataset being our source of truth here. + */ + +SELECT {{ final_output_feature_names | join(', ')}} +FROM entity_dataframe as entity_df +{% for featureview in featureviews %} +LEFT JOIN ( + SELECT + {{featureview.name}}__entity_row_unique_id + {% for feature in featureview.features %} + ,{% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %} + {% endfor %} + FROM {{ featureview.name }}__cleaned +) as cleaned +ON TRUE +AND entity_df.{{featureview.name}}__entity_row_unique_id = cleaned.{{featureview.name}}__entity_row_unique_id +{% endfor %} +""" diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py new file mode 100644 index 0000000000..f96dc0d048 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py @@ -0,0 +1,343 @@ +from typing import Callable, Dict, Iterable, Optional, Tuple + +from feast import type_map +from feast.data_source import DataSource +from feast.errors import DataSourceNoNameException, DataSourceNotFoundException +from feast.feature_logging import LoggingDestination +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.protos.feast.core.FeatureService_pb2 import ( + LoggingConfig as LoggingConfigProto, +) +from feast.protos.feast.core.SavedDataset_pb2 import ( + SavedDatasetStorage as SavedDatasetStorageProto, +) +from feast.repo_config import RepoConfig +from feast.saved_dataset import SavedDatasetStorage +from feast.value_type import ValueType + + +class AthenaSource(DataSource): + def __init__( + self, + *, + timestamp_field: Optional[str] = "", + table: Optional[str] = None, + database: Optional[str] = None, + data_source: Optional[str] = None, + created_timestamp_column: Optional[str] = None, + field_mapping: Optional[Dict[str, str]] = None, + date_partition_column: Optional[str] = None, + query: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = "", + tags: Optional[Dict[str, str]] = None, + owner: Optional[str] = "", + ): + """ + Creates a AthenaSource object. + + Args: + timestamp_field : event timestamp column. + table (optional): Athena table where the features are stored. + database: Athena Database Name + data_source (optional): Athena data source + created_timestamp_column (optional): Timestamp column indicating when the + row was created, used for deduplicating rows. + field_mapping (optional): A dictionary mapping of column names in this data + source to column names in a feature table or view. + date_partition_column : Timestamp column used for partitioning. + query (optional): The query to be executed to obtain the features. + name (optional): Name for the source. Defaults to the table_ref if not specified. + description (optional): A human-readable description. + tags (optional): A dictionary of key-value pairs to store arbitrary metadata. + owner (optional): The owner of the athena source, typically the email of the primary + maintainer. + + + """ + + _database = "default" if table and not database else database + self.athena_options = AthenaOptions( + table=table, query=query, database=_database, data_source=data_source + ) + + if table is None and query is None: + raise ValueError('No "table" argument provided.') + + # If no name, use the table as the default name. + if name is None and table is None: + raise DataSourceNoNameException() + _name = name or table + assert _name + + super().__init__( + name=_name if _name else "", + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + field_mapping=field_mapping, + date_partition_column=date_partition_column, + description=description, + tags=tags, + owner=owner, + ) + + @staticmethod + def from_proto(data_source: DataSourceProto): + """ + Creates a AthenaSource from a protobuf representation of a AthenaSource. + + Args: + data_source: A protobuf representation of a AthenaSource + + Returns: + A AthenaSource object based on the data_source protobuf. + """ + return AthenaSource( + name=data_source.name, + timestamp_field=data_source.timestamp_field, + table=data_source.athena_options.table, + database=data_source.athena_options.database, + data_source=data_source.athena_options.data_source, + created_timestamp_column=data_source.created_timestamp_column, + field_mapping=dict(data_source.field_mapping), + date_partition_column=data_source.date_partition_column, + query=data_source.athena_options.query, + description=data_source.description, + tags=dict(data_source.tags), + ) + + # Note: Python requires redefining hash in child classes that override __eq__ + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + if not isinstance(other, AthenaSource): + raise TypeError( + "Comparisons should only involve AthenaSource class objects." + ) + + return ( + super().__eq__(other) + and self.athena_options.table == other.athena_options.table + and self.athena_options.query == other.athena_options.query + and self.athena_options.database == other.athena_options.database + and self.athena_options.data_source == other.athena_options.data_source + ) + + @property + def table(self): + """Returns the table of this Athena source.""" + return self.athena_options.table + + @property + def database(self): + """Returns the database of this Athena source.""" + return self.athena_options.database + + @property + def query(self): + """Returns the Athena query of this Athena source.""" + return self.athena_options.query + + @property + def data_source(self): + """Returns the Athena data_source of this Athena source.""" + return self.athena_options.data_source + + def to_proto(self) -> DataSourceProto: + """ + Converts a RedshiftSource object to its protobuf representation. + + Returns: + A DataSourceProto object. + """ + data_source_proto = DataSourceProto( + type=DataSourceProto.BATCH_ATHENA, + name=self.name, + timestamp_field=self.timestamp_field, + created_timestamp_column=self.created_timestamp_column, + field_mapping=self.field_mapping, + date_partition_column=self.date_partition_column, + description=self.description, + tags=self.tags, + athena_options=self.athena_options.to_proto(), + ) + + return data_source_proto + + def validate(self, config: RepoConfig): + # As long as the query gets successfully executed, or the table exists, + # the data source is validated. We don't need the results though. + self.get_table_column_names_and_types(config) + + def get_table_query_string(self, config: Optional[RepoConfig] = None) -> str: + """Returns a string that can directly be used to reference this table in SQL.""" + if self.table: + data_source = self.data_source + database = self.database + if config: + data_source = config.offline_store.data_source + database = config.offline_store.database + return f'"{data_source}"."{database}"."{self.table}"' + else: + return f"({self.query})" + + @staticmethod + def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: + return type_map.athena_to_feast_value_type + + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: + """ + Returns a mapping of column names to types for this Athena source. + + Args: + config: A RepoConfig describing the feature repo + """ + from botocore.exceptions import ClientError + + from feast.infra.offline_stores.contrib.athena_offline_store.athena import ( + AthenaOfflineStoreConfig, + ) + from feast.infra.utils import aws_utils + + assert isinstance(config.offline_store, AthenaOfflineStoreConfig) + + client = aws_utils.get_athena_data_client(config.offline_store.region) + if self.table: + try: + table = client.get_table_metadata( + CatalogName=self.data_source, + DatabaseName=self.database, + TableName=self.table, + ) + except ClientError as e: + raise aws_utils.AthenaError(e) + + # The API returns valid JSON with empty column list when the table doesn't exist + if len(table["TableMetadata"]["Columns"]) == 0: + raise DataSourceNotFoundException(self.table) + + columns = table["TableMetadata"]["Columns"] + else: + statement_id = aws_utils.execute_athena_query( + client, + config.offline_store.data_source, + config.offline_store.database, + f"SELECT * FROM ({self.query}) LIMIT 1", + ) + columns = aws_utils.get_athena_query_result(client, statement_id)[ + "ResultSetMetadata" + ]["ColumnInfo"] + + return [(column["Name"], column["Type"].upper()) for column in columns] + + +class AthenaOptions: + """ + Configuration options for a Athena data source. + """ + + def __init__( + self, + table: Optional[str], + query: Optional[str], + database: Optional[str], + data_source: Optional[str], + ): + self.table = table or "" + self.query = query or "" + self.database = database or "" + self.data_source = data_source or "" + + @classmethod + def from_proto(cls, athena_options_proto: DataSourceProto.AthenaOptions): + """ + Creates a AthenaOptions from a protobuf representation of a Athena option. + + Args: + athena_options_proto: A protobuf representation of a DataSource + + Returns: + A AthenaOptions object based on the athena_options protobuf. + """ + athena_options = cls( + table=athena_options_proto.table, + query=athena_options_proto.query, + database=athena_options_proto.database, + data_source=athena_options_proto.data_source, + ) + + return athena_options + + def to_proto(self) -> DataSourceProto.AthenaOptions: + """ + Converts an AthenaOptionsProto object to its protobuf representation. + + Returns: + A AthenaOptionsProto protobuf. + """ + athena_options_proto = DataSourceProto.AthenaOptions( + table=self.table, + query=self.query, + database=self.database, + data_source=self.data_source, + ) + + return athena_options_proto + + +class SavedDatasetAthenaStorage(SavedDatasetStorage): + _proto_attr_name = "athena_storage" + + athena_options: AthenaOptions + + def __init__( + self, + table_ref: str, + query: str = None, + database: str = None, + data_source: str = None, + ): + self.athena_options = AthenaOptions( + table=table_ref, query=query, database=database, data_source=data_source + ) + + @staticmethod + def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage: + + return SavedDatasetAthenaStorage( + table_ref=AthenaOptions.from_proto(storage_proto.athena_storage).table + ) + + def to_proto(self) -> SavedDatasetStorageProto: + return SavedDatasetStorageProto(athena_storage=self.athena_options.to_proto()) + + def to_data_source(self) -> DataSource: + return AthenaSource(table=self.athena_options.table) + + +class AthenaLoggingDestination(LoggingDestination): + _proto_kind = "athena_destination" + + table_name: str + + def __init__(self, *, table_name: str): + self.table_name = table_name + + @classmethod + def from_proto(cls, config_proto: LoggingConfigProto) -> "LoggingDestination": + return AthenaLoggingDestination( + table_name=config_proto.athena_destination.table_name, + ) + + def to_proto(self) -> LoggingConfigProto: + return LoggingConfigProto( + athena_destination=LoggingConfigProto.AthenaDestination( + table_name=self.table_name + ) + ) + + def to_data_source(self) -> DataSource: + return AthenaSource(table=self.table_name) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/__init__.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py new file mode 100644 index 0000000000..92e0d6e5f6 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py @@ -0,0 +1,130 @@ +import os +import uuid +from typing import Dict, List, Optional + +import pandas as pd + +from feast import AthenaSource +from feast.data_source import DataSource +from feast.feature_logging import LoggingDestination +from feast.infra.offline_stores.contrib.athena_offline_store.athena import ( + AthenaOfflineStoreConfig, +) +from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( + AthenaLoggingDestination, + SavedDatasetAthenaStorage, +) +from feast.infra.utils import aws_utils +from feast.repo_config import FeastConfigBaseModel +from tests.integration.feature_repos.universal.data_source_creator import ( + DataSourceCreator, +) + + +class AthenaDataSourceCreator(DataSourceCreator): + + tables: List[str] = [] + + def __init__(self, project_name: str, *args, **kwargs): + super().__init__(project_name) + self.client = aws_utils.get_athena_data_client("ap-northeast-2") + self.s3 = aws_utils.get_s3_resource("ap-northeast-2") + data_source = ( + os.environ.get("ATHENA_DATA_SOURCE") + if os.environ.get("ATHENA_DATA_SOURCE") + else "AwsDataCatalog" + ) + database = ( + os.environ.get("ATHENA_DATABASE") + if os.environ.get("ATHENA_DATABASE") + else "default" + ) + bucket_name = ( + os.environ.get("ATHENA_S3_BUCKET_NAME") + if os.environ.get("ATHENA_S3_BUCKET_NAME") + else "feast-integration-tests" + ) + self.offline_store_config = AthenaOfflineStoreConfig( + data_source=f"{data_source}", + region="ap-northeast-2", + database=f"{database}", + s3_staging_location=f"s3://{bucket_name}/test_dir", + ) + + def create_data_source( + self, + df: pd.DataFrame, + destination_name: str, + suffix: Optional[str] = None, + timestamp_field="ts", + created_timestamp_column="created_ts", + field_mapping: Dict[str, str] = None, + ) -> DataSource: + + table_name = destination_name + s3_target = ( + self.offline_store_config.s3_staging_location + + "/" + + self.project_name + + "/" + + table_name + + "/" + + table_name + + ".parquet" + ) + + aws_utils.upload_df_to_athena( + self.client, + self.offline_store_config.data_source, + self.offline_store_config.database, + self.s3, + s3_target, + table_name, + df, + ) + + self.tables.append(table_name) + + return AthenaSource( + table=table_name, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + field_mapping=field_mapping or {"ts_1": "ts"}, + database=self.offline_store_config.database, + data_source=self.offline_store_config.data_source, + ) + + def create_saved_dataset_destination(self) -> SavedDatasetAthenaStorage: + table = self.get_prefixed_table_name( + f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" + ) + self.tables.append(table) + + return SavedDatasetAthenaStorage( + table_ref=table, + database=self.offline_store_config.database, + data_source=self.offline_store_config.data_source, + ) + + def create_logged_features_destination(self) -> LoggingDestination: + table = self.get_prefixed_table_name( + f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" + ) + self.tables.append(table) + + return AthenaLoggingDestination(table_name=table) + + def create_offline_store_config(self) -> FeastConfigBaseModel: + return self.offline_store_config + + def get_prefixed_table_name(self, suffix: str) -> str: + return f"{self.project_name}_{suffix}" + + def teardown(self): + for table in self.tables: + aws_utils.execute_athena_query( + self.client, + self.offline_store_config.data_source, + self.offline_store_config.database, + f"DROP TABLE IF EXISTS {table}", + ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py b/sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py new file mode 100644 index 0000000000..32376eb652 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py @@ -0,0 +1,15 @@ +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.universal.data_sources.athena import ( + AthenaDataSourceCreator, +) + +FULL_REPO_CONFIGS = [ + IntegrationTestRepoConfig( + provider="aws", + offline_store_creator=AthenaDataSourceCreator, + ), +] + +AVAILABLE_OFFLINE_STORES = [("aws", AthenaDataSourceCreator)] diff --git a/sdk/python/feast/infra/offline_stores/offline_utils.py b/sdk/python/feast/infra/offline_stores/offline_utils.py index 8b963a864b..829d46c5ca 100644 --- a/sdk/python/feast/infra/offline_stores/offline_utils.py +++ b/sdk/python/feast/infra/offline_stores/offline_utils.py @@ -93,6 +93,9 @@ class FeatureViewQueryContext: entity_selections: List[str] min_event_timestamp: Optional[str] max_event_timestamp: str + date_partition_column: Optional[ + str + ] # this attribute is added because partition pruning affects Athena's query performance. def get_feature_view_query_context( @@ -142,6 +145,11 @@ def get_feature_view_query_context( feature_view.batch_source.created_timestamp_column, ) + date_partition_column = reverse_field_mapping.get( + feature_view.batch_source.date_partition_column, + feature_view.batch_source.date_partition_column, + ) + max_event_timestamp = to_naive_utc(entity_df_timestamp_range[1]).isoformat() min_event_timestamp = None if feature_view.ttl: @@ -162,6 +170,7 @@ def get_feature_view_query_context( entity_selections=entity_selections, min_event_timestamp=min_event_timestamp, max_event_timestamp=max_event_timestamp, + date_partition_column=date_partition_column, ) query_context.append(context) diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index 3c8ad9d71b..72c40e4fc2 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -21,7 +21,7 @@ RedshiftQueryError, RedshiftTableNameTooLong, ) -from feast.type_map import pa_to_redshift_value_type +from feast.type_map import pa_to_athena_value_type, pa_to_redshift_value_type try: import boto3 @@ -32,7 +32,6 @@ raise FeastExtrasDependencyImportError("aws", str(e)) - REDSHIFT_TABLE_NAME_MAX_LENGTH = 127 @@ -672,3 +671,331 @@ def list_s3_files(aws_region: str, path: str) -> List[str]: contents = objects["Contents"] files = [f"s3://{bucket}/{content['Key']}" for content in contents] return files + + +# Athena + + +def get_athena_data_client(aws_region: str): + """ + Get the athena Data API Service client for the given AWS region. + """ + return boto3.client("athena", config=Config(region_name=aws_region)) + + +@retry( + wait=wait_exponential(multiplier=1, max=4), + retry=retry_if_exception_type(ConnectionClosedError), + stop=stop_after_attempt(5), + reraise=True, +) +def execute_athena_query_async( + athena_data_client, data_source: str, database: str, query: str +) -> dict: + """Execute Athena statement asynchronously. Does not wait for the query to finish. + + Raises AthenaCredentialsError if the statement couldn't be executed due to the validation error. + + Args: + athena_data_client: athena Data API Service client + data_source: athena Cluster Identifier + database: athena Database Name + query: The SQL query to execute + + Returns: JSON response + + """ + try: + # return athena_data_client.execute_statement( + return athena_data_client.start_query_execution( + QueryString=query, + QueryExecutionContext={"Database": database}, + WorkGroup="primary", + ) + + except ClientError as e: + raise AthenaQueryError(e) + + +class AthenaStatementNotFinishedError(Exception): + pass + + +@retry( + wait=wait_exponential(multiplier=1, max=30), + retry=retry_if_exception_type(AthenaStatementNotFinishedError), + reraise=True, +) +def wait_for_athena_execution(athena_data_client, execution: dict) -> None: + """Waits for the Athena statement to finish. Raises AthenaQueryError if the statement didn't succeed. + + We use exponential backoff for checking the query state until it's not running. The backoff starts with + 0.1 seconds and doubles exponentially until reaching 30 seconds, at which point the backoff is fixed. + + Args: + athena_data_client: athena Service boto3 client + execution: The athena execution to wait for (result of execute_athena_statement) + + Returns: None + + """ + response = athena_data_client.get_query_execution( + QueryExecutionId=execution["QueryExecutionId"] + ) + if response["QueryExecution"]["Status"]["State"] in ("QUEUED", "RUNNING"): + raise AthenaStatementNotFinishedError # Retry + if response["QueryExecution"]["Status"]["State"] != "SUCCEEDED": + raise AthenaQueryError(response) # Don't retry. Raise exception. + + +def drop_temp_table( + athena_data_client, data_source: str, database: str, temp_table: str +): + query = f"DROP TABLE `{database}.{temp_table}`" + execute_athena_query_async(athena_data_client, data_source, database, query) + + +def execute_athena_query( + athena_data_client, + data_source: str, + database: str, + query: str, + temp_table: str = None, +) -> str: + """Execute athena statement synchronously. Waits for the query to finish. + + Raises athenaCredentialsError if the statement couldn't be executed due to the validation error. + Raises athenaQueryError if the query runs but finishes with errors. + + + Args: + athena_data_client: athena Data API Service client + data_source: athena data source Name + database: athena Database Name + query: The SQL query to execute + temp_table: temp table name to be deleted after query execution. + + Returns: Statement ID + + """ + + execution = execute_athena_query_async( + athena_data_client, data_source, database, query + ) + wait_for_athena_execution(athena_data_client, execution) + if temp_table is not None: + drop_temp_table(athena_data_client, data_source, database, temp_table) + + return execution["QueryExecutionId"] + + +def get_athena_query_result(athena_data_client, query_execution_id: str) -> dict: + """Get the athena query result""" + response = athena_data_client.get_query_results(QueryExecutionId=query_execution_id) + return response["ResultSet"] + + +class AthenaError(Exception): + def __init__(self, details): + super().__init__(f"Athena API failed. Details: {details}") + + +class AthenaQueryError(Exception): + def __init__(self, details): + super().__init__(f"Athena SQL Query failed to finish. Details: {details}") + + +class AthenaTableNameTooLong(Exception): + def __init__(self, table_name: str): + super().__init__( + f"Athena table(Data catalog) names have a maximum length of 255 characters, but the table name {table_name} has length {len(table_name)} characters." + ) + + +def unload_athena_query_to_pa( + athena_data_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + query: str, + temp_table: str, +) -> pa.Table: + """Unload Athena Query results to S3 and get the results in PyArrow Table format""" + bucket, key = get_bucket_and_key(s3_path) + + execute_athena_query_and_unload_to_s3( + athena_data_client, data_source, database, query, temp_table + ) + + with tempfile.TemporaryDirectory() as temp_dir: + download_s3_directory(s3_resource, bucket, key, temp_dir) + delete_s3_directory(s3_resource, bucket, key) + return pq.read_table(temp_dir) + + +def unload_athena_query_to_df( + athena_data_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + query: str, + temp_table: str, +) -> pd.DataFrame: + """Unload Athena Query results to S3 and get the results in Pandas DataFrame format""" + table = unload_athena_query_to_pa( + athena_data_client, + data_source, + database, + s3_resource, + s3_path, + query, + temp_table, + ) + return table.to_pandas() + + +def execute_athena_query_and_unload_to_s3( + athena_data_client, + data_source: str, + database: str, + query: str, + temp_table: str, +) -> None: + """Unload Athena Query results to S3 + + Args: + athena_data_client: Athena Data API Service client + data_source: Athena data source + database: Redshift Database Name + query: The SQL query to execute + temp_table: temp table name to be deleted after query execution. + + """ + + execute_athena_query(athena_data_client, data_source, database, query, temp_table) + + +def upload_df_to_athena( + athena_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + table_name: str, + df: pd.DataFrame, +): + """Uploads a Pandas DataFrame to S3(Athena) as a new table. + + The caller is responsible for deleting the table when no longer necessary. + + Args: + athena_client: Athena API Service client + data_source: Athena Data Source + database: Athena Database Name + s3_resource: S3 Resource object + s3_path: S3 path where the Parquet file is temporarily uploaded + table_name: The name of the new Data Catalog table where we copy the dataframe + df: The Pandas DataFrame to upload + + Raises: + AthenaTableNameTooLong: The specified table name is too long. + """ + + # Drop the index so that we dont have unnecessary columns + df.reset_index(drop=True, inplace=True) + + # Convert Pandas DataFrame into PyArrow table and compile the Athena table schema. + # Note, if the underlying data has missing values, + # pandas will convert those values to np.nan if the dtypes are numerical (floats, ints, etc.) or boolean. + # If the dtype is 'object', then missing values are inferred as python `None`s. + # More details at: + # https://pandas.pydata.org/pandas-docs/stable/user_guide/missing_data.html#values-considered-missing + table = pa.Table.from_pandas(df) + upload_arrow_table_to_athena( + table, + athena_client, + data_source=data_source, + database=database, + s3_resource=s3_resource, + s3_path=s3_path, + table_name=table_name, + ) + + +def upload_arrow_table_to_athena( + table: Union[pyarrow.Table, Path], + athena_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + table_name: str, + schema: Optional[pyarrow.Schema] = None, + fail_if_exists: bool = True, +): + """Uploads an Arrow Table to S3(Athena). + + Here's how the upload process works: + 1. PyArrow Table is serialized into a Parquet format on local disk + 2. The Parquet file is uploaded to S3 + 3. an Athena(data catalog) table is created. the S3 directory(in number 2) will be set as an external location. + 4. The local disk & s3 paths are cleaned up + + Args: + table: The Arrow Table or Path to parquet dataset to upload + athena_client: Athena API Service client + data_source: Athena data source + database: Athena Database Name + s3_resource: S3 Resource object + s3_path: S3 path where the Parquet file is temporarily uploaded + table_name: The name of the new Athena table where we copy the dataframe + schema: (Optionally) client may provide arrow Schema which will be converted into Athena table schema + fail_if_exists: fail if table with such name exists or append data to existing table + + Raises: + AthenaTableNameTooLong: The specified table name is too long. + """ + DATA_CATALOG_TABLE_NAME_MAX_LENGTH = 255 + + if len(table_name) > DATA_CATALOG_TABLE_NAME_MAX_LENGTH: + raise AthenaTableNameTooLong(table_name) + + if isinstance(table, pyarrow.Table) and not schema: + schema = table.schema + + if not schema: + raise ValueError("Schema must be specified when data is passed as a Path") + + bucket, key = get_bucket_and_key(s3_path) + + column_query_list = ", ".join( + [f"`{field.name}` {pa_to_athena_value_type(field.type)}" for field in schema] + ) + + with tempfile.TemporaryFile(suffix=".parquet") as parquet_temp_file: + pq.write_table(table, parquet_temp_file) + parquet_temp_file.seek(0) + s3_resource.Object(bucket, key).put(Body=parquet_temp_file) + + create_query = ( + f"CREATE EXTERNAL TABLE {database}.{table_name} " + f"({column_query_list}) " + f"STORED AS PARQUET " + f"LOCATION '{s3_path[:s3_path.rfind('/')]}' " + f"TBLPROPERTIES('parquet.compress' = 'SNAPPY') " + ) + + try: + execute_athena_query( + athena_client, + data_source, + database, + f"{create_query}", + ) + finally: + pass + # Clean up S3 temporary data + # for file_path in uploaded_files: + # s3_resource.Object(bucket, file_path).delete() diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 5bc25faee5..34df1a215f 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -58,6 +58,7 @@ "spark": "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore", "trino": "feast.infra.offline_stores.contrib.trino_offline_store.trino.TrinoOfflineStore", "postgres": "feast.infra.offline_stores.contrib.postgres_offline_store.postgres.PostgreSQLOfflineStore", + "athena": "feast.infra.offline_stores.contrib.athena_offline_store.athena.AthenaOfflineStore", } FEATURE_SERVER_CONFIG_CLASS_FOR_TYPE = { diff --git a/sdk/python/feast/templates/athena/__init__.py b/sdk/python/feast/templates/athena/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/python/feast/templates/athena/example.py b/sdk/python/feast/templates/athena/example.py new file mode 100644 index 0000000000..768a2709dc --- /dev/null +++ b/sdk/python/feast/templates/athena/example.py @@ -0,0 +1,108 @@ +import os +from datetime import datetime, timedelta + +import pandas as pd + +from feast import Entity, Feature, FeatureStore, FeatureView, ValueType +from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( + AthenaSource, +) + + +def test_end_to_end(): + + try: + fs = FeatureStore("feature_repo/") + + driver_hourly_stats = AthenaSource( + timestamp_field="event_timestamp", + table="driver_stats", + # table="driver_stats_partitioned", + database="sampledb", + data_source="AwsDataCatalog", + created_timestamp_column="created", + # date_partition_column="std_date" + ) + + driver = Entity( + name="driver_id", + value_type=ValueType.INT64, + description="driver id", + ) + + driver_hourly_stats_view = FeatureView( + name="driver_hourly_stats", + entities=["driver_id"], + ttl=timedelta(days=365), + features=[ + Feature(name="conv_rate", dtype=ValueType.FLOAT), + Feature(name="acc_rate", dtype=ValueType.FLOAT), + Feature(name="avg_daily_trips", dtype=ValueType.INT64), + ], + online=True, + batch_source=driver_hourly_stats, + ) + + # apply repository + fs.apply([driver_hourly_stats, driver, driver_hourly_stats_view]) + + print(fs.list_data_sources()) + print(fs.list_feature_views()) + + entity_df = pd.DataFrame( + {"driver_id": [1001], "event_timestamp": [datetime.now()]} + ) + + # Read features from offline store + + feature_vector = ( + fs.get_historical_features( + features=["driver_hourly_stats:conv_rate"], entity_df=entity_df + ) + .to_df() + .to_dict() + ) + conv_rate = feature_vector["conv_rate"][0] + print(conv_rate) + assert conv_rate > 0 + + # load data into online store + fs.materialize_incremental(end_date=datetime.now()) + + online_response = fs.get_online_features( + features=[ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + "driver_hourly_stats:avg_daily_trips", + ], + entity_rows=[{"driver_id": 1002}], + ) + online_response_dict = online_response.to_dict() + print(online_response_dict) + + except Exception as e: + print(e) + finally: + # tear down feature store + fs.teardown() + + +def test_cli(): + os.system("PYTHONPATH=$PYTHONPATH:/$(pwd) feast -c feature_repo apply") + try: + os.system("PYTHONPATH=$PYTHONPATH:/$(pwd) ") + with open("output", "r") as f: + output = f.read() + + if "Pulling latest features from my offline store" not in output: + raise Exception( + 'Failed to successfully use provider from CLI. See "output" for more details.' + ) + finally: + os.system("PYTHONPATH=$PYTHONPATH:/$(pwd) feast -c feature_repo teardown") + + +if __name__ == "__main__": + # pass + test_end_to_end() + test_cli() diff --git a/sdk/python/feast/templates/athena/feature_store.yaml b/sdk/python/feast/templates/athena/feature_store.yaml new file mode 100644 index 0000000000..13e7898e86 --- /dev/null +++ b/sdk/python/feast/templates/athena/feature_store.yaml @@ -0,0 +1,13 @@ +project: repo +registry: registry.db +provider: aws +online_store: + type: sqlite + path: online_store.db +offline_store: + type: athena + region: ap-northeast-2 + database: sampledb + data_source: AwsDataCatalog + s3_staging_location: s3://sagemaker-yelo-test +entity_key_serialization_version: 2 \ No newline at end of file diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index ed4b7cba59..a9dc4e25da 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -791,3 +791,60 @@ def pg_type_code_to_arrow(code: int) -> str: return feast_value_type_to_pa( pg_type_to_feast_value_type(pg_type_code_to_pg_type(code)) ) + + +def athena_to_feast_value_type(athena_type_as_str: str) -> ValueType: + # Type names from https://docs.aws.amazon.com/athena/latest/ug/data-types.html + type_map = { + "null": ValueType.UNKNOWN, + "boolean": ValueType.BOOL, + "tinyint": ValueType.INT32, + "smallint": ValueType.INT32, + "int": ValueType.INT32, + "bigint": ValueType.INT64, + "double": ValueType.DOUBLE, + "float": ValueType.FLOAT, + "binary": ValueType.BYTES, + "char": ValueType.STRING, + "varchar": ValueType.STRING, + "string": ValueType.STRING, + "timestamp": ValueType.UNIX_TIMESTAMP, + # skip date,decimal,array,map,struct + } + return type_map[athena_type_as_str.lower()] + + +def pa_to_athena_value_type(pa_type: pyarrow.DataType) -> str: + # PyArrow types: https://arrow.apache.org/docs/python/api/datatypes.html + # Type names from https://docs.aws.amazon.com/athena/latest/ug/data-types.html + pa_type_as_str = str(pa_type).lower() + if pa_type_as_str.startswith("timestamp"): + return "timestamp" + + if pa_type_as_str.startswith("date"): + return "date" + + if pa_type_as_str.startswith("decimal"): + return pa_type_as_str + + # We have to take into account how arrow types map to parquet types as well. + # For example, null type maps to int32 in parquet, so we have to use int4 in Redshift. + # Other mappings have also been adjusted accordingly. + type_map = { + "null": "null", + "bool": "boolean", + "int8": "tinyint", + "int16": "smallint", + "int32": "int", + "int64": "bigint", + "uint8": "tinyint", + "uint16": "tinyint", + "uint32": "tinyint", + "uint64": "tinyint", + "float": "float", + "double": "double", + "binary": "binary", + "string": "string", + } + + return type_map[pa_type_as_str]