diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 1676f166c..a71595134 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -151,3 +151,22 @@ def mark_dynamic_kjt( if kjt._offsets is not None: shapes_collection[kjt._offsets] = (olen,) return shapes_collection + + +def move_to_copy_nodes_to_device( + unflattened_module: nn.Module, + device: torch.device, +) -> nn.Module: + """ + Moves all the copy nodes to the given device. + """ + for nodes in unflattened_module.graph.nodes: + if "_to_copy" in nodes.name: + new_kwargs = {} + for k, v in nodes.kwargs.items(): + if isinstance(v, torch.device): + v = device + new_kwargs[k] = v + nodes.kwargs = new_kwargs + + return unflattened_module