Skip to content
This repository has been archived by the owner on Jan 27, 2025. It is now read-only.

Commit

Permalink
feat(ingest): s3 - allow selfsigned certificate (datahub-project#6179)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate authored and cccs-tom committed Nov 18, 2022
1 parent 8777781 commit 8d5e726
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,24 @@ def get_credentials(self) -> Dict[str, str]:
}
return {}

def get_s3_client(self) -> "S3Client":
def get_s3_client(
self, verify_ssl: Optional[Union[bool, str]] = None
) -> "S3Client":
return self.get_session().client(
"s3",
endpoint_url=self.aws_endpoint_url,
config=Config(proxies=self.aws_proxy),
verify=verify_ssl,
)

def get_s3_resource(self) -> "S3ServiceResource":
def get_s3_resource(
self, verify_ssl: Optional[Union[bool, str]] = None
) -> "S3ServiceResource":
resource = self.get_session().resource(
"s3",
endpoint_url=self.aws_endpoint_url,
config=Config(proxies=self.aws_proxy),
verify=verify_ssl,
)
# according to: https://stackoverflow.com/questions/32618216/override-s3-endpoint-using-boto3-configuration-file
# boto3 only reads the signature version for s3 from that config file. boto3 automatically changes the endpoint to
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Iterable, Optional
from typing import Iterable, Optional, Union

from datahub.emitter.mce_builder import make_tag_urn
from datahub.ingestion.api.common import PipelineContext
Expand All @@ -23,13 +23,14 @@ def get_s3_tags(
ctx: PipelineContext,
use_s3_bucket_tags: Optional[bool] = False,
use_s3_object_tags: Optional[bool] = False,
verify_ssl: Optional[Union[bool, str]] = None,
) -> Optional[GlobalTagsClass]:
if aws_config is None:
raise ValueError("aws_config not set. Cannot browse s3")
new_tags = GlobalTagsClass(tags=[])
tags_to_add = []
if use_s3_bucket_tags:
s3 = aws_config.get_s3_resource()
s3 = aws_config.get_s3_resource(verify_ssl)
bucket = s3.Bucket(bucket_name)
try:
tags_to_add.extend(
Expand Down
7 changes: 6 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/s3/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import pydantic
from pydantic.fields import Field
Expand Down Expand Up @@ -61,6 +61,11 @@ class DataLakeSourceConfig(PlatformSourceConfigBase, EnvBasedSourceConfigBase):
description="Maximum number of rows to use when inferring schemas for TSV and CSV files.",
)

verify_ssl: Union[bool, str] = Field(
default=True,
description="Either a boolean, in which case it controls whether we verify the server's TLS certificate, or a string, in which case it must be a path to a CA bundle to use.",
)

_rename_path_spec_to_plural = pydantic_renamed_field(
"path_spec", "path_specs", lambda path_spec: [path_spec]
)
Expand Down
9 changes: 7 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/s3/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,9 @@ def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List:
if self.source_config.aws_config is None:
raise ValueError("AWS config is required for S3 file sources")

s3_client = self.source_config.aws_config.get_s3_client()
s3_client = self.source_config.aws_config.get_s3_client(
self.source_config.verify_ssl
)

file = smart_open(
table_data.full_path, "rb", transport_params={"client": s3_client}
Expand Down Expand Up @@ -581,6 +583,7 @@ def ingest_table(
self.ctx,
self.source_config.use_s3_bucket_tags,
self.source_config.use_s3_object_tags,
self.source_config.verify_ssl,
)
if s3_tags is not None:
dataset_snapshot.aspects.append(s3_tags)
Expand Down Expand Up @@ -649,7 +652,9 @@ def resolve_templated_folders(self, bucket_name: str, prefix: str) -> Iterable[s
def s3_browser(self, path_spec: PathSpec) -> Iterable[Tuple[str, datetime, int]]:
if self.source_config.aws_config is None:
raise ValueError("aws_config not set. Cannot browse s3")
s3 = self.source_config.aws_config.get_s3_resource()
s3 = self.source_config.aws_config.get_s3_resource(
self.source_config.verify_ssl
)
bucket_name = get_bucket_name(path_spec.include)
logger.debug(f"Scanning bucket: {bucket_name}")
bucket = s3.Bucket(bucket_name)
Expand Down

0 comments on commit 8d5e726

Please sign in to comment.