diff --git a/docs/reference/offline-stores/trino.md b/docs/reference/offline-stores/trino.md index 446db620e3..fd437a7aa6 100644 --- a/docs/reference/offline-stores/trino.md +++ b/docs/reference/offline-stores/trino.md @@ -27,6 +27,47 @@ offline_store: catalog: memory connector: type: memory + user: trino + source: feast-trino-offline-store + http-scheme: https + ssl-verify: false + x-trino-extra-credential-header: foo=bar, baz=qux + + # enables authentication in Trino connections, pick the one you need + # if you don't need authentication, you can safely remove the whole auth block + auth: + # Basic Auth + type: basic + config: + username: foo + password: $FOO + + # Certificate + type: certificate + config: + cert-file: /path/to/cert/file + key-file: /path/to/key/file + + # JWT + type: jwt + config: + token: $JWT_TOKEN + + # OAuth2 (no config required) + type: oauth2 + + # Kerberos + type: kerberos + config: + config-file: /path/to/kerberos/config/file + service-name: foo + mutual-authentication: true + force-preemptive: true + hostname-override: custom-hostname + sanitize-mutual-error-response: true + principal: principal-name + delegate: true + ca_bundle: /path/to/ca/bundle/file online_store: path: data/online_store.db ``` diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py index 67efa6a27f..a5aa53df7a 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py @@ -67,6 +67,11 @@ def __init__( catalog="memory", host="localhost", port=self.exposed_port, + source="trino-python-client", + http_scheme="http", + verify=False, + extra_credential=None, + auth=None, ) def teardown(self): diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py index e0f73404eb..f662cda913 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py @@ -1,12 +1,18 @@ import uuid from datetime import date, datetime -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd import pyarrow -from pydantic import StrictStr -from trino.auth import Authentication +from pydantic import Field, FilePath, SecretStr, StrictBool, StrictStr, root_validator +from trino.auth import ( + BasicAuthentication, + CertificateAuthentication, + JWTAuthentication, + KerberosAuthentication, + OAuth2Authentication, +) from feast.data_source import DataSource from feast.errors import InvalidEntityType @@ -32,6 +38,87 @@ from feast.usage import log_exceptions_and_usage +class BasicAuthModel(FeastConfigBaseModel): + username: StrictStr + password: SecretStr + + +class KerberosAuthModel(FeastConfigBaseModel): + config: Optional[FilePath] = Field(default=None, alias="config-file") + service_name: Optional[StrictStr] = Field(default=None, alias="service-name") + mutual_authentication: StrictBool = Field( + default=False, alias="mutual-authentication" + ) + force_preemptive: StrictBool = Field(default=False, alias="force-preemptive") + hostname_override: Optional[StrictStr] = Field( + default=None, alias="hostname-override" + ) + sanitize_mutual_error_response: StrictBool = Field( + default=True, alias="sanitize-mutual-error-response" + ) + principal: Optional[StrictStr] + delegate: StrictBool = False + ca_bundle: Optional[FilePath] = Field(default=None, alias="ca-bundle-file") + + +class JWTAuthModel(FeastConfigBaseModel): + token: SecretStr + + +class CertificateAuthModel(FeastConfigBaseModel): + cert: FilePath = Field(default=None, alias="cert-file") + key: FilePath = Field(default=None, alias="key-file") + + +CLASSES_BY_AUTH_TYPE = { + "kerberos": { + "auth_model": KerberosAuthModel, + "trino_auth": KerberosAuthentication, + }, + "basic": { + "auth_model": BasicAuthModel, + "trino_auth": BasicAuthentication, + }, + "jwt": { + "auth_model": JWTAuthModel, + "trino_auth": JWTAuthentication, + }, + "oauth2": { + "auth_model": None, + "trino_auth": OAuth2Authentication, + }, + "certificate": { + "auth_model": CertificateAuthModel, + "trino_auth": CertificateAuthentication, + }, +} + + +class AuthConfig(FeastConfigBaseModel): + type: Literal["kerberos", "basic", "jwt", "oauth2", "certificate"] + config: Optional[Dict[StrictStr, Any]] + + @root_validator + def config_only_nullable_for_oauth2(cls, values): + auth_type = values["type"] + auth_config = values["config"] + if auth_type != "oauth2" and auth_config is None: + raise ValueError(f"config cannot be null for auth type '{auth_type}'") + + return values + + def to_trino_auth(self): + auth_type = self.type + trino_auth_cls = CLASSES_BY_AUTH_TYPE[auth_type]["trino_auth"] + + if auth_type == "oauth2": + return trino_auth_cls() + + model_cls = CLASSES_BY_AUTH_TYPE[auth_type]["auth_model"] + model = model_cls(**self.config) + return trino_auth_cls(**model.dict()) + + class TrinoOfflineStoreConfig(FeastConfigBaseModel): """Online store config for Trino""" @@ -47,6 +134,23 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel): catalog: StrictStr """ Catalog of the Trino cluster """ + user: StrictStr + """ User of the Trino cluster """ + + source: Optional[StrictStr] = "trino-python-client" + """ ID of the feast's Trino Python client, useful for debugging """ + + http_scheme: Literal["http", "https"] = Field(default="http", alias="http-scheme") + """ HTTP scheme that should be used while establishing a connection to the Trino cluster """ + + verify: StrictBool = Field(default=True, alias="ssl-verify") + """ Whether the SSL certificate emited by the Trino cluster should be verified or not """ + + extra_credential: Optional[StrictStr] = Field( + default=None, alias="x-trino-extra-credential-header" + ) + """ Specifies the HTTP header X-Trino-Extra-Credential, e.g. user1=pwd1, user2=pwd2 """ + connector: Dict[str, str] """ Trino connector to use as well as potential extra parameters. @@ -59,6 +163,16 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel): dataset: StrictStr = "feast" """ (optional) Trino Dataset name for temporary tables """ + auth: Optional[AuthConfig] + """ + (optional) Authentication mechanism to use when connecting to Trino. Supported options are: + - kerberos + - basic + - jwt + - oauth2 + - certificate + """ + class TrinoRetrievalJob(RetrievalJob): def __init__( @@ -162,9 +276,6 @@ def pull_latest_from_table_or_query( created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, - user: Optional[str] = None, - auth: Optional[Authentication] = None, - http_scheme: Optional[str] = None, ) -> TrinoRetrievalJob: assert isinstance(config.offline_store, TrinoOfflineStoreConfig) assert isinstance(data_source, TrinoSource) @@ -181,9 +292,7 @@ def pull_latest_from_table_or_query( timestamps.append(created_timestamp_column) timestamp_desc_string = " DESC, ".join(timestamps) + " DESC" field_string = ", ".join(join_key_columns + feature_name_columns + timestamps) - client = _get_trino_client( - config=config, user=user, auth=auth, http_scheme=http_scheme - ) + client = _get_trino_client(config=config) query = f""" SELECT @@ -216,17 +325,12 @@ def get_historical_features( registry: Registry, project: str, full_feature_names: bool = False, - user: Optional[str] = None, - auth: Optional[Authentication] = None, - http_scheme: Optional[str] = None, ) -> TrinoRetrievalJob: assert isinstance(config.offline_store, TrinoOfflineStoreConfig) for fv in feature_views: assert isinstance(fv.batch_source, TrinoSource) - client = _get_trino_client( - config=config, user=user, auth=auth, http_scheme=http_scheme - ) + client = _get_trino_client(config=config) table_reference = _get_table_reference_for_new_entity( catalog=config.offline_store.catalog, @@ -307,17 +411,12 @@ def pull_all_from_table_or_query( timestamp_field: str, start_date: datetime, end_date: datetime, - user: Optional[str] = None, - auth: Optional[Authentication] = None, - http_scheme: Optional[str] = None, ) -> RetrievalJob: assert isinstance(config.offline_store, TrinoOfflineStoreConfig) assert isinstance(data_source, TrinoSource) from_expression = data_source.get_table_query_string() - client = _get_trino_client( - config=config, user=user, auth=auth, http_scheme=http_scheme - ) + client = _get_trino_client(config=config) field_string = ", ".join( join_key_columns + feature_name_columns + [timestamp_field] ) @@ -378,21 +477,22 @@ def _upload_entity_df_and_get_entity_schema( # TODO: Ensure that the table expires after some time -def _get_trino_client( - config: RepoConfig, - user: Optional[str], - auth: Optional[Any], - http_scheme: Optional[str], -) -> Trino: - client = Trino( - user=user, - catalog=config.offline_store.catalog, +def _get_trino_client(config: RepoConfig) -> Trino: + auth = None + if config.offline_store.auth is not None: + auth = config.offline_store.auth.to_trino_auth() + + return Trino( host=config.offline_store.host, port=config.offline_store.port, + user=config.offline_store.user, + catalog=config.offline_store.catalog, + source=config.offline_store.source, + http_scheme=config.offline_store.http_scheme, + verify=config.offline_store.verify, + extra_credential=config.offline_store.extra_credential, auth=auth, - http_scheme=http_scheme, ) - return client def _get_entity_df_event_timestamp_range( diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_queries.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_queries.py index 97c61f78a6..50472407bc 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_queries.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_queries.py @@ -1,7 +1,6 @@ from __future__ import annotations import datetime -import os import signal from dataclasses import dataclass from enum import Enum @@ -30,34 +29,27 @@ class QueryStatus(Enum): class Trino: def __init__( self, - host: Optional[str] = None, - port: Optional[int] = None, - user: Optional[str] = None, - catalog: Optional[str] = None, - auth: Optional[Any] = None, - http_scheme: Optional[str] = None, - source: Optional[str] = None, - extra_credential: Optional[str] = None, + host: str, + port: int, + user: str, + catalog: str, + source: Optional[str], + http_scheme: str, + verify: bool, + extra_credential: Optional[str], + auth: Optional[trino.Authentication], ): - self.host = host or os.getenv("TRINO_HOST") - self.port = port or os.getenv("TRINO_PORT") - self.user = user or os.getenv("TRINO_USER") - self.catalog = catalog or os.getenv("TRINO_CATALOG") - self.auth = auth or os.getenv("TRINO_AUTH") - self.http_scheme = http_scheme or os.getenv("TRINO_HTTP_SCHEME") - self.source = source or os.getenv("TRINO_SOURCE") - self.extra_credential = extra_credential or os.getenv("TRINO_EXTRA_CREDENTIAL") + self.host = host + self.port = port + self.user = user + self.catalog = catalog + self.source = source + self.http_scheme = http_scheme + self.verify = verify + self.extra_credential = extra_credential + self.auth = auth self._cursor: Optional[Cursor] = None - if self.host is None: - raise ValueError("TRINO_HOST must be set if not passed in") - if self.port is None: - raise ValueError("TRINO_PORT must be set if not passed in") - if self.user is None: - raise ValueError("TRINO_USER must be set if not passed in") - if self.catalog is None: - raise ValueError("TRINO_CATALOG must be set if not passed in") - def _get_cursor(self) -> Cursor: if self._cursor is None: headers = ( @@ -70,9 +62,10 @@ def _get_cursor(self) -> Cursor: port=self.port, user=self.user, catalog=self.catalog, - auth=self.auth, - http_scheme=self.http_scheme, source=self.source, + http_scheme=self.http_scheme, + verify=self.verify, + auth=self.auth, http_headers=headers, ).cursor() diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_source.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_source.py index f09b79069c..e618e8664e 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_source.py @@ -227,10 +227,20 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: + auth = None + if config.offline_store.auth is not None: + auth = config.offline_store.auth.to_trino_auth() + client = Trino( catalog=config.offline_store.catalog, host=config.offline_store.host, port=config.offline_store.port, + user=config.offline_store.user, + source=config.offline_store.source, + http_scheme=config.offline_store.http_scheme, + verify=config.offline_store.verify, + extra_credential=config.offline_store.extra_credential, + auth=auth, ) if self.table: table_schema = client.execute_query(