Skip to content

Commit

Permalink
[Relax][PyTorch] Support more unary ops for ExportedProgram importer (a…
Browse files Browse the repository at this point in the history
…pache#17421)

* support more unary ops

* support clamp

* support gelu

* support hardsigmoid

* support hardswish

* support hardtanh

* support leaky_relu

* support log_softmax

* support round

* support softmax

* support tril and triu

* skip flaky test
  • Loading branch information
mshr-h authored Sep 27, 2024
1 parent 42ff98b commit 176d01e
Show file tree
Hide file tree
Showing 5 changed files with 812 additions and 80 deletions.
74 changes: 74 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,80 @@ def convert(node: fx.Node) -> relax.Var:

return convert

def _clamp(self, node: fx.Node) -> relax.Expr:
args = self.retrieve_args(node)
a_min = args[1] if len(args) > 1 else node.kwargs["min"]
a_max = args[2] if len(args) > 2 else node.kwargs["max"]
if not isinstance(a_min, (int, float)):
raise ValueError(
f"TVM only supports constant min value for torch.clamp/clip, "
f"but got {a_min} with type {type(a_min)}"
)
if not isinstance(a_max, (int, float)):
raise ValueError(
f"TVM only supports constant max value for torch.clamp/clip, "
f"but got {a_max} with type {type(a_max)}"
)
return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))

def _gelu(self, node: fx.Node) -> relax.Expr:
approximate = node.kwargs.get("approximate", "none")
if approximate == "none":
return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
elif approximate == "tanh":
return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
else:
raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate))

def _hardsigmoid(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dtype = x.struct_info.dtype
x0 = relax.op.add(x, relax.const(3, dtype))
x1 = relax.op.clip(x0, 0, 6)
return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype)))

def _hardswish(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dtype = x.struct_info.dtype
x0 = relax.op.add(x, relax.const(3, dtype))
x1 = relax.op.clip(x0, 0, 6)
x2 = relax.op.divide(x1, relax.const(6, dtype))
return self.block_builder.emit(relax.op.multiply(x, x2))

def _leakyrelu(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01)
return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))

def _log_softmax(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))

def _round(self, node: fx.Node) -> relax.Expr:
if node.kwargs.get("decimals", 0) != 0:
raise ValueError("specifying decimals for round is not supported yet")
arg = self.env[node.args[0]]
return self.block_builder.emit(relax.op.round(arg))

def _softmax(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
return self.block_builder.emit(relax.op.nn.softmax(x, dim))

def _tril_triu(self, op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0)
assert isinstance(k, int)
return self.block_builder.emit(op(x, k))

return convert

########## Neural Network ##########

def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,51 @@ def create_input_vars(

return parameters_buffers_constants, user_inputs

########## Unary Ops ##########

def _hardtanh(self, node: fx.Node) -> relax.Expr:
args = self.retrieve_args(node)
x = args[0]
min_val = node.args[1] if len(args) > 1 else node.kwargs("min_val", -1.0)
max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0)
return self.block_builder.emit(relax.op.clip(x, min_val, max_val))

def create_convert_map(
self,
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
return {
# unary
"acos.default": self._unary_op(relax.op.acos),
"acosh.default": self._unary_op(relax.op.acosh),
"asin.default": self._unary_op(relax.op.asin),
"asinh.default": self._unary_op(relax.op.asinh),
"atan.default": self._unary_op(relax.op.atan),
"atanh.default": self._unary_op(relax.op.atanh),
"clamp.default": self._clamp,
"cos.default": self._unary_op(relax.op.cos),
"cosh.default": self._unary_op(relax.op.cosh),
"dropout.default": lambda node: self.env[node.args[0]],
"exp.default": self._unary_op(relax.op.exp),
"gelu.default": self._gelu,
"hardsigmoid.default": self._hardsigmoid,
"hardswish.default": self._hardswish,
"hardtanh.default": self._hardtanh,
"leaky_relu.default": self._leakyrelu,
"log_softmax.int": self._log_softmax,
"neg.default": self._unary_op(relax.op.negative),
"relu.default": self._unary_op(relax.op.nn.relu),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
"sigmoid.default": self._unary_op(relax.op.sigmoid),
"silu.default": self._unary_op(relax.op.nn.silu),
"sin.default": self._unary_op(relax.op.sin),
"sinh.default": self._unary_op(relax.op.sinh),
"softmax.int": self._softmax,
"sqrt.default": self._unary_op(relax.op.sqrt),
"tan.default": self._unary_op(relax.op.tan),
"tanh.default": self._unary_op(relax.op.tanh),
"tril.default": self._tril_triu(relax.op.tril),
"triu.default": self._tril_triu(relax.op.triu),
# neural network
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
"conv2d.default": self._conv2d,
Expand Down
74 changes: 0 additions & 74 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,82 +62,19 @@ def _fetch_attr(self, model, target: str):

########## Unary Ops ##########

def _clamp(self, node: fx.Node) -> relax.Expr:
args = self.retrieve_args(node)
a_min = args[1] if len(args) > 1 else node.kwargs["min"]
a_max = args[2] if len(args) > 2 else node.kwargs["max"]
if not isinstance(a_min, (int, float)):
raise ValueError(
f"TVM only supports constant min value for torch.clamp/clip, "
f"but got {a_min} with type {type(a_min)}"
)
if not isinstance(a_max, (int, float)):
raise ValueError(
f"TVM only supports constant max value for torch.clamp/clip, "
f"but got {a_max} with type {type(a_max)}"
)
return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))

def _gelu(self, node: fx.Node) -> relax.Expr:
approximate = node.kwargs.get("approximate", "none")
if approximate == "none":
return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
elif approximate == "tanh":
return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
else:
raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate))

def _hardsigmoid(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dtype = x.struct_info.dtype
x0 = relax.op.add(x, relax.const(3, dtype))
x1 = relax.op.clip(x0, 0, 6)
return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype)))

def _hardswish(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dtype = x.struct_info.dtype
x0 = relax.op.add(x, relax.const(3, dtype))
x1 = relax.op.clip(x0, 0, 6)
x2 = relax.op.divide(x1, relax.const(6, dtype))
return self.block_builder.emit(relax.op.multiply(x, x2))

def _leakyrelu(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01)
return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))

def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
alpha = module.negative_slope
return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))

def _log_softmax(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))

def _log_softmax_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
dim = module.dim
assert dim is not None
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))

def _round(self, node: fx.Node) -> relax.Expr:
if node.kwargs.get("decimals", 0) != 0:
raise ValueError("specifying decimals for round is not supported yet")
arg = self.env[node.args[0]]
return self.block_builder.emit(relax.op.round(arg))

def _softmax(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
return self.block_builder.emit(relax.op.nn.softmax(x, dim))

def _softmax_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
Expand All @@ -159,17 +96,6 @@ def convert(node: fx.Node) -> relax.Var:

return convert

def _tril_triu(self, op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0)
assert isinstance(k, int)
return self.block_builder.emit(op(x, k))

return convert

########## Binary Ops ##########

def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
Expand Down
Loading

0 comments on commit 176d01e

Please sign in to comment.