Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor _build_args_kwards into an instance method on CallArgs + ArgInfo (#2742) #2743

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,103 @@ def forward(
return pred.sum(), pred


class TestModelWithPreprocCollectionArgs(nn.Module):
"""
Basic module with up to 3 postproc modules:
- postproc on idlist_features for non-weighted EBC
- postproc on idscore_features for weighted EBC
- postproc_inner on model input shared by both EBCs
- postproc_outer providing input to postproc_b (aka nested postproc)

Args:
tables,
weighted_tables,
device,
postproc_module_a,
postproc_module_b,
num_float_features,

Example:
>>> TestModelWithPreprocWithListArg(tables, weighted_tables, device)

Returns:
Tuple[torch.Tensor, torch.Tensor]
"""

CONST_DICT_KEY = "const"
INPUT_TENSOR_DICT_KEY = "tensor_from_input"
POSTPTOC_TENSOR_DICT_KEY = "tensor_from_postproc"

def __init__(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
device: torch.device,
postproc_module_outer: nn.Module,
postproc_module_nested: nn.Module,
num_float_features: int = 10,
) -> None:
super().__init__()
self.dense = TestDenseArch(num_float_features, device)

self.ebc: EmbeddingBagCollection = EmbeddingBagCollection(
tables=tables,
device=device,
)
self.weighted_ebc = EmbeddingBagCollection(
tables=weighted_tables,
is_weighted=True,
device=device,
)
self.postproc_nonweighted = TestPreprocNonWeighted()
self.postproc_weighted = TestPreprocWeighted()
self._postproc_module_outer = postproc_module_outer
self._postproc_module_nested = postproc_module_nested

def forward(
self,
input: ModelInput,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Runs preproc for EBC and weighted EBC, optionally runs postproc for input

Args:
input
Returns:
Tuple[torch.Tensor, torch.Tensor]
"""
modified_input = input

outer_postproc_input = self._postproc_module_outer(modified_input)

preproc_input_list = [
1,
modified_input.float_features,
outer_postproc_input,
]
preproc_input_dict = {
self.CONST_DICT_KEY: 1,
self.INPUT_TENSOR_DICT_KEY: modified_input.float_features,
self.POSTPTOC_TENSOR_DICT_KEY: outer_postproc_input,
}

modified_input = self._postproc_module_nested(
modified_input, preproc_input_list, preproc_input_dict
)

modified_idlist_features = self.postproc_nonweighted(
modified_input.idlist_features
)
modified_idscore_features = self.postproc_weighted(
modified_input.idscore_features
)
ebc_out = self.ebc(modified_idlist_features[0])
weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0])

pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1)
return pred.sum(), pred


class TestNegSamplingModule(torch.nn.Module):
"""
Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/train_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
_to_device, # noqa
_wait_for_batch, # noqa
ArgInfo, # noqa
ArgInfoStepFactory, # noqa
CallArgs, # noqa
DataLoadingThread, # noqa
In, # noqa
Out, # noqa
Expand Down
192 changes: 143 additions & 49 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from contextlib import ExitStack
from dataclasses import dataclass
from functools import partial
from typing import cast, List, Optional, Tuple, Type, Union
from typing import cast, Dict, List, Optional, Tuple, Type, Union
from unittest.mock import MagicMock

import torch
from hypothesis import given, settings, strategies as st, Verbosity
from torch import nn, optim
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._dynamo.utils import counters
from torch.fx._symbolic_trace import is_fx_tracing
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
Expand All @@ -36,6 +37,7 @@
ModelInput,
TestEBCSharder,
TestModelWithPreproc,
TestModelWithPreprocCollectionArgs,
TestNegSamplingModule,
TestPositionWeightedPreprocModule,
TestSparseNN,
Expand All @@ -58,6 +60,7 @@
TrainPipelineSparseDistCompAutograd,
)
from torchrec.distributed.train_pipeline.utils import (
ArgInfoStep,
DataLoadingThread,
get_h2d_func,
PipelinedForward,
Expand Down Expand Up @@ -1021,23 +1024,26 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:

# Check pipelined args
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
self.assertEqual(len(ebc.forward._args), 1)
self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2)
self.assertIsInstance(
ebc.forward._args[0].postproc_modules[0], PipelinedPostproc
)
self.assertEqual(ebc.forward._args[0].postproc_modules[1], None)
self.assertEqual(len(ebc.forward._args.args), 1)
self.assertEqual(len(ebc.forward._args.kwargs), 0)
self.assertEqual(len(ebc.forward._args.args[0].steps), 2)
[step1, step2] = ebc.forward._args.args[0].steps

self.assertEqual(step1.input_attr, "")
self.assertEqual(step1.is_getitem, False)
self.assertEqual(step2.input_attr, 0)
self.assertEqual(step2.is_getitem, True)
self.assertIsNotNone(step1.postproc_module)
self.assertIsNone(step2.postproc_module)

self.assertEqual(
pipelined_ebc.forward._args[0].postproc_modules[0],
pipelined_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
pipelined_model.module.postproc_nonweighted,
)
self.assertEqual(
pipelined_weighted_ebc.forward._args[0].postproc_modules[0],
pipelined_weighted_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
pipelined_model.module.postproc_weighted,
Expand All @@ -1048,16 +1054,20 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
input_attr_names = {"idlist_features", "idscore_features"}
for i in range(len(pipeline._pipelined_postprocs)):
postproc_mod = pipeline._pipelined_postprocs[i]
self.assertEqual(len(postproc_mod._args), 1)
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
self.assertEqual(len(postproc_mod._args.args[0].steps), 2)
[step1, step2] = postproc_mod._args.args[0].steps

self.assertTrue(step2.input_attr in input_attr_names)

input_attr_name = postproc_mod._args[0].input_attrs[1]
self.assertTrue(input_attr_name in input_attr_names)
self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name])
input_attr_names.remove(input_attr_name)
input_attr_names.remove(step2.input_attr)

self.assertEqual(postproc_mod._args[0].is_getitems, [False, False])
self.assertFalse(step1.is_getitem)
self.assertFalse(step2.is_getitem)
# no parent postproc module in FX graph
self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None])
self.assertIsNone(step1.postproc_module)
self.assertIsNone(step2.postproc_module)

