diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 06239ff7f..549c66eea 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -300,6 +300,23 @@ def _fx_trec_unwrap_kjt( return indices.int(), offsets.int() +@torch.fx.wrap +def _fx_trec_unwrap_jt( + jt: JaggedTensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forced conversions to support TBE + CPU - int32 or int64, offsets dtype must match + GPU - int32 only, offsets dtype must match + """ + indices = jt.values() + offsets = jt.offsets() + if jt.device().type == "cpu": + return indices, offsets.type(dtype=indices.dtype) + else: + return indices.int(), offsets.int() + + class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin): """ This class represents a reimplemented version of the EmbeddingBagCollection