Skip to content

Commit

Permalink
feat(ingestion/s3): ignore depth mismatched path
Browse files Browse the repository at this point in the history
  • Loading branch information
eagle-25 committed Jan 12, 2025
1 parent 9897804 commit a2690ac
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,7 @@ def group_s3_objects_by_dirname(
dirname = "/"
grouped_s3_objs[dirname].append(obj)
return grouped_s3_objs


def get_path_depth(key: str, delimiter: str = "/") -> int:
return key.lstrip(delimiter).count(delimiter) + 1
15 changes: 11 additions & 4 deletions metadata-ingestion/src/datahub/ingestion/source/s3/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_bucket_name,
get_bucket_relative_path,
get_key_prefix,
get_path_depth,
group_s3_objects_by_dirname,
strip_s3_prefix,
)
Expand Down Expand Up @@ -842,14 +843,14 @@ def get_dir_to_process(
return [f"{protocol}{bucket_name}/{folder}"]
return [f"{protocol}{bucket_name}/{folder}"]

def get_folder_info(
def get_folders_by_prefix_and_depth(
self,
path_spec: PathSpec,
bucket: "Bucket",
prefix: str,
) -> List[Folder]:
"""
Retrieves all the folders in a path by listing all the files in the prefix.
Retrieves folders in the prefix whose depth matches path_spec.include.
If the prefix is a full path then only that folder will be extracted.
A folder has creation and modification times, size, and a sample file path.
Expand All @@ -866,8 +867,14 @@ def get_folder_info(
Returns:
List[Folder]: A list of Folder objects representing the partitions found.
"""
include_path_depth = get_path_depth(urlparse(path_spec.include).path)
s3_objects = (
obj
for obj in bucket.objects.filter(Prefix=prefix).page_size(PAGE_SIZE)
if get_path_depth(obj.key) == include_path_depth
)

partitions: List[Folder] = []
s3_objects = bucket.objects.filter(Prefix=prefix).page_size(PAGE_SIZE)
for key, group in group_s3_objects_by_dirname(s3_objects).items():
file_size = 0
creation_time = None
Expand Down Expand Up @@ -993,7 +1000,7 @@ def s3_browser(self, path_spec: PathSpec, sample_size: int) -> Iterable[BrowsePa
prefix_to_process = urlparse(dir).path.lstrip("/")

folders.extend(
self.get_folder_info(
self.get_folders_by_prefix_and_depth(
path_spec, bucket, prefix_to_process
)
)
Expand Down
79 changes: 63 additions & 16 deletions metadata-ingestion/tests/unit/s3/test_s3_source.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
from datetime import datetime
from itertools import tee
from typing import List, Tuple
from unittest.mock import Mock
from unittest.mock import Mock, patch

import pytest

from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.s3_util import group_s3_objects_by_dirname
from datahub.ingestion.source.data_lake_common.data_lake_utils import ContainerWUCreator
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
from datahub.ingestion.source.s3.source import S3Source, partitioned_folder_comparator


def _get_s3_source(path_spec: PathSpec) -> S3Source:
return S3Source.create(
config_dict={
"path_spec": {
"include": path_spec.include,
"table_name": path_spec.table_name,
},
},
ctx=PipelineContext(run_id="test-s3"),
)


def test_partition_comparator_numeric_folder_name():
folder1 = "3"
folder2 = "12"
Expand Down Expand Up @@ -245,22 +259,10 @@ def container_properties_filter(x: MetadataWorkUnit) -> bool:
}


def test_get_folder_info():
def test_get_folders_by_prefix_and_depth_returns_latest_file_in_each_folder():
"""
Test S3Source.get_folder_info returns the latest file in each folder
Test S3Source.get_folders_by_prefix_and_depth returns the latest file in each folder
"""

def _get_s3_source(path_spec_: PathSpec) -> S3Source:
return S3Source.create(
config_dict={
"path_spec": {
"include": path_spec_.include,
"table_name": path_spec_.table_name,
},
},
ctx=PipelineContext(run_id="test-s3"),
)

# arrange
path_spec = PathSpec(
include="s3://my-bucket/{table}/{partition0}/*.csv",
Expand Down Expand Up @@ -295,11 +297,56 @@ def _get_s3_source(path_spec_: PathSpec) -> S3Source:
)

# act
res = _get_s3_source(path_spec).get_folder_info(
res = _get_s3_source(path_spec).get_folders_by_prefix_and_depth(
path_spec, bucket, prefix="/my-folder"
)

# assert
assert len(res) == 2
assert res[0].sample_file == "s3://my-bucket/my-folder/dir1/0002.csv"
assert res[1].sample_file == "s3://my-bucket/my-folder/dir2/0001.csv"


def test_get_folders_by_prefix_and_depth_ignores_depth_mismatch():
"""
Test S3Source.get_folders_by_prefix_and_depth ignores folders that do not match depth of path_spec.include.
"""
# arrange
path_spec = PathSpec(
include="s3://my-bucket/{table}/{partition0}/*.csv",
table_name="{table}",
)

bucket = Mock()
bucket.objects.filter().page_size = Mock(
return_value=[
Mock(
bucket_name="my-bucket",
key="my-folder/ignore/this/path/0001.csv",
creation_time=datetime(2025, 1, 1, 1),
last_modified=datetime(2025, 1, 1, 1),
size=100,
),
]
)

captured_arg0_for_group_s3_objects = None

def _wrapped_group_s3_objects_by_dirname(s3_objects):
nonlocal captured_arg0_for_group_s3_objects
copy1, copy2 = tee(s3_objects)
captured_arg0_for_group_s3_objects = copy1
return group_s3_objects_by_dirname(copy2)

# act
with patch(
"datahub.ingestion.source.s3.source.group_s3_objects_by_dirname",
wraps=_wrapped_group_s3_objects_by_dirname,
):
_get_s3_source(path_spec).get_folders_by_prefix_and_depth(
path_spec, bucket, prefix="/my-folder"
)

# assert
assert captured_arg0_for_group_s3_objects is not None
assert len(list(captured_arg0_for_group_s3_objects)) == 0
15 changes: 14 additions & 1 deletion metadata-ingestion/tests/unit/s3/test_s3_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from unittest.mock import Mock

from datahub.ingestion.source.aws.s3_util import group_s3_objects_by_dirname
import pytest

from datahub.ingestion.source.aws.s3_util import (
get_path_depth,
group_s3_objects_by_dirname,
)


def test_group_s3_objects_by_dirname():
Expand All @@ -27,3 +32,11 @@ def test_group_s3_objects_by_dirname_files_in_root_directory():

assert len(grouped_objects) == 1
assert grouped_objects["/"] == s3_objects


@pytest.mark.parametrize(
"path, expected_depth",
[("dir1", 1), ("/dir1", 1), ("/dir1/", 2), ("/dir1/file.txt", 2)],
)
def test_get_path_depth(path: str, expected_depth: int) -> None:
assert get_path_depth(path) == expected_depth

0 comments on commit a2690ac

Please sign in to comment.