Skip to content

Commit

Permalink
Enable flake8-return
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed Dec 29, 2024
1 parent 202b253 commit 85bd94c
Show file tree
Hide file tree
Showing 18 changed files with 101 additions and 133 deletions.
6 changes: 2 additions & 4 deletions pyathena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@ class DBAPITypeObject(FrozenSet[str]):
def __eq__(self, other: object):
if isinstance(other, frozenset):
return frozenset.__eq__(self, other)
else:
return other in self
return other in self

def __ne__(self, other: object):
if isinstance(other, frozenset):
return frozenset.__ne__(self, other)
else:
return other not in self
return other not in self

def __hash__(self):
return frozenset.__hash__(self)
Expand Down
3 changes: 1 addition & 2 deletions pyathena/arrow/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def get_default_converter(
) -> Union[DefaultArrowTypeConverter, DefaultArrowUnloadTypeConverter, Any]:
if unload:
return DefaultArrowUnloadTypeConverter()
else:
return DefaultArrowTypeConverter()
return DefaultArrowTypeConverter()

@property
def arraysize(self) -> int:
Expand Down
5 changes: 2 additions & 3 deletions pyathena/arrow/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
def _to_date(value: Optional[Union[str, datetime]]) -> Optional[date]:
if value is None:
return None
elif isinstance(value, datetime):
if isinstance(value, datetime):
return value.date()
else:
return datetime.strptime(value, "%Y-%m-%d").date()
return datetime.strptime(value, "%Y-%m-%d").date()


