Skip to content

Commit

Permalink
add better support for aten::to and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 29, 2020
1 parent 4378be9 commit 8336237
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,15 +623,25 @@ def _impl(inputs, input_types):
def _to():
def _impl(inputs, input_types):
data = inputs[0]
if inputs[3] in ["cpu", "cuda"]:
return data
# special handling for aten::to(data, 6, _, _, _) case
# 6 means dtype = float
# this happens when converting upsampling with scale factor
cast_func = {
6: float,
3: int,
}
cast_func_expr = {
6: lambda x: _op.cast(x, "float32"),
3: lambda x: _op.cast(x, "int32"),
}
if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
return cast_func[inputs[1]](data)
elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
return cast_func_expr[inputs[1]](data)
return data

return _impl

def _upsample(method):
Expand Down
42 changes: 37 additions & 5 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,16 @@ def verify_model(model_name, input_data=[],
ctx_list=ctx_list()):
"""Assert that the output of a compiled model matches with that of its
baseline."""
if len(input_data) == 0:
if isinstance(model_name, str):
baseline_model, baseline_input = load_model(model_name)
elif isinstance(input_data, torch.Tensor):
elif isinstance(input_data, list):
baseline_model = model_name
baseline_input = input_data
elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0:
baseline_model = model_name
baseline_input = [input_data]
else:
assert isinstance(input_data, list)
baseline_model = model_name
baseline_input = input_data
assert False, "Unexpected input format"

if torch.cuda.is_available():
baseline_model = baseline_model.cuda()
Expand Down Expand Up @@ -672,6 +673,36 @@ def forward(self, x):
verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp)
verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp)

def test_to():
""" test for aten::to(...) """
class ToCPU(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.to("cpu")

class ToFloat(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.float()

class ToInt(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.int()

verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32)))
verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), dtype=torch.int))
verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int))
verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32)))
verify_model(ToInt().eval(), torch.tensor(2.0))


# Model tests
def test_resnet18():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -801,6 +832,7 @@ def forward(self, inp):
test_forward_pow()
test_forward_chunk()
test_upsample()
test_to()

# Model tests
test_resnet18()
Expand Down

0 comments on commit 8336237

Please sign in to comment.