diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 5e365ba3d..aaec26e02 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -18,28 +18,10 @@ EmbeddingConfig, pooling_type_to_str, ) -from torchrec.modules.utils import register_custom_op +from torchrec.modules.utils import is_non_strict_exporting, register_custom_op from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor -try: - if torch.jit.is_scripting(): - raise Exception() - - from torch.compiler import ( - is_compiling as is_compiler_compiling, - is_dynamo_compiling as is_torchdynamo_compiling, - ) - - def is_non_strict_exporting() -> bool: - return not is_torchdynamo_compiling() and is_compiler_compiling() - -except Exception: - - def is_non_strict_exporting() -> bool: - return False - - @torch.fx.wrap def reorder_inverse_indices( inverse_indices: Optional[Tuple[List[str], torch.Tensor]], @@ -233,7 +215,7 @@ def _non_strict_exporting_forward( features.weights_or_none(), features.lengths_or_none(), features.offsets_or_none(), - ] + ] + [bag.weight for bag in self.embedding_bags.values()] dims = [sum(self._lengths_per_embedding)] ebc_op = register_custom_op(self, dims) outputs = ebc_op(arg_list, batch_size) diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index 626a58ed1..90e2b4e3f 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -20,6 +20,24 @@ lib = torch.library.Library("custom", "FRAGMENT") +try: + if torch.jit.is_scripting(): + raise Exception() + + from torch.compiler import ( + is_compiling as is_compiler_compiling, + is_dynamo_compiling as is_torchdynamo_compiling, + ) + + def is_non_strict_exporting() -> bool: + return not is_torchdynamo_compiling() and is_compiler_compiling() + +except Exception: + + def is_non_strict_exporting() -> bool: + return False + + class OpRegistryState: """ State of operator registry.