_DEFAULT_ARROW_CONVERTERS: Dict[str, Callable[[Optional[str]], Optional[Any]]] = {
Expand Down
3 changes: 1 addition & 2 deletions pyathena/arrow/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def get_default_converter(
) -> Union[DefaultArrowTypeConverter, DefaultArrowUnloadTypeConverter, Any]:
if unload:
return DefaultArrowUnloadTypeConverter()
else:
return DefaultArrowTypeConverter()
return DefaultArrowTypeConverter()

@property
def arraysize(self) -> int:
Expand Down
31 changes: 15 additions & 16 deletions pyathena/arrow/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,41 @@ def get_athena_type(type_: "DataType") -> Tuple[str, int, int]:

if type_.id in [types.Type_BOOL]: # 1
return "boolean", 0, 0
elif type_.id in [types.Type_UINT8, types.Type_INT8]: # 2, 3
if type_.id in [types.Type_UINT8, types.Type_INT8]: # 2, 3
return "tinyint", 3, 0
elif type_.id in [types.Type_UINT16, types.Type_INT16]: # 4, 5
if type_.id in [types.Type_UINT16, types.Type_INT16]: # 4, 5
return "smallint", 5, 0
elif type_.id in [types.Type_UINT32, types.Type_INT32]: # 6, 7
if type_.id in [types.Type_UINT32, types.Type_INT32]: # 6, 7
return "integer", 10, 0
elif type_.id in [types.Type_UINT64, types.Type_INT64]: # 8, 9
if type_.id in [types.Type_UINT64, types.Type_INT64]: # 8, 9
return "bigint", 19, 0
elif type_.id in [types.Type_HALF_FLOAT, types.Type_FLOAT]: # 10, 11
if type_.id in [types.Type_HALF_FLOAT, types.Type_FLOAT]: # 10, 11
return "float", 17, 0
elif type_.id in [types.Type_DOUBLE]: # 12
if type_.id in [types.Type_DOUBLE]: # 12
return "double", 17, 0
elif type_.id in [types.Type_STRING, types.Type_LARGE_STRING]: # 13, 34
if type_.id in [types.Type_STRING, types.Type_LARGE_STRING]: # 13, 34
return "varchar", 2147483647, 0
elif type_.id in [
if type_.id in [
types.Type_BINARY,
types.Type_FIXED_SIZE_BINARY,
types.Type_LARGE_BINARY,
]: # 14, 15, 35
return "varbinary", 1073741824, 0
elif type_.id in [types.Type_DATE32, types.Type_DATE64]: # 16, 17
if type_.id in [types.Type_DATE32, types.Type_DATE64]: # 16, 17
return "date", 0, 0
elif type_.id == types.Type_TIMESTAMP: # 18
if type_.id == types.Type_TIMESTAMP: # 18
return "timestamp", 3, 0
elif type_.id in [types.Type_DECIMAL128, types.Decimal256Type]: # 23, 24
if type_.id in [types.Type_DECIMAL128, types.Decimal256Type]: # 23, 24
type_ = cast(types.Decimal128Type, type_)
return "decimal", type_.precision, type_.scale
elif type_.id in [
if type_.id in [
types.Type_LIST,
types.Type_FIXED_SIZE_LIST,
types.Type_LARGE_LIST,
]: # 25, 32, 36
return "array", 0, 0
elif type_.id in [types.Type_STRUCT]: # 26
if type_.id in [types.Type_STRUCT]: # 26
return "row", 0, 0
elif type_.id in [types.Type_MAP]: # 30
if type_.id in [types.Type_MAP]: # 30
return "map", 0, 0
else:
return "string", 2147483647, 0
return "string", 2147483647, 0
6 changes: 2 additions & 4 deletions pyathena/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def __next__(self):
row = self.fetchone()
if row is None:
raise StopIteration
else:
return row
return row

def __iter__(self):
return self
Expand Down Expand Up @@ -482,8 +481,7 @@ def __poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculation
AthenaQueryExecution.STATE_CANCELLED,
]:
return query_execution
else:
time.sleep(self._poll_interval)
time.sleep(self._poll_interval)

def _poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculationExecution]:
try:
Expand Down
39 changes: 17 additions & 22 deletions pyathena/fastparquet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,34 +37,29 @@ def get_athena_type(type_: "SchemaElement") -> Tuple[str, int, int]:

if type_.type in [Type.BOOLEAN]:
return "boolean", 0, 0
elif type_.type in [Type.INT32]:
if type_.type in [Type.INT32]:
if type_.converted_type == ConvertedType.DATE:
return "date", 0, 0
else:
return "integer", 10, 0
elif type_.type in [Type.INT64]:
return "integer", 10, 0
if type_.type in [Type.INT64]:
return "bigint", 19, 0
elif type_.type in [Type.INT96]:
if type_.type in [Type.INT96]:
return "timestamp", 3, 0
elif type_.type in [Type.FLOAT]:
if type_.type in [Type.FLOAT]:
return "float", 17, 0
elif type_.type in [Type.DOUBLE]:
if type_.type in [Type.DOUBLE]:
return "double", 17, 0
elif type_.type in [Type.BYTE_ARRAY, Type.FIXED_LEN_BYTE_ARRAY]:
if type_.type in [Type.BYTE_ARRAY, Type.FIXED_LEN_BYTE_ARRAY]:
if type_.converted_type == ConvertedType.UTF8:
return "varchar", 2147483647, 0
elif type_.converted_type == ConvertedType.DECIMAL:
if type_.converted_type == ConvertedType.DECIMAL:
return "decimal", type_.precision, type_.scale
else:
return "varbinary", 1073741824, 0
else:
if type_.converted_type == ConvertedType.LIST:
return "array", 0, 0
elif type_.converted_type == ConvertedType.MAP:
return "map", 0, 0
else:
children = getattr(type_, "children", [])
if type_.type is None and type_.converted_type is None and children:
return "row", 0, 0
else:
return "string", 2147483647, 0
return "varbinary", 1073741824, 0
if type_.converted_type == ConvertedType.LIST:
return "array", 0, 0
if type_.converted_type == ConvertedType.MAP:
return "map", 0, 0
children = getattr(type_, "children", [])
if type_.type is None and type_.converted_type is None and children:
return "row", 0, 0
return "string", 2147483647, 0
47 changes: 20 additions & 27 deletions pyathena/filesystem/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ def parse_path(path: str) -> Tuple[str, Optional[str], Optional[str]]:
match = S3FileSystem.PATTERN_PATH.search(path)
if match:
return match.group("bucket"), match.group("key"), match.group("version_id")
else:
raise ValueError(f"Invalid S3 path format {path}.")
raise ValueError(f"Invalid S3 path format {path}.")

def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]:
if bucket not in self.dircache or refresh:
Expand Down Expand Up @@ -347,20 +346,19 @@ def info(self, path: str, **kwargs) -> S3Object:

if cache:
return cache
else:
return S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=key.rstrip("/") if key else None,
version_id=version_id,
)
return S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=key.rstrip("/") if key else None,
version_id=version_id,
)
if key:
object_info = self._head_object(path, refresh=refresh, version_id=version_id)
if object_info:
Expand All @@ -369,8 +367,7 @@ def info(self, path: str, **kwargs) -> S3Object:
bucket_info = self._head_bucket(path, refresh=refresh)
if bucket_info:
return bucket_info
else:
raise FileNotFoundError(path)
raise FileNotFoundError(path)

response = self._call(
self._client.list_objects_v2,
Expand All @@ -397,8 +394,7 @@ def info(self, path: str, **kwargs) -> S3Object:
key=key.rstrip("/") if key else None,
version_id=version_id,
)
else:
raise FileNotFoundError(path)
raise FileNotFoundError(path)

