Skip to content

Commit

Permalink
refactor(ingest): streamline pydantic configs (#6011)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Sep 26, 2022
1 parent 9d0a2de commit f227bd9
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,12 @@ class BaseTimeWindowConfig(ConfigModel):
# `start_time` and `end_time` will be populated by the pre-validators.
# However, we must specify a "default" value here or pydantic will complain
# if those fields are not set by the user.
end_time: datetime = Field(default=None, description="Latest date of usage to consider. Default: Current time in UTC") # type: ignore
end_time: datetime = Field(
default_factory=lambda: datetime.now(tz=timezone.utc),
description="Latest date of usage to consider. Default: Current time in UTC",
)
start_time: datetime = Field(default=None, description="Earliest date of usage to consider. Default: Last full day in UTC (or hour, depending on `bucket_duration`)") # type: ignore

@pydantic.validator("end_time", pre=True, always=True)
def default_end_time(
cls, v: Any, *, values: Dict[str, Any], **kwargs: Any
) -> datetime:
return v or datetime.now(tz=timezone.utc)

@pydantic.validator("start_time", pre=True, always=True)
def default_start_time(
cls, v: Any, *, values: Dict[str, Any], **kwargs: Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def pydantic_renamed_field(
old_name: str,
new_name: str,
transform: Callable = _default_rename_transform,
print_warning: bool = True,
) -> classmethod:
def _validate_field_rename(cls: Type, values: dict) -> dict:
if old_name in values:
Expand All @@ -22,10 +23,12 @@ def _validate_field_rename(cls: Type, values: dict) -> dict:
f"Cannot specify both {old_name} and {new_name} in the same config. Note that {old_name} has been deprecated in favor of {new_name}."
)
else:
warnings.warn(
f"The {old_name} is deprecated, please use {new_name} instead.",
UserWarning,
)
if print_warning:
warnings.warn(
f"The {old_name} is deprecated, please use {new_name} instead.",
UserWarning,
stacklevel=2,
)
values[new_name] = transform(values.pop(old_name))
return values

Expand Down
1 change: 1 addition & 0 deletions metadata-ingestion/src/datahub/ingestion/source/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ class DBTConfig(StatefulIngestionConfigBase):
description="When enabled, applies the mappings that are defined through the `query_tag_mapping` directives.",
)
write_semantics: str = Field(
# TODO: Replace with the WriteSemantics enum.
default="PATCH",
description='Whether the new tags, terms and owners to be added will override the existing ones added only by this source or not. Value for this config can be "PATCH" or "OVERRIDE"',
)
Expand Down
37 changes: 14 additions & 23 deletions metadata-ingestion/src/datahub/ingestion/source/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
import json
import logging
import os.path
import pathlib
from dataclasses import dataclass, field
from enum import auto
from io import BufferedReader
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union

import ijson
from pydantic import root_validator, validator
from pydantic import validator
from pydantic.fields import Field

from datahub.configuration.common import ConfigEnum, ConfigModel
from datahub.configuration.validate_field_rename import pydantic_renamed_field
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
SupportStatus,
Expand Down Expand Up @@ -43,8 +44,10 @@ class FileReadMode(ConfigEnum):


class FileSourceConfig(ConfigModel):
filename: Optional[str] = Field(None, description="Path to file to ingest.")
path: str = Field(
filename: Optional[str] = Field(
None, description="[deprecated in favor or `path`] The file to ingest."
)
path: pathlib.Path = Field(
description="Path to folder or file to ingest. If pointed to a folder, all files with extension {file_extension} (default json) within that folder will be processed."
)
file_extension: str = Field(
Expand All @@ -61,18 +64,9 @@ class FileSourceConfig(ConfigModel):
100 * 1000 * 1000 # Must be at least 100MB before we use streaming mode
)

@root_validator(pre=True)
def filename_populates_path_if_present(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
if "path" not in values and "filename" in values:
values["path"] = values["filename"]
elif values.get("filename"):
raise ValueError(
"Both path and filename should not be provided together. Use one. We recommend using path. filename is deprecated."
)

return values
_filename_populates_path_if_present = pydantic_renamed_field(
"filename", "path", print_warning=False
)

@validator("file_extension", always=True)
def add_leading_dot_to_extension(cls, v: str) -> str:
Expand Down Expand Up @@ -179,16 +173,13 @@ def create(cls, config_dict, ctx):
return cls(ctx, config)

def get_filenames(self) -> Iterable[str]:
is_file = os.path.isfile(self.config.path)
is_dir = os.path.isdir(self.config.path)
if is_file:
if self.config.path.is_file():
self.report.total_num_files = 1
return [self.config.path]
if is_dir:
p = Path(self.config.path)
return [str(self.config.path)]
elif self.config.path.is_dir():
files_and_stats = [
(str(x), os.path.getsize(x))
for x in list(p.glob(f"*{self.config.file_extension}"))
for x in list(self.config.path.glob(f"*{self.config.file_extension}"))
if x.is_file()
]
self.report.total_num_files = len(files_and_stats)
Expand Down

0 comments on commit f227bd9

Please sign in to comment.