Skip to content

Commit

Permalink
Call prefetch of underlying module when using prefetch (#2077)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2077

Right now, we can't use prefetch pipeline with a model that has a table sharded with data parallel.

It would just call lookup.prefetch, which is well guarded with `if hasattr(emb_op.emb_module, "prefetch"):`.

Reviewed By: joshuadeng

Differential Revision: D58213342

fbshipit-source-id: e1ad02b95487dd2488b23b37f1fc4e81431fb30a
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jun 6, 2024
1 parent 8393202 commit 5cf175a
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation
from torch import fx, nn
from torch.nn.modules.module import _addindent
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.types import (
get_tensor_size_bytes,
ModuleSharder,
Expand Down Expand Up @@ -337,6 +338,8 @@ def prefetch(
"""

for feature, emb_lookup in zip(dist_input, self._lookups):
while isinstance(emb_lookup, DistributedDataParallel):
emb_lookup = emb_lookup.module
emb_lookup.prefetch(sparse_features=feature, forward_stream=forward_stream)

def extra_repr(self) -> str:
Expand Down

0 comments on commit 5cf175a

Please sign in to comment.