Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): allow selfsigned certificate in s3 source #6179

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,24 @@ def get_credentials(self) -> Dict[str, str]:
}
return {}

def get_s3_client(self) -> "S3Client":
def get_s3_client(
self, verify_ssl: Optional[Union[bool, str]] = None
) -> "S3Client":
return self.get_session().client(
"s3",
endpoint_url=self.aws_endpoint_url,
config=Config(proxies=self.aws_proxy),
verify=verify_ssl,
)

def get_s3_resource(self) -> "S3ServiceResource":
def get_s3_resource(
self, verify_ssl: Optional[Union[bool, str]] = None
) -> "S3ServiceResource":
resource = self.get_session().resource(
"s3",
endpoint_url=self.aws_endpoint_url,
config=Config(proxies=self.aws_proxy),
verify=verify_ssl,
)
# according to: https://stackoverflow.com/questions/32618216/override-s3-endpoint-using-boto3-configuration-file
# boto3 only reads the signature version for s3 from that config file. boto3 automatically changes the endpoint to
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Iterable, Optional
from typing import Iterable, Optional, Union

from datahub.emitter.mce_builder import make_tag_urn
from datahub.ingestion.api.common import PipelineContext
Expand All @@ -23,13 +23,14 @@ def get_s3_tags(
ctx: PipelineContext,
use_s3_bucket_tags: Optional[bool] = False,
use_s3_object_tags: Optional[bool] = False,
verify_ssl: Optional[Union[bool, str]] = None,
) -> Optional[GlobalTagsClass]:
if aws_config is None:
raise ValueError("aws_config not set. Cannot browse s3")
new_tags = GlobalTagsClass(tags=[])
tags_to_add = []
if use_s3_bucket_tags:
s3 = aws_config.get_s3_resource()
s3 = aws_config.get_s3_resource(verify_ssl)
bucket = s3.Bucket(bucket_name)
try:
tags_to_add.extend(
Expand Down
7 changes: 6 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/s3/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import pydantic
from pydantic.fields import Field
Expand Down Expand Up @@ -61,6 +61,11 @@ class DataLakeSourceConfig(PlatformSourceConfigBase, EnvBasedSourceConfigBase):
description="Maximum number of rows to use when inferring schemas for TSV and CSV files.",
)

verify_ssl: Union[bool, str] = Field(
default=True,
description="Either a boolean, in which case it controls whether we verify the server's TLS certificate, or a string, in which case it must be a path to a CA bundle to use.",
)

_rename_path_spec_to_plural = pydantic_renamed_field(
"path_spec", "path_specs", lambda path_spec: [path_spec]
)
Expand Down
9 changes: 7 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/s3/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,9 @@ def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List:
if self.source_config.aws_config is None:
raise ValueError("AWS config is required for S3 file sources")

s3_client = self.source_config.aws_config.get_s3_client()
s3_client = self.source_config.aws_config.get_s3_client(
self.source_config.verify_ssl
)

file = smart_open(
table_data.full_path, "rb", transport_params={"client": s3_client}
Expand Down Expand Up @@ -581,6 +583,7 @@ def ingest_table(
self.ctx,
self.source_config.use_s3_bucket_tags,
self.source_config.use_s3_object_tags,
self.source_config.verify_ssl,
)
if s3_tags is not None:
dataset_snapshot.aspects.append(s3_tags)
Expand Down Expand Up @@ -649,7 +652,9 @@ def resolve_templated_folders(self, bucket_name: str, prefix: str) -> Iterable[s
def s3_browser(self, path_spec: PathSpec) -> Iterable[Tuple[str, datetime, int]]:
if self.source_config.aws_config is None:
raise ValueError("aws_config not set. Cannot browse s3")
s3 = self.source_config.aws_config.get_s3_resource()
s3 = self.source_config.aws_config.get_s3_resource(
self.source_config.verify_ssl
)
bucket_name = get_bucket_name(path_spec.include)
logger.debug(f"Scanning bucket: {bucket_name}")
bucket = s3.Bucket(bucket_name)
Expand Down