From 9ec1beee350ff9982ce9faf9602cbcaa5621268e Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Tue, 11 Jun 2024 12:39:00 -0700 Subject: [PATCH] Workaround for hardcoded device in PT2 IR Summary: PT2 IR can end up hardcoding the device during tracing, for example: https://fburl.com/code/m478zn5s. This diff provides a workaround util function move_to_copy_nodes_to_device to correctly set hardcoded devices to an appropriate file. Differential Revision: D58426261 --- torchrec/ir/utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 1676f166c..a71595134 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -151,3 +151,22 @@ def mark_dynamic_kjt( if kjt._offsets is not None: shapes_collection[kjt._offsets] = (olen,) return shapes_collection + + +def move_to_copy_nodes_to_device( + unflattened_module: nn.Module, + device: torch.device, +) -> nn.Module: + """ + Moves all the copy nodes to the given device. + """ + for nodes in unflattened_module.graph.nodes: + if "_to_copy" in nodes.name: + new_kwargs = {} + for k, v in nodes.kwargs.items(): + if isinstance(v, torch.device): + v = device + new_kwargs[k] = v + nodes.kwargs = new_kwargs + + return unflattened_module