Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
expediamatt committed Nov 30, 2023
1 parent c1e4ec4 commit f15a03c
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 62 deletions.
45 changes: 26 additions & 19 deletions sdk/python/feast/expediagroup/pydantic_models/data_source_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,13 @@
from pydantic import Field as PydanticField
from typing_extensions import Annotated, Self

from feast.data_source import (
KafkaSource,
RequestSource,
)
from feast.data_source import KafkaSource, RequestSource
from feast.expediagroup.pydantic_models.field_model import FieldModel
from feast.expediagroup.pydantic_models.stream_format_model import (
AnyStreamFormat,
AvroFormatModel,
JsonFormatModel,
ProtoFormatModel,

)
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
SparkSource,
Expand Down Expand Up @@ -167,6 +163,7 @@ def from_data_source(
timestamp_field=data_source.timestamp_field,
)


AnyBatchDataSource = Annotated[
Union[RequestSourceModel, SparkSourceModel],
PydanticField(discriminator="model_type"),
Expand All @@ -176,6 +173,7 @@ def from_data_source(
SUPPORTED_MESSAGE_FORMATS = [AvroFormatModel, JsonFormatModel, ProtoFormatModel]
SUPPORTED_KAFKA_BATCH_SOURCES = [RequestSourceModel, SparkSourceModel]


class KafkaSourceModel(DataSourceModel):
"""
Pydantic Model of a Feast KafkaSource.
Expand Down Expand Up @@ -213,8 +211,10 @@ def to_data_source(self) -> KafkaSource:
description=self.description,
tags=self.tags,
owner=self.owner,
batch_source=self.batch_source.to_data_source(),
watermark_delay_threshold=self.watermark_delay_threshold
batch_source=self.batch_source.to_data_source()
if self.batch_source
else None,
watermark_delay_threshold=self.watermark_delay_threshold,
)

@classmethod
Expand All @@ -229,23 +229,24 @@ def from_data_source(
A KafkaSourceModel.
"""


class_ = getattr(
sys.modules[__name__],
type(data_source.kafka_options.message_format).__name__ + "Model",
)
sys.modules[__name__],
type(data_source.kafka_options.message_format).__name__ + "Model",
)
if class_ not in SUPPORTED_MESSAGE_FORMATS:
raise ValueError(
"Data Source message format is not a supported stream format."
)
message_format = class_.from_stream_format(data_source.kafka_options.message_format)
message_format = class_.from_stream_format(
data_source.kafka_options.message_format
)

batch_source = None
if data_source.batch_source:
class_ = getattr(
sys.modules[__name__],
type(data_source.batch_source).__name__ + "Model",
)
sys.modules[__name__],
type(data_source.batch_source).__name__ + "Model",
)
if class_ not in SUPPORTED_KAFKA_BATCH_SOURCES:
raise ValueError(
"Kafka Source's batch source type is not a supported data source type."
Expand All @@ -256,10 +257,16 @@ def from_data_source(
name=data_source.name,
timestamp_field=data_source.timestamp_field,
message_format=message_format,
kafka_bootstrap_servers=data_source.kafka_options.kafka_bootstrap_servers if data_source.kafka_options.kafka_bootstrap_servers else "",
topic=data_source.kafka_options.topic if data_source.kafka_options.topic else "",
kafka_bootstrap_servers=data_source.kafka_options.kafka_bootstrap_servers
if data_source.kafka_options.kafka_bootstrap_servers
else "",
topic=data_source.kafka_options.topic
if data_source.kafka_options.topic
else "",
created_timestamp_column=data_source.created_timestamp_column,
field_mapping=data_source.field_mapping if data_source.field_mapping else None,
field_mapping=data_source.field_mapping
if data_source.field_mapping
else None,
description=data_source.description,
tags=data_source.tags if data_source.tags else None,
owner=data_source.owner,
Expand All @@ -271,6 +278,6 @@ def from_data_source(
# https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83
# This lets us discriminate child classes of DataSourceModel with type hints.
AnyDataSource = Annotated[
Union[RequestSourceModel, SparkSourceModel, KafkaSourceModel],
Union[RequestSourceModel, SparkSourceModel],
PydanticField(discriminator="model_type"),
]
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def from_feature_service(
cls,
feature_service: FeatureService,
) -> Self: # type: ignore

features = []
for feature in feature_service._features:
class_ = getattr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from feast.expediagroup.pydantic_models.data_source_model import (
AnyBatchDataSource,
KafkaSourceModel,
RequestSourceModel,
SparkSourceModel,
KafkaSourceModel,
)
from feast.expediagroup.pydantic_models.entity_model import EntityModel
from feast.expediagroup.pydantic_models.field_model import FieldModel
Expand Down Expand Up @@ -137,7 +137,9 @@ def from_feature_view(
# on a parameter.
stream_source = None
if feature_view.stream_source:
stream_source = KafkaSourceModel.from_data_source(feature_view.stream_source)
stream_source = KafkaSourceModel.from_data_source(
feature_view.stream_source
)
return cls(
name=feature_view.name,
original_entities=[
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from typing import Dict, Literal, Optional, Union
from typing import Literal, Union

from pydantic import BaseModel
from pydantic import Field as PydanticField
from typing_extensions import Annotated, Self

from feast.data_format import (
StreamFormat,
AvroFormat,
JsonFormat,
ProtoFormat
)
from feast.data_format import AvroFormat, JsonFormat, ProtoFormat


class StreamFormatModel(BaseModel):
Expand Down Expand Up @@ -52,9 +47,7 @@ def to_stream_format(self) -> AvroFormat:
Returns:
An AvroFormat.
"""
return AvroFormat(
schema_json=self.schoma
)
return AvroFormat(schema_json=self.schoma)

@classmethod
def from_stream_format(
Expand All @@ -67,9 +60,7 @@ def from_stream_format(
Returns:
An AvroFormatModel.
"""
return cls(
schoma=avro_format.schema_json
)
return cls(schoma=avro_format.schema_json)


class JsonFormatModel(StreamFormatModel):
Expand All @@ -87,9 +78,7 @@ def to_stream_format(self) -> JsonFormat:
Returns:
A JsonFormat.
"""
return JsonFormat(
schema_json=self.schoma
)
return JsonFormat(schema_json=self.schoma)

@classmethod
def from_stream_format(
Expand All @@ -102,9 +91,7 @@ def from_stream_format(
Returns:
A JsonFormatModel.
"""
return cls(
schoma=json_format.schema_json
)
return cls(schoma=json_format.schema_json)


class ProtoFormatModel(StreamFormatModel):
Expand All @@ -122,9 +109,7 @@ def to_stream_format(self) -> ProtoFormat:
Returns:
A ProtoFormat.
"""
return ProtoFormat(
class_path=self.class_path
)
return ProtoFormat(class_path=self.class_path)

@classmethod
def from_stream_format(
Expand All @@ -137,14 +122,12 @@ def from_stream_format(
Returns:
A ProtoFormatModel.
"""
return cls(
class_path=proto_format.class_path
)
return cls(class_path=proto_format.class_path)


# https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83
# This lets us discriminate child classes of DataSourceModel with type hints.
AnyStreamFormat = Annotated[
Union[AvroFormatModel, JsonFormatModel, ProtoFormatModel],
PydanticField(discriminator="format"),
]
]
19 changes: 5 additions & 14 deletions sdk/python/tests/unit/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,14 @@
import pandas as pd
from pydantic import BaseModel

from feast.data_source import (
KafkaSource,
RequestSource,
)
from feast.data_format import (
AvroFormat,
JsonFormat,
ProtoFormat
)
from feast.data_format import AvroFormat
from feast.data_source import KafkaSource, RequestSource
from feast.entity import Entity
from feast.expediagroup.pydantic_models.data_source_model import (
AnyDataSource,
KafkaSourceModel,
RequestSourceModel,
SparkSourceModel,
KafkaSourceModel,
)
from feast.expediagroup.pydantic_models.entity_model import EntityModel
from feast.expediagroup.pydantic_models.feature_service import FeatureServiceModel
Expand Down Expand Up @@ -217,7 +210,6 @@ def test_idempotent_sparksource_conversion():


def test_idempotent_kafkasource_conversion():

schema = [
Field(name="f1", dtype=Float32),
Field(name="f2", dtype=Bool),
Expand All @@ -226,7 +218,7 @@ def test_idempotent_kafkasource_conversion():
name="source",
schema=schema,
description="desc",
tags={"tag1": "val1"},,
tags={"tag1": "val1"},
owner="feast",
)

Expand All @@ -243,7 +235,6 @@ def test_idempotent_kafkasource_conversion():
field_mapping={"source_thing": "thing_val"},
owner="[email protected]",
watermark_delay_threshold=timedelta(days=1),

)

pydantic_obj = KafkaSourceModel.from_data_source(python_obj)
Expand Down Expand Up @@ -346,7 +337,7 @@ def test_idempotent_featureview_with_streaming_source_conversion():
timestamp_field="whatevs, just a string",
message_format=AvroFormat(schema_json="whatevs, also just a string"),
batch_source=request_source,
description="Bob's used message formats emporium is open 24/7"
description="Bob's used message formats emporium is open 24/7",
)
feature_view = FeatureView(
name="my-feature-view",
Expand Down

0 comments on commit f15a03c

Please sign in to comment.