diff --git a/torchrec/modules/feature_processor.py b/torchrec/modules/feature_processor.py index b4df1c492..48d89a7ba 100644 --- a/torchrec/modules/feature_processor.py +++ b/torchrec/modules/feature_processor.py @@ -40,6 +40,14 @@ def position_weighted_module_update_features( return features +@torch.jit.script_if_tracing +@torch.fx.wrap +def offsets_to_range_traceble( + offsets: torch.Tensor, values: torch.Tensor +) -> torch.Tensor: + return torch.ops.fbgemm.offsets_range(offsets.long(), torch.numel(values)) + + # Will be deprecated soon, please use PositionWeightedProcessor, see full doc below class PositionWeightedModule(BaseFeatureProcessor): """ @@ -86,8 +94,8 @@ def forward( weighted_features: Dict[str, JaggedTensor] = {} for key, position_weight in self.position_weights.items(): - seq = torch.ops.fbgemm.offsets_range( - features[key].offsets().long(), torch.numel(features[key].values()) + seq = offsets_to_range_traceble( + features[key].offsets(), features[key].values() ) weighted_features[key] = JaggedTensor( values=features[key].values(),