diff --git a/pytorch_to_returnn/torch/fft.py b/pytorch_to_returnn/torch/fft.py new file mode 100644 index 0000000..e69de29 diff --git a/pytorch_to_returnn/torch/tensor.py b/pytorch_to_returnn/torch/tensor.py index 1e50ac6..4ee6ee0 100644 --- a/pytorch_to_returnn/torch/tensor.py +++ b/pytorch_to_returnn/torch/tensor.py @@ -98,15 +98,18 @@ def clone(self): def cpu(self): return self # ignore + def cuda(self): + return self # ignore + @property def device(self): class DeviceDummy: type = None return DeviceDummy() - def flatten(self): + def flatten(self, start_dim=0, end_dim=-1): from .nn.functional import flatten - return flatten(self) + return flatten(self, start_dim, end_dim) def view(self, *shape): from .nn.functional import reshape