diff --git a/feathr_project/feathr/registry/_feathr_registry_client.py b/feathr_project/feathr/registry/_feathr_registry_client.py index 0851d5aae..d09d051e5 100644 --- a/feathr_project/feathr/registry/_feathr_registry_client.py +++ b/feathr_project/feathr/registry/_feathr_registry_client.py @@ -21,7 +21,7 @@ from feathr.definition.feature import Feature, FeatureBase from feathr.definition.feature_derivations import DerivedFeature from feathr.definition.repo_definitions import RepoDefinitions -from feathr.definition.source import GenericSource, HdfsSource, InputContext, JdbcSource, SnowflakeSource, Source +from feathr.definition.source import GenericSource, HdfsSource, InputContext, JdbcSource, SnowflakeSource, Source, SparkSqlSource from feathr.definition.transformation import ExpressionTransformation, Transformation, WindowAggTransformation from feathr.definition.typed_key import TypedKey from feathr.registry.feature_registry import FeathrRegistry @@ -260,6 +260,9 @@ def source_to_def(v: Source) -> dict: elif isinstance(v, GenericSource): ret = v.to_dict() ret["name"] = v.name + elif isinstance(v, SparkSqlSource): + ret = v.to_dict() + ret["name"] = v.name else: raise ValueError(f"Unsupported source type {v.__class__}") if hasattr(v, "preprocessing") and v.preprocessing: @@ -281,6 +284,17 @@ def dict_to_source(v: dict) -> Source: source = None if type == INPUT_CONTEXT: source = InputContext() + elif type == "sparksql": + source = SparkSqlSource(name=v["attributes"]["name"], + sql=v["attributes"].get("sql"), + table=v["attributes"].get("table"), + preprocessing=_correct_function_indentation( + v["attributes"].get("preprocessing")), + event_timestamp_column=v["attributes"].get( + "eventTimestampColumn"), + timestamp_format=v["attributes"].get( + "timestampFormat"), + registry_tags=v["attributes"].get("tags", {})) elif type == "jdbc": source = JdbcSource(name=v["attributes"]["name"], url=v["attributes"].get("url"),