# pyre-ignore
@unittest.skipIf(
Expand Down Expand Up @@ -1104,23 +1114,26 @@ def test_pipeline_postproc_recursive(self) -> None:

# Check pipelined args
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
self.assertEqual(len(ebc.forward._args), 1)
self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2)
self.assertIsInstance(
ebc.forward._args[0].postproc_modules[0], PipelinedPostproc
)
self.assertEqual(ebc.forward._args[0].postproc_modules[1], None)
self.assertEqual(len(ebc.forward._args.args), 1)
self.assertEqual(len(ebc.forward._args.kwargs), 0)
self.assertEqual(len(ebc.forward._args.args[0].steps), 2)
[step1, step2] = ebc.forward._args.args[0].steps

self.assertEqual(step1.input_attr, "")
self.assertEqual(step1.is_getitem, False)
self.assertEqual(step2.input_attr, 0)
self.assertEqual(step2.is_getitem, True)
self.assertIsNotNone(step1.postproc_module)
self.assertIsNone(step2.postproc_module)

self.assertEqual(
pipelined_ebc.forward._args[0].postproc_modules[0],
pipelined_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
pipelined_model.module.postproc_nonweighted,
)
self.assertEqual(
pipelined_weighted_ebc.forward._args[0].postproc_modules[0],
pipelined_weighted_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
pipelined_model.module.postproc_weighted,
Expand All @@ -1137,35 +1150,41 @@ def test_pipeline_postproc_recursive(self) -> None:
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
if postproc_mod == pipelined_model.module.postproc_nonweighted:
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(args.input_attrs, ["", "idlist_features"])
self.assertEqual(args.is_getitems, [False, False])
self.assertEqual(len(args.postproc_modules), 2)
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
args = postproc_mod._args.args[0]
self.assertEqual(len(args.steps), 2)
self.assertEqual(
args.postproc_modules[0],
parent_postproc_mod,
[step.input_attr for step in args.steps], ["", "idlist_features"]
)
self.assertEqual(args.postproc_modules[1], None)
self.assertEqual(
[step.is_getitem for step in args.steps], [False, False]
)
self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod)
self.assertIsNone(args.steps[1].postproc_module)
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
elif postproc_mod == pipelined_model.module.postproc_weighted:
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(args.input_attrs, ["", "idscore_features"])
self.assertEqual(args.is_getitems, [False, False])
self.assertEqual(len(args.postproc_modules), 2)
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
args = postproc_mod._args.args[0]
self.assertEqual(len(args.steps), 2)
self.assertEqual(
args.postproc_modules[0],
parent_postproc_mod,
[step.input_attr for step in args.steps], ["", "idscore_features"]
)
self.assertEqual(args.postproc_modules[1], None)
self.assertEqual(
[step.is_getitem for step in args.steps], [False, False]
)
self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod)
self.assertIsNone(args.steps[1].postproc_module)
elif postproc_mod == parent_postproc_mod:
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(args.input_attrs, [""])
self.assertEqual(args.is_getitems, [False])
self.assertEqual(args.postproc_modules, [None])
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
args = postproc_mod._args.args[0]
self.assertEqual(len(args.steps), 1)
self.assertEqual(args.steps[0].input_attr, "")
self.assertFalse(args.steps[0].is_getitem)
self.assertIsNone(args.steps[0].postproc_module)

# pyre-ignore
@unittest.skipIf(
Expand Down Expand Up @@ -1448,6 +1467,81 @@ def forward(
self.assertEqual(len(pipeline._pipelined_modules), 2)
self.assertEqual(len(pipeline._pipelined_postprocs), 1)

# pyre-ignore
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_pipeline_postproc_with_collection_args(self) -> None:
"""
Exercises scenario when postproc module has an argument that is a list or dict
with some elements being:
* static scalars
* static tensors (e.g. torch.ones())
* tensors derived from input batch (e.g. input.idlist_features["feature_0"])
* tensors derived from input batch and other postproc module (e.g. other_postproc(input.idlist_features["feature_0"]))
"""
test_runner = self

class PostprocOuter(nn.Module):
def __init__(
self,
) -> None:
super().__init__()

def forward(
self,
model_input: ModelInput,
) -> torch.Tensor:
return model_input.float_features * 0.1

class PostprocInner(nn.Module):
def __init__(
self,
) -> None:
super().__init__()

def forward(
self,
model_input: ModelInput,
input_list: List[Union[torch.Tensor, int]],
input_dict: Dict[str, Union[torch.Tensor, int]],
) -> ModelInput:
if not is_fx_tracing():
for idx, value in enumerate(input_list):
if isinstance(value, torch.fx.Node):
test_runner.fail(
f"input_list[{idx}] was a fx.Node: {value}"
)
model_input.float_features += value

for key, value in input_dict.items():
if isinstance(value, torch.fx.Node):
test_runner.fail(
f"input_dict[{key}] was a fx.Node: {value}"
)
model_input.float_features += value

return model_input

model = TestModelWithPreprocCollectionArgs(
tables=self.tables[:-1], # ignore last table as postproc will remove
weighted_tables=self.weighted_tables[:-1], # ignore last table
device=self.device,
postproc_module_outer=PostprocOuter(),
postproc_module_nested=PostprocInner(),
)

pipelined_model, pipeline = self._check_output_equal(
model,
self.sharding_type,
)

# both EC end EBC are pipelined
self.assertEqual(len(pipeline._pipelined_modules), 2)
# both outer and nested postproces are pipelined
self.assertEqual(len(pipeline._pipelined_postprocs), 4)


class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
@unittest.skipIf(
Expand Down
Loading
Loading