Skip to content

Commit

Permalink
feat(ingest): make sink use type annotations (#5899)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Sep 11, 2022
1 parent 8f72de5 commit 220ae0b
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 122 deletions.
9 changes: 7 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/api/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
import inspect
from typing import Any, Dict, Generic, Type, TypeVar, Union
from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar, Union

import entrypoints
import typing_inspect
Expand Down Expand Up @@ -38,8 +38,11 @@ def import_path(path: str) -> Any:
class PluginRegistry(Generic[T]):
_mapping: Dict[str, Union[str, Type[T], Exception]]

def __init__(self) -> None:
def __init__(
self, extra_cls_check: Optional[Callable[[Type[T]], None]] = None
) -> None:
self._mapping = {}
self._extra_cls_check = extra_cls_check

def _get_registered_type(self) -> Type[T]:
cls = typing_inspect.get_generic_type(self)
Expand All @@ -55,6 +58,8 @@ def _check_cls(self, cls: Type[T]) -> None:
super_cls = self._get_registered_type()
if not issubclass(cls, super_cls):
raise ValueError(f"must be derived from {super_cls}; got {cls}")
if self._extra_cls_check is not None:
self._extra_cls_check(cls)

def _register(
self, key: str, tp: Union[str, Type[T], Exception], override: bool = False
Expand Down
50 changes: 37 additions & 13 deletions metadata-ingestion/src/datahub/ingestion/api/sink.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import datetime
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any, Generic, Optional, Type, TypeVar, cast

from datahub.configuration.common import ConfigModel
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit
from datahub.ingestion.api.report import Report
from datahub.utilities.lossy_collections import LossyList
from datahub.utilities.type_annotations import get_class_from_annotation


@dataclass
Expand All @@ -15,7 +17,7 @@ class SinkReport(Report):
records_written_per_second: int = 0
warnings: LossyList[Any] = field(default_factory=LossyList)
failures: LossyList[Any] = field(default_factory=LossyList)
start_time: datetime.datetime = datetime.datetime.now()
start_time: datetime.datetime = field(default_factory=datetime.datetime.now)
current_time: Optional[datetime.datetime] = None
total_duration_in_seconds: Optional[float] = None

Expand Down Expand Up @@ -75,23 +77,47 @@ def on_failure(
pass


# See https://github.com/python/mypy/issues/5374 for why we suppress this mypy error.
@dataclass # type: ignore[misc]
class Sink(Closeable, metaclass=ABCMeta):
SinkReportType = TypeVar("SinkReportType", bound=SinkReport)
SinkConfig = TypeVar("SinkConfig", bound=ConfigModel)
Self = TypeVar("Self", bound="Sink")


class Sink(Generic[SinkConfig, SinkReportType], Closeable, metaclass=ABCMeta):
"""All Sinks must inherit this base class."""

ctx: PipelineContext
config: SinkConfig
report: SinkReportType

@classmethod
@abstractmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "Sink":
def get_config_class(cls) -> Type[SinkConfig]:
config_class = get_class_from_annotation(cls, Sink, ConfigModel)
assert config_class, "Sink subclasses must define a config class"
return cast(Type[SinkConfig], config_class)

@classmethod
def get_report_class(cls) -> Type[SinkReportType]:
report_class = get_class_from_annotation(cls, Sink, SinkReport)
assert report_class, "Sink subclasses must define a report class"
return cast(Type[SinkReportType], report_class)

def __init__(self, ctx: PipelineContext, config: SinkConfig):
self.ctx = ctx
self.config = config
self.report = self.get_report_class()()

self.__post_init__()

def __post_init__(self) -> None:
pass

@abstractmethod
@classmethod
def create(cls: Type[Self], config_dict: dict, ctx: PipelineContext) -> "Self":
return cls(ctx, cls.get_config_class().parse_obj(config_dict))

def handle_work_unit_start(self, workunit: WorkUnit) -> None:
pass

@abstractmethod
def handle_work_unit_end(self, workunit: WorkUnit) -> None:
pass

Expand All @@ -102,11 +128,9 @@ def write_record_async(
# must call callback when done.
pass

@abstractmethod
def get_report(self) -> SinkReport:
pass
def get_report(self) -> SinkReportType:
return self.report

@abstractmethod
def close(self) -> None:
pass

Expand Down
23 changes: 2 additions & 21 deletions metadata-ingestion/src/datahub/ingestion/sink/console.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,17 @@
import dataclasses
import logging

from datahub.configuration.common import ConfigModel
from datahub.ingestion.api.common import RecordEnvelope
from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class ConsoleSink(Sink):
report: SinkReport = dataclasses.field(default_factory=SinkReport)

@classmethod
def create(cls, config_dict, ctx):
return cls(ctx)

def handle_work_unit_start(self, wu):
pass

def handle_work_unit_end(self, wu):
pass

class ConsoleSink(Sink[ConfigModel, SinkReport]):
def write_record_async(
self, record_envelope: RecordEnvelope, write_callback: WriteCallback
) -> None:
print(f"{record_envelope}")
if write_callback:
self.report.report_record_written(record_envelope)
write_callback.on_success(record_envelope, {})

def get_report(self):
return self.report

def close(self):
pass
20 changes: 3 additions & 17 deletions metadata-ingestion/src/datahub/ingestion/sink/datahub_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from datahub.emitter.kafka_emitter import DatahubKafkaEmitter, KafkaEmitterConfig
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit
from datahub.ingestion.api.common import RecordEnvelope, WorkUnit
from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback
from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
MetadataChangeEvent,
Expand Down Expand Up @@ -35,23 +35,12 @@ def kafka_callback(self, err: Optional[Exception], msg: str) -> None:
self.write_callback.on_success(self.record_envelope, {"msg": msg})


@dataclass
class DatahubKafkaSink(Sink):
config: KafkaSinkConfig
report: SinkReport
class DatahubKafkaSink(Sink[KafkaSinkConfig, SinkReport]):
emitter: DatahubKafkaEmitter

def __init__(self, config: KafkaSinkConfig, ctx: PipelineContext):
super().__init__(ctx)
self.config = config
self.report = SinkReport()
def __post_init__(self):
self.emitter = DatahubKafkaEmitter(self.config)

@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "DatahubKafkaSink":
config = KafkaSinkConfig.parse_obj(config_dict)
return cls(config, ctx)

def handle_work_unit_start(self, workunit: WorkUnit) -> None:
pass

Expand Down Expand Up @@ -91,8 +80,5 @@ def write_record_async(
f"The datahub-kafka sink only supports MetadataChangeEvent/MetadataChangeProposal[Wrapper] classes, not {type(record)}"
)

def get_report(self):
return self.report

def close(self) -> None:
self.emitter.flush()
22 changes: 4 additions & 18 deletions metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from datahub.configuration.common import ConfigurationError, OperationalError
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit
from datahub.ingestion.api.common import RecordEnvelope, WorkUnit
from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.graph.client import DatahubClientConfig
Expand Down Expand Up @@ -87,17 +87,11 @@ def shutdown(self, wait=True):
self.executor.shutdown(wait)


@dataclass
class DatahubRestSink(Sink):
config: DatahubRestSinkConfig
class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
emitter: DatahubRestEmitter
report: DataHubRestSinkReport
treat_errors_as_warnings: bool = False

def __init__(self, ctx: PipelineContext, config: DatahubRestSinkConfig):
super().__init__(ctx)
self.config = config
self.report = DataHubRestSinkReport()
def __post_init__(self) -> None:
self.emitter = DatahubRestEmitter(
self.config.server,
self.config.token,
Expand Down Expand Up @@ -131,11 +125,6 @@ def __init__(self, ctx: PipelineContext, config: DatahubRestSinkConfig):
bound=self.config.max_pending_requests,
)

@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "DatahubRestSink":
config = DatahubRestSinkConfig.parse_obj(config_dict)
return cls(ctx, config)

def handle_work_unit_start(self, workunit: WorkUnit) -> None:
if isinstance(workunit, MetadataWorkUnit):
mwu: MetadataWorkUnit = cast(MetadataWorkUnit, workunit)
Expand Down Expand Up @@ -222,14 +211,11 @@ def write_record_async(
except Exception as e:
write_callback.on_failure(record_envelope, e, failure_metadata={})

def get_report(self) -> SinkReport:
return self.report

def close(self):
self.executor.shutdown(wait=True)

def __repr__(self) -> str:
return self.emitter.__repr__()

def configured(self) -> str:
return self.__repr__()
return repr(self)
27 changes: 3 additions & 24 deletions metadata-ingestion/src/datahub/ingestion/sink/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from datahub.configuration.common import ConfigModel
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope
from datahub.ingestion.api.common import RecordEnvelope
from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback
from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
MetadataChangeEvent,
Expand All @@ -20,31 +20,13 @@ class FileSinkConfig(ConfigModel):
filename: str


class FileSink(Sink):
config: FileSinkConfig
report: SinkReport

def __init__(self, ctx: PipelineContext, config: FileSinkConfig):
super().__init__(ctx)
self.config = config
self.report = SinkReport()

class FileSink(Sink[FileSinkConfig, SinkReport]):
def __post_init__(self) -> None:
fpath = pathlib.Path(self.config.filename)
self.file = fpath.open("w")
self.file.write("[\n")
self.wrote_something = False

@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "FileSink":
config = FileSinkConfig.parse_obj(config_dict)
return cls(ctx, config)

def handle_work_unit_start(self, wu):
self.id = wu.id

def handle_work_unit_end(self, wu):
pass

def write_record_async(
self,
record_envelope: RecordEnvelope[
Expand All @@ -69,9 +51,6 @@ def write_record_async(
self.report.report_record_written(record_envelope)
write_callback.on_success(record_envelope, {})

def get_report(self):
return self.report

def close(self):
self.file.write("\n]")
self.file.close()
12 changes: 11 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/sink/sink_registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import dataclasses
from typing import Type

from datahub.ingestion.api.registry import PluginRegistry
from datahub.ingestion.api.sink import Sink

sink_registry = PluginRegistry[Sink]()

def _check_sink_classes(cls: Type[Sink]) -> None:
assert not dataclasses.is_dataclass(cls), f"Sink {cls} is a dataclass"
assert cls.get_config_class()
assert cls.get_report_class()


sink_registry = PluginRegistry[Sink](extra_cls_check=_check_sink_classes)
sink_registry.register_from_entrypoint("datahub.ingestion.sink.plugins")

# These sinks are always enabled
Expand Down
27 changes: 4 additions & 23 deletions metadata-ingestion/tests/test_helpers/sink_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List

from datahub.ingestion.api.common import RecordEnvelope, WorkUnit
from datahub.configuration.common import ConfigModel
from datahub.ingestion.api.common import RecordEnvelope
from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback
from datahub.ingestion.run.pipeline import PipelineContext


class RecordingSinkReport(SinkReport):
Expand All @@ -13,28 +13,9 @@ def report_record_written(self, record_envelope: RecordEnvelope) -> None:
self.received_records.append(record_envelope)


class RecordingSink(Sink):
def __init__(self):
self.sink_report = RecordingSinkReport()

@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "Sink":
return cls()

def handle_work_unit_start(self, workunit: WorkUnit) -> None:
pass

def handle_work_unit_end(self, workunit: WorkUnit) -> None:
pass

class RecordingSink(Sink[ConfigModel, RecordingSinkReport]):
def write_record_async(
self, record_envelope: RecordEnvelope, callback: WriteCallback
) -> None:
self.sink_report.report_record_written(record_envelope)
self.report.report_record_written(record_envelope)
callback.on_success(record_envelope, {})

def get_report(self) -> SinkReport:
return self.sink_report

def close(self) -> None:
pass
Loading

0 comments on commit 220ae0b

Please sign in to comment.