Skip to content

Commit

Permalink
refactor: Introduce client factory for dynamic session management and…
Browse files Browse the repository at this point in the history
… conditional refresh handling
  • Loading branch information
sagar-salvi-apptware committed Aug 19, 2024
1 parent b5cc27e commit 5e7fc8f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 23 deletions.
22 changes: 18 additions & 4 deletions metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import DefaultDict, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional

from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
Expand Down Expand Up @@ -33,6 +33,9 @@
StatefulIngestionSourceBase,
)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient


@platform_name("SageMaker")
@config_class(SagemakerSourceConfig)
Expand All @@ -56,6 +59,7 @@ def __init__(self, config: SagemakerSourceConfig, ctx: PipelineContext):
self.report = SagemakerSourceReport()
self.sagemaker_client = config.sagemaker_client
self.env = config.env
self.client_factory = ClientFactory(config)

@classmethod
def create(cls, config_dict, ctx):
Expand Down Expand Up @@ -92,9 +96,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
# extract jobs if specified
if self.source_config.extract_jobs is not False:
job_processor = JobProcessor(
sagemaker_client=self.source_config.get_auto_refreshing_sagemaker_client()
if self.source_config.allowed_cred_refresh()
else self.sagemaker_client,
sagemaker_client=self.client_factory.get_client,
env=self.env,
report=self.report,
job_type_filter=self.source_config.extract_jobs,
Expand All @@ -120,3 +122,15 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

def get_report(self):
return self.report


class ClientFactory:
def __init__(self, config: SagemakerSourceConfig):
self.config = config
self._cached_client = self.config.sagemaker_client

def get_client(self) -> "SageMakerClient":
if self.config.allowed_cred_refresh():
# Always fetch the client dynamically with auto-refresh logic
return self.config.sagemaker_client
return self._cached_client
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@ class SagemakerSourceConfig(
def sagemaker_client(self):
return self.get_sagemaker_client()

def get_auto_refreshing_sagemaker_client(self):
"""
Returns a reference to the SageMaker client function.
This is used to create a fresh client each time it is called.
"""
return self.get_sagemaker_client


@dataclass
class SagemakerSourceReport(StaleEntityRemovalSourceReport):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from types import MethodType
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Dict,
Iterable,
Expand Down Expand Up @@ -148,8 +148,7 @@ class JobProcessor:
"""

# boto3 SageMaker client
sagemaker_client: Any

sagemaker_client: Callable[[], "SageMakerClient"]
env: str
report: SagemakerSourceReport
# config filter for specific job types to ingest (see metadata-ingestion README)
Expand All @@ -172,7 +171,7 @@ class JobProcessor:

def get_jobs(self, job_type: JobType, job_spec: JobInfo) -> List[Any]:
jobs = []
paginator = self.get_sagemaker_client().get_paginator(job_spec.list_command)
paginator = self.sagemaker_client().get_paginator(job_spec.list_command)
for page in paginator.paginate():
page_jobs: List[Any] = page[job_spec.list_key]

Expand Down Expand Up @@ -270,7 +269,7 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]:
describe_command = job_type_to_info[job_type].describe_command
describe_name_key = job_type_to_info[job_type].describe_name_key

return getattr(self.get_sagemaker_client(), describe_command)(
return getattr(self.sagemaker_client(), describe_command)(
**{describe_name_key: job_name}
)

Expand Down Expand Up @@ -941,8 +940,3 @@ def process_transform_job(self, job: Dict[str, Any]) -> SageMakerJob:
output_datasets=output_datasets,
input_jobs=input_jobs,
)

def get_sagemaker_client(self) -> "SageMakerClient":
if isinstance(self.sagemaker_client, MethodType):
return self.sagemaker_client()
return self.sagemaker_client
15 changes: 13 additions & 2 deletions metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

from botocore.stub import Stubber
from freezegun import freeze_time

Expand Down Expand Up @@ -220,8 +222,17 @@ def test_sagemaker_ingest(tmp_path, pytestconfig):
{"ModelName": "the-second-model"},
)

mce_objects = [wu.metadata for wu in sagemaker_source_instance.get_workunits()]
write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects)
# Patch the client factory's get_client method to return the stubbed client for jobs
with patch.object(
sagemaker_source_instance.client_factory,
"get_client",
return_value=sagemaker_source_instance.sagemaker_client,
):
# Run the test and generate the MCEs
mce_objects = [
wu.metadata for wu in sagemaker_source_instance.get_workunits()
]
write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects)

# Verify the output.
test_resources_dir = pytestconfig.rootpath / "tests/unit/sagemaker"
Expand Down

0 comments on commit 5e7fc8f

Please sign in to comment.