Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Oct 2, 2024
1 parent a22a348 commit 3666765
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 28 deletions.
46 changes: 19 additions & 27 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"""
import math
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as _np
import onnx.onnx_ml_pb2
Expand All @@ -56,17 +56,7 @@ def get_type(elem_type: Union[str, int]) -> str:
if isinstance(elem_type, str):
return elem_type

try:
return str(helper.tensor_dtype_to_np_dtype(elem_type))
except:
try:
from onnx.mapping import ( # pylint: disable=import-outside-toplevel
TENSOR_TYPE_TO_NP_TYPE,
)
except ImportError as exception:
raise ImportError("Unable to import onnx which is required {}".format(exception))

return str(TENSOR_TYPE_TO_NP_TYPE[elem_type])
return str(helper.tensor_dtype_to_np_dtype(elem_type))


def get_constant(
Expand Down Expand Up @@ -244,14 +234,17 @@ def _impl_v13(cls, bb, inputs, attr, params):
class BinaryBase(OnnxOpConverter):
"""Converts an onnx BinaryBase node into an equivalent Relax expression."""

numpy_op = None
relax_op = None
numpy_op: Callable = None
relax_op: Callable = None

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
print(inputs)
if cls.numpy_op is None or cls.relax_op is None:
raise ValueError("Numpy and Relax operators must be defined for BinaryBase.")
if all([isinstance(inp, relax.Constant) for inp in inputs]):
output = cls.numpy_op(inputs[0].data.numpy(), inputs[1].data.numpy())
output = cls.numpy_op( # pylint: disable=not-callable
inputs[0].data.numpy(), inputs[1].data.numpy()
)
return relax.const(output, inputs[0].struct_info.dtype)
if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
x = (
Expand All @@ -264,7 +257,7 @@ def _impl_v1(cls, bb, inputs, attr, params):
if isinstance(inputs[1], relax.PrimValue)
else inputs[1].data.numpy()
)
return relax.PrimValue(cls.numpy_op(x, y))
return relax.PrimValue(cls.numpy_op(x, y)) # pylint: disable=not-callable

return cls.relax_op(inputs[0], inputs[1])

Expand Down Expand Up @@ -1267,22 +1260,22 @@ def _impl_v1(cls, bb, inputs, attr, params):
class MultiInputBase(OnnxOpConverter):
"""Converts an onnx MultiInputBase node into an equivalent Relax expression."""

numpy_op = None
relax_op = None
numpy_op: Callable = None
relax_op: Callable = None

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
if cls.numpy_op is None or cls.relax_op is None:
raise NotImplementedError("numpy_op and relax_op must be defined for MultiInputBase")
if all([isinstance(inp, relax.Constant) for inp in inputs]):
np_inputs = [inp.data.numpy() for inp in inputs]
output = cls.numpy_op(*np_inputs)
output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable
return relax.const(output, output.dtype)

# Expand inputs, stack them, then perform minimum over the new axis.
inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs]
stacked_tensor = relax.op.concat(inputs, axis=0)
return cls.relax_op(stacked_tensor, axis=0)
return cls.relax_op(stacked_tensor, axis=0) # pylint: disable=not-callable


class Min(MultiInputBase):
Expand Down Expand Up @@ -1960,7 +1953,7 @@ class GlobalAveragePool(OnnxOpConverter):
@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
rank = len(inputs[0].struct_info.shape)
axes = [i for i in range(2, rank)]
axes = list(range(2, rank))
return relax.op.mean(inputs[0], axis=axes, keepdims=True)


Expand All @@ -1970,7 +1963,7 @@ class GlobalMaxPool(OnnxOpConverter):
@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
rank = len(inputs[0].struct_info.shape)
axes = [i for i in range(2, rank)]
axes = list(range(2, rank))
return relax.op.max(inputs[0], axis=axes, keepdims=True)


Expand All @@ -1982,7 +1975,7 @@ def _impl_v2(cls, bb, inputs, attr, params):
p = attr.get("p", 2.0)
dtype = inputs[0].struct_info.dtype
rank = len(inputs[0].struct_info.shape)
axes = [i for i in range(2, rank)]
axes = list(range(2, rank))
x_abs = relax.op.abs(inputs[0])
x_p = relax.op.power(x_abs, relax.const(p, dtype=dtype))
x_sum = relax.op.sum(x_p, axes, keepdims=True)
Expand Down Expand Up @@ -2610,7 +2603,8 @@ def _impl_v11(cls, bb, inputs, attr, params):
chunk_size, dim_size = int(split), input_shape[axis]
if dim_size % chunk_size != 0:
raise ValueError(
f"Dimension of size {dim_size} along axis {axis} must be evenly divisible by chunk size {chunk_size}"
f"Dimension of size {dim_size} along axis {axis} must be "
f"evenly divisible by chunk size {chunk_size}"
)
split = dim_size // chunk_size

Expand Down Expand Up @@ -2775,8 +2769,6 @@ def _get_convert_map():
"Resize": Resize,
"Einsum": Einsum,
"Range": Range,
"Greater": Greater,
"Reciprocal": Reciprocal,
"OneHot": OneHot,
"Unique": Unique,
# "NonZero": NonZero,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=import-outside-toplevel, redefined-builtin, unused-argument
"""Set operators."""
from typing import Optional, Tuple, Union
from typing import Optional, Union

import numpy as np # type: ignore
import tvm
Expand Down

0 comments on commit 3666765

Please sign in to comment.