From 3ea57c5971c0888c76ab1ff47603733eef183b45 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 23 Feb 2024 18:05:05 -0800 Subject: [PATCH] tolist() support for FunctionalTensor Summary: Support tolist() for FunctionalTensor for KJT in torch.export bypass-github-pytorch-ci-checks Reviewed By: ezyang Differential Revision: D53731064 fbshipit-source-id: a226bef5de4cdbf6aa0dcc1a6dee28c0874ac484 --- torchrec/distributed/tests/test_pt2.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 60b15befd..c809691eb 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -240,3 +240,18 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None: # TODO: turn on AOT Inductor test once the support is ready test_aot_inductor=False, ) + + def test_tensor_tolist(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor): + return kjt.values().tolist() + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + self._test_kjt_input_module( + M(), + kjt.keys(), + (kjt._values, kjt._lengths), + test_dynamo=False, + test_aot_inductor=False, + test_pt2_ir_export=True, + )