From 3b91dd27bdc35546c3d3c2a32858cea3b68e93f6 Mon Sep 17 00:00:00 2001 From: Zhengkai Zhang Date: Fri, 23 Feb 2024 20:37:21 -0800 Subject: [PATCH] Update PositionWeightedModule to make it jit trace compitable Summary: In ``` torch.ops.fbgemm.offsets_range(features[key].offsets().long(), torch.numel(features[key].values()) ``` This part ``` torch.numel(features[key].values() ``` will be traced into constant Reviewed By: snabelkabiya, houseroad Differential Revision: D53744703 fbshipit-source-id: bf87eeaacf295b82b4591fb53029f66868ab8d86 --- torchrec/modules/feature_processor.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchrec/modules/feature_processor.py b/torchrec/modules/feature_processor.py index b4df1c492..48d89a7ba 100644 --- a/torchrec/modules/feature_processor.py +++ b/torchrec/modules/feature_processor.py @@ -40,6 +40,14 @@ def position_weighted_module_update_features( return features +@torch.jit.script_if_tracing +@torch.fx.wrap +def offsets_to_range_traceble( + offsets: torch.Tensor, values: torch.Tensor +) -> torch.Tensor: + return torch.ops.fbgemm.offsets_range(offsets.long(), torch.numel(values)) + + # Will be deprecated soon, please use PositionWeightedProcessor, see full doc below class PositionWeightedModule(BaseFeatureProcessor): """ @@ -86,8 +94,8 @@ def forward( weighted_features: Dict[str, JaggedTensor] = {} for key, position_weight in self.position_weights.items(): - seq = torch.ops.fbgemm.offsets_range( - features[key].offsets().long(), torch.numel(features[key].values()) + seq = offsets_to_range_traceble( + features[key].offsets(), features[key].values() ) weighted_features[key] = JaggedTensor( values=features[key].values(),