Skip to content

Commit

Permalink
[Relax][PyTorch] Add support for torch.nn.functional.conv* (#17325)
Browse files Browse the repository at this point in the history
* add test for functional conv1d

* add support for functional conv1d

* cleanup conv1d

* add test for functional conv_transpose1d

* add support for functional conv_transpose1d

* add test for functional conv_transpose2d

* add support for functional conv_transpose2d

* add test for functional conv3d

* add support for functional conv3d
  • Loading branch information
mshr-h authored Sep 3, 2024
1 parent 42bffc3 commit 0e9c683
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 61 deletions.
284 changes: 223 additions & 61 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,61 +740,140 @@ def _linear_functional(self, node: fx.node.Node) -> relax.Var:
bias = args[2] if len(args) > 2 else None
return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32"))

def _conv1d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]

def _conv1d_impl(
self,
x: relax.Expr,
weight: relax.Expr,
bias: Optional[relax.Expr],
strides: Optional[Tuple],
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
) -> relax.Var:
conv1d = self.block_builder.emit(
relax.op.nn.conv1d(
x,
weight,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NCW",
kernel_layout="OIW",
out_dtype="float32",
)
)

if module.bias is None:
if bias is None:
return conv1d

bias = self.params[module.bias]
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1))

return self.block_builder.emit(relax.op.add(conv1d, bias))

def _conv3d(self, node: fx.node.Node) -> relax.Var:
def _conv1d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

conv3d = self.block_builder.emit(
relax.op.nn.conv3d(
return self._conv1d_impl(
x,
weight,
bias=bias,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)

def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv1d_impl(
x,
weight,
bias=bias,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
)

def _conv1d_transpose_impl(
self,
x: relax.Expr,
weight: relax.Expr,
bias: Optional[relax.Expr],
strides: Optional[Tuple],
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
) -> relax.Var:
conv1d_transpose = self.block_builder.emit(
relax.op.nn.conv1d_transpose(
x,
weight,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
data_layout="NCDHW",
kernel_layout="OIDHW",
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NCW",
kernel_layout="OIW",
out_dtype="float32",
)
)

if module.bias is None:
return conv3d
if bias is None:
return conv1d_transpose

bias = self.params[module.bias]
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
bias = relax.op.reshape(bias, (1, -1, 1))
return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))

return self.block_builder.emit(relax.op.add(conv3d, bias))
def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

return self._conv1d_transpose_impl(
x,
weight,
bias=bias,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)

def _conv1d_transpose_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv1d_transpose_impl(
x,
weight,
bias=bias,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
)

def _conv2d_impl(
self,
Expand Down Expand Up @@ -826,71 +905,150 @@ def _conv2d_impl(
bias = relax.op.reshape(bias, (1, -1, 1, 1))
return self.block_builder.emit(relax.op.add(conv2d, bias))

def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
def _conv2d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

conv1d_transpose = self.block_builder.emit(
relax.op.nn.conv1d_transpose(
return self._conv2d_impl(
x,
weight,
bias=bias,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)

def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv2d_impl(
x,
weight,
bias=bias,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
)

def _conv2d_transpose_impl(
self,
x: relax.Expr,
weight: relax.Expr,
bias: Optional[relax.Expr],
strides: Optional[Tuple],
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
) -> relax.Var:
conv2d_transpose = self.block_builder.emit(
relax.op.nn.conv2d_transpose(
x,
weight,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
data_layout="NCW",
kernel_layout="OIW",
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="float32",
)
)

if module.bias is None:
return conv1d_transpose
if bias is None:
return conv2d_transpose

bias = self.params[module.bias]
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1))

return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))
bias = relax.op.reshape(bias, (1, -1, 1, 1))
return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))

def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

conv2d_transpose = self.block_builder.emit(
relax.op.nn.conv2d_transpose(
return self._conv2d_transpose_impl(
x,
weight,
bias=bias,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)

def _conv2d_transpose_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv2d_transpose_impl(
x,
weight,
bias=bias,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
)

def _conv3d_impl(
self,
x: relax.Expr,
weight: relax.Expr,
bias: Optional[relax.Expr],
strides: Optional[Tuple],
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
):
conv3d = self.block_builder.emit(
relax.op.nn.conv3d(
x,
weight,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
data_layout="NCHW",
kernel_layout="OIHW",
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NCDHW",
kernel_layout="OIDHW",
out_dtype="float32",
)
)

if module.bias is None:
return conv2d_transpose

bias = self.params[module.bias]
if bias is None:
return conv3d
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1, 1))

return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
return self.block_builder.emit(relax.op.add(conv3d, bias))

def _conv2d(self, node: fx.node.Node) -> relax.Var:
def _conv3d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

return self._conv2d_impl(
return self._conv3d_impl(
x,
weight,
bias=bias,
Expand All @@ -900,7 +1058,7 @@ def _conv2d(self, node: fx.node.Node) -> relax.Var:
groups=module.groups,
)

def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
def _conv3d_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
Expand All @@ -909,7 +1067,7 @@ def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv2d_impl(
return self._conv3d_impl(
x,
weight,
bias=bias,
Expand Down Expand Up @@ -1482,7 +1640,11 @@ def create_convert_map(self):
"type": self._type,
"astype": self._type,
"matmul": self._matmul,
"conv1d": self._conv1d_functional,
"conv_transpose1d": self._conv1d_transpose_functional,
"conv2d": self._conv2d_functional,
"conv_transpose2d": self._conv2d_transpose_functional,
"conv3d": self._conv3d_functional,
"linear": self._linear_functional,
"addmm": self._addmm,
"baddbmm": self._baddbmm,
Expand Down
Loading

0 comments on commit 0e9c683

Please sign in to comment.