Skip to content

Commit

Permalink
[Relax][PyTorch] Add support for torch.repeat (#17304)
Browse files Browse the repository at this point in the history
* add test

* add support for torch.repeat

* remove debug print
  • Loading branch information
mshr-h authored Aug 27, 2024
1 parent bf7bbef commit 99defd2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,14 @@ def _squeeze(self, node: fx.node.Node) -> relax.Var:
dim = None
return self.block_builder.emit(relax.op.squeeze(x, dim))

def _repeat(self, node: fx.node.Node) -> relax.Var:
import torch # type: ignore

args = self.retrieve_args(node)
if isinstance(args[1], (torch.Size, tuple, list)):
return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1])))
return self.block_builder.emit(relax.op.tile(args[0], args[1:]))

def _tile(self, node: fx.node.Node) -> relax.Var:
import torch # type: ignore

Expand Down Expand Up @@ -1484,6 +1492,7 @@ def create_convert_map(self):
"expand": self._expand,
"flatten": self._flatten,
"permute": self._permute,
"repeat": self._repeat,
"reshape": self._reshape,
"split": self._split,
"tile": self._tile,
Expand Down
36 changes: 36 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3311,6 +3311,42 @@ def main(
verify_model(Transpose(), input_info, {}, expected1)


def test_repeat():
class Tile1(Module):
def forward(self, x: torch.Tensor):
return x.repeat(2)

class Tile2(Module):
def forward(self, x: torch.Tensor):
return x.repeat(4, 2)

@tvm.script.ir_module
class expected1:
@R.function
def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((6,), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2)
gv: R.Tensor((6,), dtype="float32") = lv
R.output(gv)
return gv

@tvm.script.ir_module
class expected2:
@R.function
def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
gv: R.Tensor((4, 6), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Tile1(), [([3], "float32")], {}, expected1)
verify_model(Tile2(), [([1, 3], "float32")], {}, expected2)
verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2)


def test_view():
input_info = [([1, 2, 3, 4], "float32")]

Expand Down

0 comments on commit 99defd2

Please sign in to comment.