Skip to content

Commit

Permalink
fix spark source
Browse files Browse the repository at this point in the history
Signed-off-by: Danny Chiao <[email protected]>
  • Loading branch information
adchia committed Mar 4, 2022
1 parent fcfca3b commit bdd220f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, path):
class DataSourceNoNameException(Exception):
def __init__(self):
super().__init__(
"Unable to infer a name for this data source. Either table_ref or name must be specified."
"Unable to infer a name for this data source. Either table or name must be specified."
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyspark.sql import SparkSession

from feast.data_source import DataSource
from feast.errors import DataSourceNoNameException
from feast.infra.offline_stores.offline_utils import get_temp_entity_table_name
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.protos.feast.core.SavedDataset_pb2 import (
Expand All @@ -30,6 +31,7 @@ class SparkSourceFormat(Enum):
class SparkSource(DataSource):
def __init__(
self,
name: Optional[str] = None,
table: Optional[str] = None,
query: Optional[str] = None,
path: Optional[str] = None,
Expand All @@ -39,7 +41,15 @@ def __init__(
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = None,
):
# If no name, use the table_ref as the default name
_name = name
if not _name:
if table:
_name = table
else:
raise DataSourceNoNameException()
super().__init__(
_name,
event_timestamp_column,
created_timestamp_column,
field_mapping,
Expand Down Expand Up @@ -106,6 +116,7 @@ def from_proto(data_source: DataSourceProto) -> Any:

spark_options = SparkOptions.from_proto(data_source.custom_options)
return SparkSource(
name=data_source.name,
field_mapping=dict(data_source.field_mapping),
table=spark_options.table,
query=spark_options.query,
Expand All @@ -118,6 +129,7 @@ def from_proto(data_source: DataSourceProto) -> Any:

def to_proto(self) -> DataSourceProto:
data_source_proto = DataSourceProto(
name=self.name,
type=DataSourceProto.CUSTOM_SOURCE,
field_mapping=self.field_mapping,
custom_options=self.spark_options.to_proto(),
Expand Down

0 comments on commit bdd220f

Please sign in to comment.