Skip to content

Commit

Permalink
Update PositionWeightedModule to make it jit trace compitable
Browse files Browse the repository at this point in the history
Summary:
In
```
torch.ops.fbgemm.offsets_range(features[key].offsets().long(), torch.numel(features[key].values())
```
This part
```
torch.numel(features[key].values()
```

will be traced into constant

Reviewed By: snabelkabiya, houseroad

Differential Revision: D53744703

fbshipit-source-id: bf87eeaacf295b82b4591fb53029f66868ab8d86
  • Loading branch information
ZhengkaiZ authored and facebook-github-bot committed Feb 24, 2024
1 parent a8e1675 commit 3b91dd2
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions torchrec/modules/feature_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def position_weighted_module_update_features(
return features


@torch.jit.script_if_tracing
@torch.fx.wrap
def offsets_to_range_traceble(
offsets: torch.Tensor, values: torch.Tensor
) -> torch.Tensor:
return torch.ops.fbgemm.offsets_range(offsets.long(), torch.numel(values))


# Will be deprecated soon, please use PositionWeightedProcessor, see full doc below
class PositionWeightedModule(BaseFeatureProcessor):
"""
Expand Down Expand Up @@ -86,8 +94,8 @@ def forward(

weighted_features: Dict[str, JaggedTensor] = {}
for key, position_weight in self.position_weights.items():
seq = torch.ops.fbgemm.offsets_range(
features[key].offsets().long(), torch.numel(features[key].values())
seq = offsets_to_range_traceble(
features[key].offsets(), features[key].values()
)
weighted_features[key] = JaggedTensor(
values=features[key].values(),
Expand Down

0 comments on commit 3b91dd2

Please sign in to comment.