From eddfad121ec5e3bed3ac79fc688965c8a2e486a8 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Thu, 9 Jan 2025 11:00:45 -0800 Subject: [PATCH] remove temporary import in the try catch Summary: Tensordict is external dependence from PyTorch atm. Reviewed By: sarckk Differential Revision: D66772947 --- torchrec/distributed/embedding.py | 7 ------- torchrec/distributed/embeddingbag.py | 7 ------- torchrec/modules/embedding_modules.py | 8 -------- torchrec/sparse/jagged_tensor.py | 8 -------- 4 files changed, 30 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 9c314f6fa..93773cc1f 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -97,13 +97,6 @@ except OSError: pass -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index ca9f6a18e..8cfd16ae9 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -102,13 +102,6 @@ except OSError: pass -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: return ( diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 9a1878361..307d66639 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -21,14 +21,6 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - - @torch.fx.wrap def reorder_inverse_indices( inverse_indices: Optional[Tuple[List[str], torch.Tensor]], diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 8468c9977..336658833 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -47,14 +47,6 @@ except OSError: pass -# OSS -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - logger: logging.Logger = logging.getLogger()