Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Push to Redshift batch source offline store directly #2819

Merged
merged 26 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,8 @@ def _write_to_offline_store(
feature_view = self.get_feature_view(
feature_view_name, allow_registry_cache=allow_registry_cache
)
df.reset_index(drop=True)
felixwang9817 marked this conversation as resolved.
Show resolved Hide resolved

table = pa.Table.from_pandas(df)
provider = self._get_provider()
provider.ingest_df_to_offline_store(feature_view, table)
Expand Down
65 changes: 65 additions & 0 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from pathlib import Path
from typing import (
Any,
Callable,
ContextManager,
Dict,
Expand Down Expand Up @@ -41,6 +42,7 @@
from feast.registry import BaseRegistry
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import feast_value_type_to_pa, redshift_to_feast_value_type
from feast.usage import log_exceptions_and_usage


Expand Down Expand Up @@ -297,6 +299,69 @@ def write_logged_features(
fail_if_exists=False,
)

@staticmethod
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, RedshiftOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when redshift type required"
)
if not isinstance(feature_view.batch_source, RedshiftSource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not redshift source"
)
redshift_options = feature_view.batch_source.redshift_options
redshift_client = aws_utils.get_redshift_data_client(
config.offline_store.region
)

column_name_to_type = feature_view.batch_source.get_table_column_names_and_types(
config
)
pa_schema_list = []
column_names = []
for column_name, redshift_type in column_name_to_type:
pa_schema_list.append(
(
column_name,
feast_value_type_to_pa(redshift_to_feast_value_type(redshift_type)),
)
)
column_names.append(column_name)
pa_schema = pa.schema(pa_schema_list)
if column_names != table.column_names:
raise ValueError(
f"Input dataframe has incorrect schema or wrong order, expected columns are: {column_names}"
)

if table.schema != pa_schema:
table = table.cast(pa_schema)

s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

aws_utils.upload_arrow_table_to_redshift(
table=table,
redshift_data_client=redshift_client,
cluster_id=config.offline_store.cluster_id,
database=redshift_options.database
or config.offline_store.database, # Users can define database in the source if needed but it's not required.
user=config.offline_store.user,
s3_resource=s3_resource,
s3_path=f"{config.offline_store.s3_staging_location}/push/{uuid.uuid4()}.parquet",
iam_role=config.offline_store.iam_role,
table_name=redshift_options.table,
schema=pa_schema,
fail_if_exists=False,
)


class RedshiftRetrievalJob(RetrievalJob):
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ def online_write_batch(
def offline_write_batch(
self,
config: RepoConfig,
table: FeatureView,
feature_view: FeatureView,
data: pa.Table,
progress: Optional[Callable[[int], Any]],
) -> None:
set_usage_attribute("provider", self.__class__.__name__)

if self.offline_store:
self.offline_store.offline_write_batch(config, table, data, progress)
self.offline_store.offline_write_batch(config, feature_view, data, progress)

@log_exceptions_and_usage(sampler=RatioSampler(ratio=0.001))
def online_read(
Expand Down
58 changes: 57 additions & 1 deletion sdk/python/feast/infra/utils/aws_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ def upload_df_to_redshift(
)


def delete_redshift_table(
felixwang9817 marked this conversation as resolved.
Show resolved Hide resolved
redshift_data_client, cluster_id: str, database: str, user: str, table_name: str,
):
drop_query = f"DROP {table_name} IF EXISTS"
execute_redshift_statement(
redshift_data_client, cluster_id, database, user, drop_query,
)


def upload_arrow_table_to_redshift(
table: Union[pyarrow.Table, Path],
redshift_data_client,
Expand Down Expand Up @@ -320,7 +329,7 @@ def upload_arrow_table_to_redshift(
cluster_id,
database,
user,
f"{create_query}; {copy_query}",
f"{create_query}; {copy_query};",
)
finally:
# Clean up S3 temporary data
Expand Down Expand Up @@ -371,6 +380,53 @@ def temporarily_upload_df_to_redshift(
)


@contextlib.contextmanager
def temporarily_upload_arrow_table_to_redshift(
table: Union[pyarrow.Table, Path],
redshift_data_client,
cluster_id: str,
database: str,
user: str,
s3_resource,
iam_role: str,
s3_path: str,
table_name: str,
schema: Optional[pyarrow.Schema] = None,
fail_if_exists: bool = True,
) -> Iterator[None]:
"""Uploads a Arrow Table to Redshift as a new table with cleanup logic.

This is essentially the same as upload_arrow_table_to_redshift (check out its docstring for full details),
but unlike it this method is a generator and should be used with `with` block. For example:

>>> with temporarily_upload_arrow_table_to_redshift(...): # doctest: +SKIP
>>> # Use `table_name` table in Redshift here
>>> # `table_name` will not exist at this point, since it's cleaned up by the `with` block

"""
# Upload the dataframe to Redshift
upload_arrow_table_to_redshift(
table,
redshift_data_client,
cluster_id,
database,
user,
s3_resource,
s3_path,
iam_role,
table_name,
schema,
fail_if_exists,
)

yield

# Clean up the uploaded Redshift table
execute_redshift_statement(
redshift_data_client, cluster_id, database, user, f"DROP TABLE {table_name}",
)


def download_s3_directory(s3_resource, bucket: str, key: str, local_dir: str):
"""Download the S3 directory to a local disk"""
bucket_obj = s3_resource.Bucket(bucket)
Expand Down
85 changes: 49 additions & 36 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tests.integration.feature_repos.repo_configuration import (
AVAILABLE_OFFLINE_STORES,
AVAILABLE_ONLINE_STORES,
OFFLINE_STORE_TO_PROVIDER_CONFIG,
Environment,
TestData,
construct_test_environment,
Expand Down Expand Up @@ -196,16 +197,24 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
"""
if "environment" in metafunc.fixturenames:
markers = {m.name: m for m in metafunc.definition.own_markers}

offline_stores = None
if "universal_offline_stores" in markers:
offline_stores = AVAILABLE_OFFLINE_STORES
# Offline stores can be explicitly requested
if "only" in markers["universal_offline_stores"].kwargs:
offline_stores = [
OFFLINE_STORE_TO_PROVIDER_CONFIG.get(store_name)
for store_name in markers["universal_offline_stores"].kwargs["only"]
if store_name in OFFLINE_STORE_TO_PROVIDER_CONFIG
]
else:
offline_stores = AVAILABLE_OFFLINE_STORES
else:
# default offline store for testing online store dimension
offline_stores = [("local", FileDataSourceCreator)]

online_stores = None
if "universal_online_stores" in markers:
# Online stores are explicitly requested
# Online stores can be explicitly requested
if "only" in markers["universal_online_stores"].kwargs:
online_stores = [
AVAILABLE_ONLINE_STORES.get(store_name)
Expand Down Expand Up @@ -240,40 +249,44 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
extra_dimensions.append({"go_feature_retrieval": True})

configs = []
for provider, offline_store_creator in offline_stores:
for online_store, online_store_creator in online_stores:
for dim in extra_dimensions:
config = {
"provider": provider,
"offline_store_creator": offline_store_creator,
"online_store": online_store,
"online_store_creator": online_store_creator,
**dim,
}
# temporary Go works only with redis
if config.get("go_feature_retrieval") and (
not isinstance(online_store, dict)
or online_store["type"] != "redis"
):
continue

# aws lambda works only with dynamo
if (
config.get("python_feature_server")
and config.get("provider") == "aws"
and (
if offline_stores:
felixwang9817 marked this conversation as resolved.
Show resolved Hide resolved
for provider, offline_store_creator in offline_stores:
for online_store, online_store_creator in online_stores:
for dim in extra_dimensions:
config = {
"provider": provider,
"offline_store_creator": offline_store_creator,
"online_store": online_store,
"online_store_creator": online_store_creator,
**dim,
}
# temporary Go works only with redis
if config.get("go_feature_retrieval") and (
not isinstance(online_store, dict)
or online_store["type"] != "dynamodb"
)
):
continue

c = IntegrationTestRepoConfig(**config)

if c not in _config_cache:
_config_cache[c] = c

configs.append(_config_cache[c])
or online_store["type"] != "redis"
):
continue

# aws lambda works only with dynamo
if (
config.get("python_feature_server")
and config.get("provider") == "aws"
and (
not isinstance(online_store, dict)
or online_store["type"] != "dynamodb"
)
):
continue

c = IntegrationTestRepoConfig(**config)

if c not in _config_cache:
_config_cache[c] = c

configs.append(_config_cache[c])
else:
# No offline stores requested -> setting the default or first available
offline_stores = [("local", FileDataSourceCreator)]

metafunc.parametrize(
"environment", configs, indirect=True, ids=[str(c) for c in configs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@
"connection_string": "127.0.0.1:6001,127.0.0.1:6002,127.0.0.1:6003",
}

OFFLINE_STORE_TO_PROVIDER_CONFIG: Dict[str, DataSourceCreator] = {
"file": ("local", FileDataSourceCreator),
"gcp": ("gcp", BigQueryDataSourceCreator),
"redshift": ("aws", RedshiftDataSourceCreator),
"snowflake": ("aws", RedshiftDataSourceCreator),
}

AVAILABLE_OFFLINE_STORES: List[Tuple[str, Type[DataSourceCreator]]] = [
("local", FileDataSourceCreator),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@


@pytest.mark.integration
@pytest.mark.universal_online_stores
def test_writing_incorrect_order_fails(environment, universal_data_sources):
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
@pytest.mark.universal_online_stores(only=["sqlite"])
def test_writing_columns_in_incorrect_order_fails(environment, universal_data_sources):
# TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in
store = environment.feature_store
_, _, data_sources = universal_data_sources
Expand Down Expand Up @@ -59,7 +60,8 @@ def test_writing_incorrect_order_fails(environment, universal_data_sources):


@pytest.mark.integration
@pytest.mark.universal_online_stores
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
@pytest.mark.universal_online_stores(only=["sqlite"])
def test_writing_incorrect_schema_fails(environment, universal_data_sources):
# TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in
store = environment.feature_store
Expand Down Expand Up @@ -107,7 +109,8 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources):


@pytest.mark.integration
@pytest.mark.universal_online_stores
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
@pytest.mark.universal_online_stores(only=["sqlite"])
felixwang9817 marked this conversation as resolved.
Show resolved Hide resolved
def test_writing_consecutively_to_offline_store(environment, universal_data_sources):
store = environment.feature_store
_, _, data_sources = universal_data_sources
Expand Down