Skip to content

Commit

Permalink
Setting to use torch dynamo compiling path in eager (#2045)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2045

1/ Refactoring to use is_torchdynamo_compiling() from torchrec.pt2.checks instead of code duplication

2/ We have alternative path of logic is_torchdynamo_compiling(). Our tests are not testing it without compilation, so it is error prone to not catch some shape mismatch or etc. =>
We need a tool how to cover it with eager tests without compilation. =>
Introducing global setting to force using is_torchdynamo_compiling() path for eager for test coverage and debug.

Enabling this path for test_pt2_multiprocess, that first eager iteration will be done on is_torchdynamo_compiling path.

Reviewed By: PaulZhang12, gnahzg

Differential Revision: D57860075

fbshipit-source-id: a033be81367b814afa47a7b22bd68d7eccf4f991
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed May 28, 2024
1 parent 9a50de2 commit 3ca8e8b
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 51 deletions.
10 changes: 1 addition & 9 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torchrec.distributed.global_settings import get_propogate_device
from torchrec.distributed.types import Awaitable, QuantizedCommCodecs, rank_device
from torchrec.fx.utils import fx_marker
from torchrec.pt2.checks import is_torchdynamo_compiling
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

try:
Expand All @@ -47,15 +48,6 @@
pass


try:
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling

except Exception:

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False


logger: logging.Logger = logging.getLogger()


Expand Down
8 changes: 1 addition & 7 deletions torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,14 @@
)
from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
from torchrec.pt2.checks import is_torchdynamo_compiling
from torchrec.quant.embedding_modules import (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection,
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor

try:
from torch._dynamo import is_compiling as is_torchdynamo_compiling
except Exception:

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False


def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
# pyre-ignore
Expand Down
16 changes: 12 additions & 4 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
import torchrec
import torchrec.pt2.checks
from hypothesis import given, settings, strategies as st, Verbosity
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
Expand Down Expand Up @@ -94,6 +95,11 @@ class _InputType(Enum):
VARIABLE_BATCH = 2


class _ConvertToVariableBatch(Enum):
FALSE = 0
TRUE = 1


class EBCSharderFixedShardingType(EmbeddingBagCollectionSharder):
def __init__(
self,
Expand Down Expand Up @@ -333,6 +339,8 @@ def _test_compile_rank_fn(
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)

torchrec.distributed.comm_ops.set_use_sync_collectives(True)
torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True)

dmp.train(True)

eager_out = dmp(kjt_ft)
Expand Down Expand Up @@ -385,14 +393,14 @@ def disable_cuda_tf32(self) -> bool:
_ModelType.EBC,
ShardingType.TABLE_WISE.value,
_InputType.SINGLE_BATCH,
True,
_ConvertToVariableBatch.TRUE,
"eager",
),
(
_ModelType.EBC,
ShardingType.COLUMN_WISE.value,
_InputType.SINGLE_BATCH,
True,
_ConvertToVariableBatch.TRUE,
"eager",
),
]
Expand All @@ -406,7 +414,7 @@ def test_compile_multiprocess(
_ModelType,
str,
_InputType,
bool,
_ConvertToVariableBatch,
str,
],
) -> None:
Expand All @@ -421,6 +429,6 @@ def test_compile_multiprocess(
sharding_type=sharding_type,
kernel_type=kernel_type,
input_type=input_type,
convert_to_vb=tovb,
convert_to_vb=tovb == _ConvertToVariableBatch.TRUE,
torch_compile_backend=compile_backend,
)
9 changes: 1 addition & 8 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,10 @@
TrainPipelineContext,
)
from torchrec.distributed.types import Awaitable
from torchrec.pt2.checks import is_torchdynamo_compiling
from torchrec.streamable import Multistreamable


try:
from torch._dynamo import is_compiling as is_torchdynamo_compiling
except Exception:

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False


logger: logging.Logger = logging.getLogger(__name__)


Expand Down
24 changes: 22 additions & 2 deletions torchrec/pt2/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,42 @@

import torch

USE_TORCHDYNAMO_COMPILING_PATH: bool = False


def set_use_torchdynamo_compiling_path(val: bool) -> None:
global USE_TORCHDYNAMO_COMPILING_PATH
USE_TORCHDYNAMO_COMPILING_PATH = val


def get_use_torchdynamo_compiling_path() -> bool:
global USE_TORCHDYNAMO_COMPILING_PATH
return USE_TORCHDYNAMO_COMPILING_PATH


try:
if torch.jit.is_scripting():
raise Exception()

from torch.compiler import (
is_compiling as is_compiler_compiling,
is_dynamo_compiling as is_torchdynamo_compiling,
is_dynamo_compiling as _is_torchdynamo_compiling,
)

def is_torchdynamo_compiling() -> bool:
if torch.jit.is_scripting():
return False

# Can not use global variable here, as it is not supported in TorchScript
# (It parses full method src even there is a guard torch.jit.is_scripting())
return get_use_torchdynamo_compiling_path() or _is_torchdynamo_compiling()

def is_non_strict_exporting() -> bool:
return not is_torchdynamo_compiling() and is_compiler_compiling()

except Exception:
# BC for torch versions without compiler and torch deploy path
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
def is_torchdynamo_compiling() -> bool:
return False

def is_non_strict_exporting() -> bool:
Expand Down
27 changes: 6 additions & 21 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from torch.autograd.profiler import record_function
from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node
from torchrec.pt2.checks import pt2_checks_all_is_size, pt2_checks_tensor_slice
from torchrec.pt2.checks import (
is_non_strict_exporting,
is_torchdynamo_compiling,
pt2_checks_all_is_size,
pt2_checks_tensor_slice,
)
from torchrec.streamable import Pipelineable

try:
Expand All @@ -38,26 +43,6 @@
except ImportError:
pass

try:
if torch.jit.is_scripting():
raise Exception()

from torch.compiler import (
is_compiling as is_compiler_compiling,
is_dynamo_compiling as is_torchdynamo_compiling,
)

def is_non_strict_exporting() -> bool:
return not is_torchdynamo_compiling() and is_compiler_compiling()

except Exception:
# BC for torch versions without compiler and torch deploy path
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False

def is_non_strict_exporting() -> bool:
return False


def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
if is_torchdynamo_compiling():
Expand Down

0 comments on commit 3ca8e8b

Please sign in to comment.