Skip to content

Commit

Permalink
attempt to write specialized reshape code for 1d <-> 2d (#2075)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2075

This diff contains torchrec changes that, together with a separate PyTorch PR, will ensure that uniform batch test with view conversions work in the existing torchrec tests.

I am deliberately not including the changes in test_pt2_multiprocess.py that I used to test locally on my devserver. I will let the torchrec team decide if they want those changes in this diff or they prefer to do it themselves in a separate diff.

Reviewed By: IvanKobzarev

Differential Revision: D58200241

fbshipit-source-id: b68e2cebce7e75b34a36f9ecf6ff546b3a1a3d8c
  • Loading branch information
Shaz Qadeer authored and facebook-github-bot committed Jun 11, 2024
1 parent ac88295 commit b4b6d0b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
11 changes: 10 additions & 1 deletion torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,9 @@ def _test_compile_rank_fn(
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
opt_fn = torch.compile(
dmp, backend=torch_compile_backend, fullgraph=True, dynamic=True
dmp,
backend=torch_compile_backend,
fullgraph=True,
)
compile_out = opt_fn(kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb))
torch.testing.assert_close(eager_out, compile_out)
Expand Down Expand Up @@ -403,6 +405,13 @@ def disable_cuda_tf32(self) -> bool:
_ConvertToVariableBatch.TRUE,
"eager",
),
(
_ModelType.EBC,
ShardingType.TABLE_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.FALSE,
"eager",
),
]
),
)
Expand Down
18 changes: 18 additions & 0 deletions torchrec/pt2/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch

from torch.fx.experimental.symbolic_shapes import guard_size_oblivious

USE_TORCHDYNAMO_COMPILING_PATH: bool = False


Expand Down Expand Up @@ -74,3 +76,19 @@ def pt2_checks_all_is_size(list: List[int]) -> List[int]:
for i in list:
torch._check_is_size(i)
return list


def pt2_check_size_nonzero(x: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting() or not is_torchdynamo_compiling():
return x

for i in range(x.dim()):
torch._check(x.size(i) > 0)
return x


def pt2_guard_size_oblivious(x: bool) -> bool:
if torch.jit.is_scripting() or not is_torchdynamo_compiling():
return x

return guard_size_oblivious(x)
8 changes: 6 additions & 2 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from torchrec.pt2.checks import (
is_non_strict_exporting,
is_torchdynamo_compiling,
pt2_check_size_nonzero,
pt2_checks_all_is_size,
pt2_checks_tensor_slice,
pt2_guard_size_oblivious,
)
from torchrec.streamable import Pipelineable

Expand Down Expand Up @@ -878,8 +880,10 @@ def _maybe_compute_length_per_key(
_length_per_key_from_stride_per_key(lengths, stride_per_key)
if variable_stride_per_key
else (
torch.sum(lengths.view(-1, stride), dim=1).tolist()
if lengths.numel() != 0
torch.sum(
pt2_check_size_nonzero(lengths.view(len(keys), stride)), dim=1
).tolist()
if pt2_guard_size_oblivious(lengths.numel() != 0)
else [0] * len(keys)
)
)
Expand Down

0 comments on commit b4b6d0b

Please sign in to comment.