Skip to content

Commit

Permalink
KJT dynamo path for conditions (#1846)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1846

The change is no-change for eager path.

**Dynamo path**

Dynamo tracing can not pass through various conditions => making dynamo path without those conditions.

We skip asserts for dynamo and variable batch check if it can go through "equal batches" logic => for Dynamo fallback on Variable Batch path.

**Comments about torch scripting**

`guard_size_oblivious()`, `_check_is_size()` is not torch scriptable atm, we have to guard it in `is not torch.jit.is_scripting`

`torch.jit._unwrap_optional(recat)` is used to hint torch script that recat is not None at the places where we use it.

`torch.jit.annotate()` hints to TorchScript about the returned type of `tensor.tolist()`

Reviewed By: ezyang

Differential Revision: D55695198

fbshipit-source-id: a98ee025cf4dfe40e18ea28a589c5ed3f2b9ed03
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed May 1, 2024
1 parent a649b4e commit 254cb1e
Showing 1 changed file with 39 additions and 17 deletions.
56 changes: 39 additions & 17 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
# pyre-strict

import abc

import operator

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch.autograd.profiler import record_function
from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec

# pyre-ignore
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node

from torchrec.streamable import Pipelineable
Expand Down Expand Up @@ -342,7 +346,7 @@ def _permute_tensor_by_segments(
segment_sizes,
tensor,
weights,
None,
tensor.numel(),
)
return permuted_tensor, permuted_weights

Expand Down Expand Up @@ -772,6 +776,11 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten
def _assert_tensor_has_no_elements_or_has_integers(
tensor: torch.Tensor, tensor_name: str
) -> None:
if is_torchdynamo_compiling():
# Skipping the check tensor.numel() == 0 to not guard on pt2 symbolic shapes.
# TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable.
return

assert tensor.numel() == 0 or tensor.dtype in [
torch.long,
torch.int,
Expand Down Expand Up @@ -1404,7 +1413,7 @@ def __init__(
self._stride_per_key_per_rank = stride_per_key_per_rank
self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank]
self._variable_stride_per_key = True
if not stride_per_key_per_rank:
if stride_per_key_per_rank is not None:
self._stride = 0
elif all(s == self.stride_per_key()[0] for s in self.stride_per_key()):
self._stride = self.stride_per_key()[0]
Expand Down Expand Up @@ -1651,8 +1660,9 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
return kjt

def sync(self) -> "KeyedJaggedTensor":
self.length_per_key()
self.offset_per_key()
if not is_torchdynamo_compiling():
self.length_per_key()
self.offset_per_key()
return self

def unsync(self) -> "KeyedJaggedTensor":
Expand Down Expand Up @@ -2277,17 +2287,21 @@ def dist_init(
cumsum_lengths[strides_cumsum[1:]] - cumsum_lengths[strides_cumsum[:-1]]
)
with record_function("## all2all_data:recat_values ##"):
if recat is not None and recat.numel() > 0:
recat_cond: bool = recat is not None
if recat_cond and not is_torchdynamo_compiling():
recat_cond = torch.jit._unwrap_optional(recat).numel() > 0

if recat_cond:
lengths, _ = _permute_tensor_by_segments(
lengths,
stride_per_rank_per_key,
recat,
torch.jit._unwrap_optional(recat),
None,
)
values, weights = _permute_tensor_by_segments(
values,
length_per_key,
recat,
torch.jit._unwrap_optional(recat),
weights,
)
if not stride_per_key_per_rank:
Expand All @@ -2314,24 +2328,32 @@ def dist_init(
else:
assert stride_per_rank is not None
with record_function("## all2all_data:recat_values ##"):
if recat is not None and recat.numel() > 0:
recat_cond: bool = recat is not None

if not torch.jit.is_scripting() and is_torchdynamo_compiling():
# pyre-ignore
recat_cond = recat_cond and guard_size_oblivious(recat.numel() > 0)
else:
recat_cond = (
recat_cond and torch.jit._unwrap_optional(recat).numel() > 0
)

if recat_cond:
stride = stride_per_rank[0]

# dynamo don't handle generators well
# so had to unroll the original generator into
# this for loop.
single_batch_per_rank = True
for s in stride_per_rank:
if s != stride:
single_batch_per_rank = False
single_batch_per_rank = False
if not is_torchdynamo_compiling():
single_batch_per_rank = all(
s == stride for s in stride_per_rank
)

if single_batch_per_rank:
(
lengths,
values,
weights,
) = torch.ops.fbgemm.permute_2D_sparse_data(
recat,
torch.jit._unwrap_optional(recat),
lengths.view(-1, stride),
values,
weights,
Expand All @@ -2344,7 +2366,7 @@ def dist_init(
values,
weights,
) = torch.ops.fbgemm.permute_1D_sparse_data(
recat,
torch.jit._unwrap_optional(recat),
lengths.view(-1),
values,
weights,
Expand Down

0 comments on commit 254cb1e

Please sign in to comment.