forked from pytorch/torchrec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handle KJT with zero batch size for Column-Wise sharded EmbeddingBagC…
…ollection Summary: Support new use case where some ranks have no embedding ids to look up i.e. `kjt.values() == torch.tensor([])`. In such cases, the expectation is for the returned embedding to be of shape `[0,emb_dim]`, as 0 is a valid tensor dimension. This diff adds support for this use case, now only for Column-Wise (CW) sharding + EmbeddingBagCollection use case. Changes in this diff: 1) CW sharding and VLE uses FBGEMM kernel to permute pooled embeddings, which doesn't work for 0-dim tensors. If tensor has no elements (i.e. 0-dim), permute doesn't do anything so we can return early. We could also support this via an `if` statement in TorchRec codebase, but we run into FX tracing issues in this case. 2) In comm_ops.py, `[output.view(B_local, -1) for output in outputs_by_rank]` isn't supported if `output` tensor has 0 dim as it will error out with `RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous`. Instead, we can explicitly create a view with the bsz and emb_dim dimensions which will hold even if `output` is 0-dim (outputs will have shape `[0,emb_dim_for_rank]` 3) Added new unit test cases Reviewed By: dstaay-fb Differential Revision: D69156551
- Loading branch information
1 parent
ea1cc27
commit 1cb8152
Showing
3 changed files
with
46 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters