diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 971f63902..abdd91f03 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8432,16 +8432,16 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False) -@torch_op("aten::unflatten.int") -def aten_unflatten(self: TReal, dim: INT64, sizes: INT64): +@torch_op("aten::unflatten.int", trace_only=True) +def aten_unflatten(self: TReal, dim: int, sizes: Sequence[INT64]): """unflatten(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)""" self_size = op.Shape(self) # PyTorch accepts negative dim as reversed counting - self_rank = op.Size(self_size) - dim = self_rank + dim - dim = dim % self_rank + self_rank = len(self.shape) + if dim < 0: + dim = self_rank + dim head_start_idx = op.Constant(value_ints=[0]) head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1])) @@ -8451,8 +8451,16 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64): tail_end_idx = op.Constant(value_ints=[_INT64_MAX]) tail_part_rank = op.Slice(self_size, tail_start_idx, tail_end_idx) - final_shape = op.Concat(head_part_rank, sizes, tail_part_rank, axis=0) + sizes = [op.Reshape(size, op.Constant(value_ints=[1])) for size in sizes] + # corner case 1: head part is None + if dim == 0: + final_shape = op.Concat(*sizes, tail_part_rank, axis=0) + # corner case 2: tail part is None + elif dim == self_rank - 1: + final_shape = op.Concat(head_part_rank, *sizes, axis=0) + else: + final_shape = op.Concat(head_part_rank, *sizes, tail_part_rank, axis=0) return op.Reshape(self, final_shape) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index c6b52be0c..c1d380f9f 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -429,13 +429,6 @@ def _sum_input_wrangler( return args, kwargs -def _unflatten_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - args[1] = np.array(args[1], dtype=np.int64) - return args, kwargs - - def _where_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -1471,14 +1464,9 @@ def _where_input_wrangler( TorchLibOpInfo( "unflatten", core_ops.aten_unflatten, - input_wrangler=_unflatten_input_wrangler, - ) - .xfail( + ).xfail( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", - ) - .xfail( - reason="fixme: https://github.com/pytorch/pytorch/issues/146336", ), TorchLibOpInfo("unfold", core_ops.aten_unfold), TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold),