diff --git a/test/export/test_export.py b/test/export/test_export.py index a7a188b6f3f38d..cec463fa3dc0e9 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2097,6 +2097,32 @@ def forward(self, x): ): export(Module(), (torch.tensor(1, device="cpu"),)) + def test_float_conversion(self): + class Module(torch.nn.Module): + def forward(self, x): + return x.float() + + ep = export(Module(), (torch.tensor(1, dtype=torch.float),)) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertGreater(len(ops), 0) + for op in ops: + self.assertIn(op, (torch.ops.aten._to_copy.default,)) + + def test_device_to_mutation_float(self): + class Module(torch.nn.Module): + def forward(self, x): + y = x.float() + y.add_(1) + return y, x + + with self.assertRaisesRegex( + RuntimeError, "cannot mutate tensors with frozen storage" + ): + export(Module(), (torch.tensor(1, dtype=torch.float),)) + def test_module(self): class MyLinear(torch.nn.Module): def __init__(self): diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 1762059eedf22f..dfef5951ab26f6 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -17,6 +17,27 @@ not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") +# NOTE Some special handling for tensor conversion during export is needed. +# Normally, when tracing through the model with tensor.to(), the maybe-aliasing +# relationship between input and output tensors will be baked into the graph. +# For example, if we got a tensor with device cpu and call tensor.to("cpu"), +# it will become a no-op in the graph. For a whole graph capture, this is not +# sound so we need to do something different. Instead, in export we will try to +# preserve the tensor conversion by forcing a non-semantic-breaking aten::_to_copy +# operator to be traced in the graph, and subsequently banning mutations on all +# such converted tensors. +# In addition to patching .to() method call in functionalization, we will have to +# patch other similar methods like float() and cpu(), because they intentionally +# don't fall back to .to() methods, but have the same behavior as .to() according to +# pytorch document. https://pytorch.org/docs/stable/generated/torch.Tensor.float.html +# thus we simply force them to go through .to() call. +def _conversion_method_template(**extra_kwargs): + def _(self, *args, **kwargs): + return self.to(*args, **{**kwargs, **extra_kwargs}) + + return _ + + class FunctionalTensor(torch.Tensor): """ Functional tensors represent tensors that will remove mutations @@ -225,6 +246,24 @@ def to(self, *args, **kwargs): return super().to(*args, **{**kwargs, "copy": True}) return super().to(*args, **kwargs) + def cuda(self, device=None, *args, **kwargs): + device = device or torch.cuda.current_device() + if len(args) > 0: + return self.to(device, *args, **kwargs) + else: + return self.to(device=device, **kwargs) + + char = _conversion_method_template(dtype=torch.int8) + cpu = _conversion_method_template(device=torch.device("cpu")) + bfloat16 = _conversion_method_template(dtype=torch.bfloat16) + byte = _conversion_method_template(dtype=torch.uint8) + double = _conversion_method_template(dtype=torch.float64) + float = _conversion_method_template(dtype=torch.float32) + bool = _conversion_method_template(dtype=torch.bool) + half = _conversion_method_template(dtype=torch.float16) + int = _conversion_method_template(dtype=torch.int32) + long = _conversion_method_template(dtype=torch.int64) + class FunctionalTensorMode(TorchDispatchMode): def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):