Skip to content

Commit

Permalink
Workaround for hardcoded device in PT2 IR
Browse files Browse the repository at this point in the history
Summary: PT2 IR can end up hardcoding the device during tracing, for example: https://fburl.com/code/m478zn5s. This diff provides a workaround util function move_to_copy_nodes_to_device to correctly set hardcoded devices to an appropriate file.

Differential Revision: D58426261
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Jun 11, 2024
1 parent 2ebc9fc commit 9ec1bee
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,22 @@ def mark_dynamic_kjt(
if kjt._offsets is not None:
shapes_collection[kjt._offsets] = (olen,)
return shapes_collection


def move_to_copy_nodes_to_device(
unflattened_module: nn.Module,
device: torch.device,
) -> nn.Module:
"""
Moves all the copy nodes to the given device.
"""
for nodes in unflattened_module.graph.nodes:
if "_to_copy" in nodes.name:
new_kwargs = {}
for k, v in nodes.kwargs.items():
if isinstance(v, torch.device):
v = device
new_kwargs[k] = v
nodes.kwargs = new_kwargs

return unflattened_module

0 comments on commit 9ec1bee

Please sign in to comment.