From 832166695f9f9734ac254ecc4881ddd724dd47d8 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sun, 12 May 2024 17:33:23 +0900 Subject: [PATCH] WIP --- pyathena/filesystem/s3.py | 135 +++++++++++++++++++++++--------------- 1 file changed, 82 insertions(+), 53 deletions(-) diff --git a/pyathena/filesystem/s3.py b/pyathena/filesystem/s3.py index bf9b8de5..acbd1b70 100644 --- a/pyathena/filesystem/s3.py +++ b/pyathena/filesystem/s3.py @@ -174,6 +174,7 @@ def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]: bucket=bucket, key=None, version_id=None, + delimiter=None, ) self.dircache[bucket] = file else: @@ -207,6 +208,7 @@ def _head_object( bucket=bucket, key=key, version_id=version_id, + delimiter=None, ) self.dircache[path] = file else: @@ -231,6 +233,7 @@ def _ls_buckets(self, refresh: bool = False) -> List[S3Object]: bucket=b["Name"], key=None, version_id=None, + delimiter=None, ) for b in response["Buckets"] ] @@ -251,58 +254,63 @@ def _ls_dirs( bucket, key, version_id = self.parse_path(path) if key: prefix = f"{key}/{prefix if prefix else ''}" - if path not in self.dircache or refresh: - files: List[S3Object] = [] - while True: - request: Dict[Any, Any] = { - "Bucket": bucket, - "Prefix": prefix, - "Delimiter": delimiter, - } - if next_token: - request.update({"ContinuationToken": next_token}) - if max_keys: - request.update({"MaxKeys": max_keys}) - response = self._call( - self._client.list_objects_v2, - **request, - ) - files.extend( - S3Object( - init={ - "ContentLength": 0, - "ContentType": None, - "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, - "ETag": None, - "LastModified": None, - }, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - bucket=bucket, - key=c["Prefix"][:-1].rstrip("/"), - version_id=version_id, - ) - for c in response.get("CommonPrefixes", []) - ) - files.extend( - S3Object( - init=c, - type=S3ObjectType.S3_OBJECT_TYPE_FILE, - bucket=bucket, - key=c["Key"], - ) - for c in response.get("Contents", []) - ) - next_token = response.get("NextContinuationToken") - if not next_token: - break - if files: - self.dircache[path] = files - else: + + if path in self.dircache and not refresh: cache = self.dircache[path] if not isinstance(cache, list): - files = [cache] + caches = [cache] else: - files = cache + caches = cache + if all([f.delimiter == delimiter for f in caches]): + return caches + + files: List[S3Object] = [] + while True: + request: Dict[Any, Any] = { + "Bucket": bucket, + "Prefix": prefix, + "Delimiter": delimiter, + } + if next_token: + request.update({"ContinuationToken": next_token}) + if max_keys: + request.update({"MaxKeys": max_keys}) + response = self._call( + self._client.list_objects_v2, + **request, + ) + files.extend( + S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, + bucket=bucket, + key=c["Prefix"][:-1].rstrip("/"), + version_id=version_id, + delimiter=delimiter, + ) + for c in response.get("CommonPrefixes", []) + ) + files.extend( + S3Object( + init=c, + type=S3ObjectType.S3_OBJECT_TYPE_FILE, + bucket=bucket, + key=c["Key"], + delimiter=delimiter, + ) + for c in response.get("Contents", []) + ) + next_token = response.get("NextContinuationToken") + if not next_token: + break + if files: + self.dircache[path] = files return files def ls( @@ -337,6 +345,7 @@ def info(self, path: str, **kwargs) -> S3Object: bucket=bucket, key=None, version_id=None, + delimiter=None, ) if not refresh: caches: Union[List[S3Object], S3Object] = self._ls_from_cache(path) @@ -363,6 +372,7 @@ def info(self, path: str, **kwargs) -> S3Object: bucket=bucket, key=key.rstrip("/") if key else None, version_id=version_id, + delimiter=None, ) if key: object_info = self._head_object(path, refresh=refresh, version_id=version_id) @@ -399,24 +409,31 @@ def info(self, path: str, **kwargs) -> S3Object: bucket=bucket, key=key.rstrip("/") if key else None, version_id=version_id, + delimiter=None, ) else: raise FileNotFoundError(path) - def find( + def _find( self, path: str, maxdepth: Optional[int] = None, withdirs: Optional[bool] = None, - detail: bool = False, **kwargs, - ) -> Union[Dict[str, S3Object], List[str]]: - # TODO: Support maxdepth and withdirs + ) -> List[S3Object]: path = self._strip_protocol(path) if path in ["", "/"]: raise ValueError("Cannot traverse all files in S3.") bucket, key, _ = self.parse_path(path) prefix = kwargs.pop("prefix", "") + if maxdepth: + return super().find( + path=path, + maxdepth=maxdepth, + withdirs=withdirs, + detail=True, + **kwargs + ).values() files = self._ls_dirs(path, prefix=prefix, delimiter="") if not files and key: @@ -424,6 +441,18 @@ def find( files = [self.info(path)] except FileNotFoundError: files = [] + return files + + def find( + self, + path: str, + maxdepth: Optional[int] = None, + withdirs: Optional[bool] = None, + detail: bool = False, + **kwargs, + ) -> Union[Dict[str, S3Object], List[str]]: + # TODO: Support withdirs + files = self._find(path=path, maxdepth=maxdepth, withdirs=withdirs, **kwargs) if detail: return {f.name: f for f in files} else: