From 5cf175a0202a39dddc6715151271d6fe14f1b7d3 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Thu, 6 Jun 2024 10:28:32 -0700 Subject: [PATCH] Call prefetch of underlying module when using prefetch (#2077) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2077 Right now, we can't use prefetch pipeline with a model that has a table sharded with data parallel. It would just call lookup.prefetch, which is well guarded with `if hasattr(emb_op.emb_module, "prefetch"):`. Reviewed By: joshuadeng Differential Revision: D58213342 fbshipit-source-id: e1ad02b95487dd2488b23b37f1fc4e81431fb30a --- torchrec/distributed/embedding_types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 61f5641f3..ab2c6a103 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -16,6 +16,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation from torch import fx, nn from torch.nn.modules.module import _addindent +from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.types import ( get_tensor_size_bytes, ModuleSharder, @@ -337,6 +338,8 @@ def prefetch( """ for feature, emb_lookup in zip(dist_input, self._lookups): + while isinstance(emb_lookup, DistributedDataParallel): + emb_lookup = emb_lookup.module emb_lookup.prefetch(sparse_features=feature, forward_stream=forward_stream) def extra_repr(self) -> str: