Skip to content

Commit

Permalink
fix missing device in fake operator (#2203)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2203

# context
* when we have the input tensors on the meta device, it calls the fake operator
* however the device information is unintentionally missed so the output tensor is on the default device (cpu)
* this is an incorrect behavior

Reviewed By: gnahzg, iamzainhuda

Differential Revision: D57077813
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 3, 2024
1 parent 144fba9 commit 2b3bc0d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
17 changes: 17 additions & 0 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,23 @@ def test_dynamic_shape_ebc(self) -> None:
self.assertEqual(eager_out[i].shape, tensor.shape)
assert torch.allclose(eager_out[i], tensor)

def test_ir_custom_op_device(self) -> None:
model = self.generate_model()
model.fpebc1 = copy.deepcopy(model.ebc1)
model.fpebc2 = copy.deepcopy(model.ebc1)
feature1 = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
values=torch.tensor([0, 1, 2, 3, 2, 3]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]),
)

model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
for device in ["cpu", "cuda", "meta"]:
device = torch.device(device)
outputs = model.to(device)(feature1.to(device))
for output in outputs:
self.assertEqual(output.device.type, device.type)

def test_deserialized_device(self) -> None:
model = self.generate_model()
id_list_features = KeyedJaggedTensor.from_offsets_sync(
Expand Down
13 changes: 9 additions & 4 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from torch import nn
from torch.export import Dim, ExportedProgram, ShapesCollection
from torch.export import Dim, ShapesCollection
from torch.export.dynamic_shapes import _Dim as DIM
from torchrec.ir.types import SerializerInterface
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
Expand All @@ -37,16 +37,21 @@ def ir_custom_op_impl(
if t is not None:
device = t.device
break
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim})")
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim}) {device}")
return torch.empty(batch_size, dim, device=device)


@torch.library.register_fake("torchrec::ir_custom_op")
def ir_custom_op_fake(
tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int
) -> torch.Tensor:
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim})")
return torch.empty(batch_size, dim)
device = None
for t in tensors:
if t is not None:
device = t.device
break
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim}) {device}")
return torch.empty(batch_size, dim, device=device)


def encapsulate_ir_modules(
Expand Down

0 comments on commit 2b3bc0d

Please sign in to comment.