diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 755b644b7..5e365ba3d 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -19,12 +19,25 @@ pooling_type_to_str, ) from torchrec.modules.utils import register_custom_op -from torchrec.sparse.jagged_tensor import ( - is_non_strict_exporting, - JaggedTensor, - KeyedJaggedTensor, - KeyedTensor, -) +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor + + +try: + if torch.jit.is_scripting(): + raise Exception() + + from torch.compiler import ( + is_compiling as is_compiler_compiling, + is_dynamo_compiling as is_torchdynamo_compiling, + ) + + def is_non_strict_exporting() -> bool: + return not is_torchdynamo_compiling() and is_compiler_compiling() + +except Exception: + + def is_non_strict_exporting() -> bool: + return False @torch.fx.wrap