Skip to content

Commit

Permalink
Fix Op(unflatten) (#2070)
Browse files Browse the repository at this point in the history
The op was failing and not traced.
  • Loading branch information
titaiwangms authored and justinchuby committed Feb 21, 2025
1 parent 946a940 commit 997ad6e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
20 changes: 14 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8432,16 +8432,16 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False)


@torch_op("aten::unflatten.int")
def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
@torch_op("aten::unflatten.int", trace_only=True)
def aten_unflatten(self: TReal, dim: int, sizes: Sequence[INT64]):
"""unflatten(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)"""

self_size = op.Shape(self)

# PyTorch accepts negative dim as reversed counting
self_rank = op.Size(self_size)
dim = self_rank + dim
dim = dim % self_rank
self_rank = len(self.shape)
if dim < 0:
dim = self_rank + dim

head_start_idx = op.Constant(value_ints=[0])
head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1]))
Expand All @@ -8451,8 +8451,16 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
tail_end_idx = op.Constant(value_ints=[_INT64_MAX])
tail_part_rank = op.Slice(self_size, tail_start_idx, tail_end_idx)

final_shape = op.Concat(head_part_rank, sizes, tail_part_rank, axis=0)
sizes = [op.Reshape(size, op.Constant(value_ints=[1])) for size in sizes]

# corner case 1: head part is None
if dim == 0:
final_shape = op.Concat(*sizes, tail_part_rank, axis=0)
# corner case 2: tail part is None
elif dim == self_rank - 1:
final_shape = op.Concat(head_part_rank, *sizes, axis=0)
else:
final_shape = op.Concat(head_part_rank, *sizes, tail_part_rank, axis=0)
return op.Reshape(self, final_shape)


Expand Down
14 changes: 1 addition & 13 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,6 @@ def _sum_input_wrangler(
return args, kwargs


def _unflatten_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args[1] = np.array(args[1], dtype=np.int64)
return args, kwargs


def _where_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -1471,14 +1464,9 @@ def _where_input_wrangler(
TorchLibOpInfo(
"unflatten",
core_ops.aten_unflatten,
input_wrangler=_unflatten_input_wrangler,
)
.xfail(
).xfail(
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
reason="fixme: Logic not implemented for size 0 inputs in op.Reshape",
)
.xfail(
reason="fixme: https://github.com/pytorch/pytorch/issues/146336",
),
TorchLibOpInfo("unfold", core_ops.aten_unfold),
TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold),
Expand Down

0 comments on commit 997ad6e

Please sign in to comment.