Skip to content

Commit

Permalink
python/pytorch: Integrate alive_progress into ShardReader
Browse files Browse the repository at this point in the history
Signed-off-by: Soham Manoli <[email protected]>
  • Loading branch information
msoham123 committed Jul 5, 2024
1 parent b1a5afd commit ac71ef7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
3 changes: 2 additions & 1 deletion python/aistore/pytorch/dev_requirements
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
torch==2.1.1
torchdata==0.7.1
torchdata==0.7.1
alive_progress==3.1.5
8 changes: 7 additions & 1 deletion python/aistore/pytorch/shard_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aistore.pytorch.utils import get_basename
from aistore.sdk.types import ArchiveSettings
from aistore.pytorch.base_iter_dataset import AISBaseIterDataset
from alive_progress import alive_it


class AISShardReader(AISBaseIterDataset):
Expand All @@ -23,6 +24,7 @@ class AISShardReader(AISBaseIterDataset):
prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of Bucket objects to list of prefixes that only allows
objects with the specified prefixes to be used from each source
etl_name (str, optional): Optional ETL on the AIS cluster to apply to each object
disable_output (bool, optional): Disables console shard reading progress indicator
Yields:
Tuple[str, List[bytes]]: Each item is a tuple where the first element is the basename of the shard
Expand All @@ -34,9 +36,11 @@ def __init__(
bucket_list: Union[Bucket, List[Bucket]],
prefix_map: Dict[Bucket, Union[str, List[str]]] = {},
etl_name: str = None,
show_progress: bool = False,
):
super().__init__(bucket_list, prefix_map)
self._etl_name = etl_name
self._show_progress = show_progress

def _get_sample_iter_from_source(self, source: Bucket, prefix: str) -> Iterable:
"""
Expand Down Expand Up @@ -70,7 +74,9 @@ def _get_sample_iter_from_source(self, source: Bucket, prefix: str) -> Iterable:

# for each basename, get the byte data for each file and yield in dictionary
shard = source.object(entry.name)
for basename, files in samples_dict.items():
for basename, files in alive_it(
samples_dict.items(), title=entry.name, disable=not self._show_progress
):
content_dict = {}
for file_name in files:
file_prefix = file_name.split(".")[-1]
Expand Down
14 changes: 5 additions & 9 deletions python/aistore/sdk/request_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,10 @@ def create_new_session(self) -> Session:
"""
request_session = session()
if "https" in self._endpoint:
self._set_session_verification(
request_session, self._skip_verify, self._ca_cert
)
self._set_session_verification(request_session)
return request_session

def _set_session_verification(
self, request_session: Session, skip_verify: bool, ca_cert: str
):
def _set_session_verification(self, request_session: Session):
"""
Set session verify value for validating the server's SSL certificate
The requests library allows this to be a boolean or a string path to the cert
Expand All @@ -67,11 +63,11 @@ def _set_session_verification(
2. Cert path from env var.
3. True (verify with system's approved CA list)
"""
if skip_verify:
if self._skip_verify:
request_session.verify = False
return
if ca_cert:
request_session.verify = ca_cert
if self._ca_cert:
request_session.verify = self._ca_cert
return
env_crt = os.getenv(AIS_SERVER_CRT)
request_session.verify = env_crt if env_crt else True
Expand Down

0 comments on commit ac71ef7

Please sign in to comment.