def find(
self,
Expand All @@ -423,8 +419,7 @@ def find(
files = []
if detail:
return {f.name: f for f in files}
else:
return [f.name for f in files]
return [f.name for f in files]

def exists(self, path: str, **kwargs) -> bool:
path = self._strip_protocol(path)
Expand Down Expand Up @@ -741,8 +736,7 @@ def checksum(self, path: str, **kwargs):
info = self.info(path, refresh=refresh)
if info.get("type") != S3ObjectType.S3_OBJECT_TYPE_DIRECTORY:
return int(info.get("etag").strip('"').split("-")[0], 16)
else:
return int(tokenize(info), 16)
return int(tokenize(info), 16)

def sign(self, path: str, expiration: int = 3600, **kwargs):
bucket, key, version_id = self.parse_path(path)
Expand Down Expand Up @@ -1224,9 +1218,8 @@ def _get_ranges(
if range_end > end:
ranges.append((range_start, end))
break
else:
ranges.append((range_start, range_end))
range_start += worker_block_size
ranges.append((range_start, range_end))
range_start += worker_block_size
else:
ranges.append((start, end))
return ranges
Expand Down
9 changes: 4 additions & 5 deletions pyathena/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,14 +547,13 @@ def serde_serialization_lib(self) -> Optional[str]:
def compression(self) -> Optional[str]:
if "write.compression" in self._parameters: # text or json
return self._parameters["write.compression"]
elif "serde.param.write.compression" in self._parameters: # text or json
if "serde.param.write.compression" in self._parameters: # text or json
return self._parameters["serde.param.write.compression"]
elif "parquet.compress" in self._parameters: # parquet
if "parquet.compress" in self._parameters: # parquet
return self._parameters["parquet.compress"]
elif "orc.compress" in self._parameters: # orc
if "orc.compress" in self._parameters: # orc
return self._parameters["orc.compress"]
else:
return None
return None

@property
def serde_properties(self) -> Dict[str, str]:
Expand Down
3 changes: 1 addition & 2 deletions pyathena/pandas/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def get_default_converter(
) -> Union[DefaultPandasTypeConverter, Any]:
if unload:
return DefaultPandasUnloadTypeConverter()
else:
return DefaultPandasTypeConverter()
return DefaultPandasTypeConverter()

@property
def arraysize(self) -> int:
Expand Down
3 changes: 1 addition & 2 deletions pyathena/pandas/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def get_default_converter(
) -> Union[DefaultPandasTypeConverter, Any]:
if unload:
return DefaultPandasUnloadTypeConverter()
else:
return DefaultPandasTypeConverter()
return DefaultPandasTypeConverter()

@property
def arraysize(self) -> int:
Expand Down
9 changes: 3 additions & 6 deletions pyathena/pandas/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def get_chunk(self, size=None):

if isinstance(self._reader, TextFileReader):
return self._reader.get_chunk(size)
else:
return next(self._reader)
return next(self._reader)


class AthenaPandasResultSet(AthenaResultSet):
Expand Down Expand Up @@ -166,8 +165,7 @@ def _get_engine(self) -> "str":
"Trying to import the above resulted in these errors:"
f"{error_msgs}"
)
else:
return self._engine
return self._engine

def __s3_file_system(self):
from pyathena.filesystem.s3 import S3FileSystem
Expand Down Expand Up @@ -386,8 +384,7 @@ def _as_pandas(self) -> Union["TextFileReader", "DataFrame"]:
def as_pandas(self) -> Union[DataFrameIterator, "DataFrame"]:
if self._chunksize is None:
return next(self._df_iter)
else:
return self._df_iter
return self._df_iter

def close(self) -> None:
import pandas as pd
Expand Down
22 changes: 10 additions & 12 deletions pyathena/pandas/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,25 @@ def to_sql_type_mappings(col: "Series") -> str:
col_type = pd.api.types.infer_dtype(col, skipna=True)
if col_type == "datetime64" or col_type == "datetime":
return "TIMESTAMP"
elif col_type == "timedelta":
if col_type == "timedelta":
return "INT"
elif col_type == "timedelta64":
if col_type == "timedelta64":
return "BIGINT"
elif col_type == "floating":
if col_type == "floating":
if col.dtype == "float32":
return "FLOAT"
else:
return "DOUBLE"
elif col_type == "integer":
return "DOUBLE"
if col_type == "integer":
if col.dtype == "int32":
return "INT"
else:
return "BIGINT"
elif col_type == "boolean":
return "BIGINT"
if col_type == "boolean":
return "BOOLEAN"
elif col_type == "date":
if col_type == "date":
return "DATE"
elif col_type == "bytes":
if col_type == "bytes":
return "BINARY"
elif col_type in ["complex", "time"]:
if col_type in ["complex", "time"]:
raise ValueError(f"Data type `{col_type}` is not supported")
return "STRING"

Expand Down
9 changes: 4 additions & 5 deletions pyathena/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,10 @@ def fetchone(
self._fetch()
if not self._rows:
return None
else:
if self._rownumber is None:
self._rownumber = 0
self._rownumber += 1
return self._rows.popleft()
if self._rownumber is None:
self._rownumber = 0
self._rownumber += 1
return self._rows.popleft()

def fetchmany(
self, size: Optional[int] = None
Expand Down
Loading

0 comments on commit 85bd94c

Please sign in to comment.