diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 91102238b..34794f9f0 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -9,6 +9,7 @@ import abc import operator + from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -21,6 +22,12 @@ try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu" + ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" + ) except OSError: pass @@ -128,18 +135,10 @@ def _assert_offsets_or_lengths_is_provided( @torch.fx.wrap +# keep for legacy use cases def _regroup_keyed_tensors( keyed_tensors: List["KeyedTensor"], groups: List[List[str]] ) -> List[torch.Tensor]: - # Shortcut for no re-grouping - if len(keyed_tensors) == len(groups): - match = True - for kt, group in zip(keyed_tensors, groups): - if kt.keys() != group: - match = False - break - if match: - return [kt.values() for kt in keyed_tensors] embedding_dicts = [keyed_tensor.to_dict() for keyed_tensor in keyed_tensors] lengths = [keyed_tensor.length_per_key() for keyed_tensor in keyed_tensors] @@ -165,6 +164,97 @@ def _regroup_keyed_tensors( return list(rearranged_values.split(split_lengths, dim=key_dim)) +@torch.fx.wrap +def _all_keys_used_once( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> bool: + flat_keys: List[str] = [] + flat_groups: List[str] = [] + for keyed_tensor in keyed_tensors: + flat_keys.extend(keyed_tensor.keys()) + for sub_group in groups: + flat_groups.extend(sub_group) + # jit.script does not support set, so we use a dict to represent the set + key_set: Dict[str, int] = {key: 1 for key in flat_keys} + group_set: Dict[str, int] = {key: 1 for key in flat_groups} + return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups) + + +@torch.fx.wrap +def _fbgemm_permute_pooled_embs( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups( + keys, lengths, groups + ) + values = torch.concat(values, dim=1) + permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad( + values, + offsets.to(device=values.device), + permute.to(device=values.device), + inv_offsets.to(device=values.device), + inv_permute.to(device=values.device), + ) + return list(torch.split(permuted_values, splits, dim=1)) + + +@torch.fx.wrap +def _desugar_keyed_tensors( + kts: List["KeyedTensor"], +) -> Tuple[List[List[str]], List[List[int]], List[torch.Tensor]]: + """ + Desugar a list of KeyedTensors into basic data structure + """ + return ( + [kt.keys() for kt in kts], + [kt.length_per_key() for kt in kts], + [kt.values() for kt in kts], + ) + + +@torch.fx.wrap +def _remap_to_groups( + keys: List[List[str]], + key_lengths: List[List[int]], + groups: List[List[str]], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """ + Given a list of keys and lengths per key for each group, return the permute indices, inverse_permute indices, offsets, inv_offsets, splits. + The output is used to re-arrange values based on groups with a single cat operation. + """ + + lengths: List[int] = [] + flat_keys: List[str] = [] + flat_groups: List[str] = [] + + for sub_keys_length in key_lengths: + lengths.extend(sub_keys_length) + for sub_keys in keys: + flat_keys.extend(sub_keys) + + for sub_group in groups: + flat_groups.extend(sub_group) + + key_splits = [len(sub_group) for sub_group in groups] + + index_map = {key: idx for idx, key in enumerate(flat_keys)} + permute = [index_map[key] for key in flat_groups] + inv_lengths = [lengths[i] for i in permute] + splits = _sum_by_splits(inv_lengths, key_splits) + + inv_permute = [0] * len(permute) + for i, p in enumerate(permute): + inv_permute[p] = i + + offsets = torch.tensor(_cumsum(lengths), dtype=torch.int64) + inv_offsets = torch.tensor(_cumsum(inv_lengths), dtype=torch.int64) + permute = torch.tensor(permute, dtype=torch.int64) + inv_permute = torch.tensor(inv_permute, dtype=torch.int64) + + return permute, inv_permute, offsets, inv_offsets, splits + + def _values_string(values: torch.Tensor, start: int, end: int) -> str: size = values.size() if len(size) == 1: @@ -2474,18 +2564,22 @@ def to_dict(self) -> Dict[str, torch.Tensor]: def regroup( keyed_tensors: List["KeyedTensor"], groups: List[List[str]] ) -> List[torch.Tensor]: - return _regroup_keyed_tensors(keyed_tensors, groups) + # Fast path, one-to-one correspondence between keyed_tensors and groups + if _all_keys_used_once(keyed_tensors, groups) is True: + return _fbgemm_permute_pooled_embs(keyed_tensors, groups) + else: # Fallback to slow path otherwise + return _regroup_keyed_tensors(keyed_tensors, groups) @staticmethod def regroup_as_dict( keyed_tensors: List["KeyedTensor"], groups: List[List[str]], keys: List[str] ) -> Dict[str, torch.Tensor]: + ret: Dict[str, torch.Tensor] = {} assert len(groups) == len(keys), "Groups and keys should have same length" - embeddings_list = _regroup_keyed_tensors(keyed_tensors, groups) - embeddings_dict: Dict[str, torch.Tensor] = {} + tensor_list = KeyedTensor.regroup(keyed_tensors, groups) for i, key in enumerate(keys): - embeddings_dict[key] = embeddings_list[i] - return embeddings_dict + ret[key] = tensor_list[i] + return ret @torch.jit.unused def record_stream(self, stream: torch.cuda.streams.Stream) -> None: diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index 84a0897e4..ae6368f0c 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -8,65 +8,214 @@ # pyre-strict -import time +import functools import timeit -from typing import Callable, List, Tuple +from typing import Any, Callable, Dict, List import click -import numpy as np - import torch -from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor +from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult +from torchrec.sparse.jagged_tensor import ( + _regroup_keyed_tensors, + KeyedJaggedTensor, + KeyedTensor, +) +from torchrec.sparse.tests.utils import build_groups, build_kts -def prepare_benchmark( - dense_features: int, sparse_features: int -) -> Tuple[List["KeyedTensor"], List[List[str]]]: - key_dim = 1 - tensor_list_1 = [torch.randn(2, 3) for i in range(dense_features)] - keys_1 = [f"dense_{i}" for i in range(dense_features)] - kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) - tensor_list_2 = [torch.randn(2, 3) for i in range(sparse_features)] - keys_2 = [f"sparse_{i}" for i in range(sparse_features)] - kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) - return ([kt_1, kt_2], [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]]) +class DummyModel(torch.nn.Module): + # pyre-ignore + def forward(self, *args, **kwargs) -> None: + pass def bench( name: str, - fn: Callable[[List["KeyedTensor"], List[List[str]]], List[torch.Tensor]], - n_dense: int, - n_sparse: int, + labels: torch.Tensor, + batch_size: int, + feature_count: int, + device_type: str, + run_backward: bool, + fn: Callable[..., List[torch.Tensor]], + fn_kwargs: Dict[str, Any], ) -> None: - input_data = prepare_benchmark(n_dense, n_sparse) - start = time.perf_counter() - for _ in range(3): - fn(input_data[0], input_data[1]) - end = time.perf_counter() - print(f"warmup time {(end-start)*1000:.1f}ms") - results = timeit.repeat( - lambda: fn(input_data[0], input_data[1]), number=10, repeat=10 - ) - p_50 = np.percentile(np.asarray(results), 50) - print(f"{name} {p_50*1000:.1f}us") + # initial call + fn(**fn_kwargs) + + def wrapped_func( + model: torch.nn.Module, # not used + bench_inputs: List[KeyedJaggedTensor], # not used + fn: Callable[..., List[torch.Tensor]], + fn_kwargs: Dict[str, Any], + run_backward: bool, + ) -> None: + result = fn(**fn_kwargs) + if run_backward: + vectors = [tensor.sum(dim=1) for tensor in result] + pred = vectors[0] + for vector in vectors[1:]: + pred.mul(vector) + loss = torch.nn.functional.l1_loss(pred, labels) + loss.sum().backward() + + if device_type == "cuda": + result = benchmark( + name=name, + model=DummyModel(), + warmup_inputs=[], + bench_inputs=[], + prof_inputs=[], + world_size=1, + output_dir="", + num_benchmarks=20, + func_to_benchmark=functools.partial( + wrapped_func, fn=fn, run_backward=run_backward, fn_kwargs=fn_kwargs + ), + benchmark_func_kwargs={}, + rank=0, + enable_logging=False, + ) + + else: # cpu + model = DummyModel() + times = timeit.repeat( + lambda: wrapped_func( + model=model, + bench_inputs=[], + fn=fn, + fn_kwargs=fn_kwargs, + run_backward=run_backward, + ), + number=1, + repeat=20, + ) + result = BenchmarkResult( + short_name=name, + elapsed_time=torch.tensor(times), + max_mem_allocated=[0], + ) + + print( + f" {name : <{35}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.1f} ms | Memory (P90): {result.max_mem_percentile(90):5.1f}" + ) @click.command() +@click.option( + "--cuda_matrix", + type=bool, + default=False, + help="Run a full GPU matrix, overrides relevant settings", +) +@click.option( + "--run_backward", + type=bool, + default=False, + help="run backward (forward always runs)", +) +@click.option( + "--device_type", + type=str, + default="cuda", + help="device type", +) @click.option( "--n_dense", type=int, - default=2000, - help="Total number of dense embeddings to be used.", + default=20, + help="Total number of dense embeddings.", +) +@click.option( + "--dim_dense", + type=int, + default=64, + help="Dim dense embedding.", ) @click.option( "--n_sparse", - default=3000, + default=1000, help="Total number of sparse embeddings to be used.", ) -def main(n_dense: int, n_sparse: int) -> None: - bench("regular ", _regroup_keyed_tensors, n_dense, n_sparse) +@click.option( + "--dim_sparse", + type=int, + default=128, + help="Dim dense embedding.", +) +@click.option( + "--batch_size", + type=int, + default=1024, + help="Batch size.", +) +@click.option( + "--n_groups", + type=int, + default=2, + help="Total num of regrouping", +) +def main( + cuda_matrix: bool, + run_backward: bool, + device_type: str, + n_dense: int, + n_sparse: int, + dim_dense: int, + dim_sparse: int, + batch_size: int, + n_groups: int, +) -> None: + if cuda_matrix: + n_denses = [64, 128, 256, 512, 1024] + n_sparses = [16, 32, 64, 128, 256] + batch_sizes = [512, 1024, 2048, 4096] + device_types = ["cuda"] + else: + n_denses = [n_dense] + n_sparses = [n_sparse] + batch_sizes = [batch_size] + device_types = [device_type] + + for device_type in device_types: + for batch_size in batch_sizes: + for n_dense, n_sparse in zip(n_denses, n_sparses): + + device = torch.device(device_type) + kts = build_kts( + n_dense, + n_sparse, + dim_dense, + dim_sparse, + batch_size, + device, + run_backward, + ) + labels = torch.randint( + 0, 1, (batch_size,), device=torch.device(device_type) + ).float() + groups = build_groups(kts, n_groups) + bench( + "[fallback] _regroup_keyed_tenors", + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + _regroup_keyed_tensors, + {"keyed_tensors": kts, "groups": groups}, + ) + bench( + "[prod] KeyedTensor.regroup", + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + KeyedTensor.regroup, + {"keyed_tensors": kts, "groups": groups}, + ) if __name__ == "__main__": diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 67b6d77c8..9efeb444c 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -9,13 +9,14 @@ import unittest -from typing import List, Tuple +from typing import Dict, List, Tuple import torch import torch.utils._pytree as pytree from torch.testing import FileCheck from torchrec.fx import symbolic_trace from torchrec.sparse.jagged_tensor import ( + _regroup_keyed_tensors, ComputeJTDictToKJT, ComputeKJTToJTDict, JaggedTensor, @@ -24,6 +25,7 @@ KeyedTensor, kjt_is_equal, ) +from torchrec.sparse.tests.utils import build_groups, build_kts torch.fx.wrap("len") @@ -2170,10 +2172,10 @@ def test_regroup_single_kt(self) -> None: def test_regroup_multiple_kt(self) -> None: key_dim = 1 - tensor_list_1 = [torch.randn(2, 3) for i in range(3)] + tensor_list_1 = [torch.randn(2, 4), torch.randn(2, 8), torch.randn(2, 2)] keys_1 = ["dense_0", "dense_1", "dense_2"] kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) - tensor_list_2 = [torch.randn(2, 3) for i in range(2)] + tensor_list_2 = [torch.randn(2, 3), torch.randn(2, 10)] keys_2 = ["sparse_0", "sparse_1"] kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) grouped_tensors = KeyedTensor.regroup( @@ -2194,11 +2196,117 @@ def test_regroup_multiple_kt(self) -> None: ) ) + def test_regroup_backward_skips_and_duplicates(self) -> None: + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=torch.device("cpu"), + run_backward=True, + ) + groups = build_groups(kts=kts, num_groups=2, skips=True, duplicates=True) + labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + + tensor_groups = KeyedTensor.regroup(kts, groups) + pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, labels).sum() + actual_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + actual_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + # clear grads are return + kts[0].values().grad = None + kts[1].values().grad = None + + tensor_groups = _regroup_keyed_tensors(kts, groups) + pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, labels).sum() + expected_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + expected_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + def test_regroup_backward(self) -> None: + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=torch.device("cpu"), + run_backward=True, + ) + groups = build_groups(kts=kts, num_groups=2, skips=False, duplicates=False) + labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + + tensor_groups = KeyedTensor.regroup(kts, groups) + pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, labels).sum() + actual_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + actual_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + # clear grads are return + kts[0].values().grad = None + kts[1].values().grad = None + + tensor_groups = _regroup_keyed_tensors(kts, groups) + pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, labels).sum() + expected_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + expected_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + def test_regroup_multiple_kt_duplicate_keys(self) -> None: + key_dim = 1 + tensor_list_1 = [torch.randn(2, 4) for i in range(2)] + keys_1 = ["dense_0", "dense_1"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3) for i in range(3)] + keys_2 = ["sparse_0", "sparse_1", "dense_2"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + grouped_tensors = KeyedTensor.regroup( + [kt_1, kt_2], [["dense_0", "sparse_1"], ["dense_1", "sparse_0", "dense_0"]] + ) + self.assertTrue( + torch.equal( + grouped_tensors[0], + torch.cat([tensor_list_1[0], tensor_list_2[1]], key_dim), + ) + ) + self.assertTrue( + torch.equal( + grouped_tensors[1], + torch.cat( + [tensor_list_1[1], tensor_list_2[0], tensor_list_1[0]], key_dim + ), + ) + ) + def test_regroup_scriptable(self) -> None: class MyModule(torch.nn.Module): - def forward( - self, inputs: List[KeyedTensor], groups: List[List[str]] - ) -> List[torch.Tensor]: + def forward(self, inputs: List[KeyedTensor]) -> List[torch.Tensor]: + # user provided, not model input + groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] return KeyedTensor.regroup(inputs, groups) m = MyModule() @@ -2232,6 +2340,43 @@ def forward( for result, traced_result in zip(results, traced_results): self.assertTrue(torch.equal(result, traced_result)) + def test_regroup_as_dict_scriptable(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, inputs: List[KeyedTensor]) -> Dict[str, torch.Tensor]: + groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + keys = ["group_0", "group_1"] + return KeyedTensor.regroup_as_dict(inputs, groups, keys) + + m = MyModule() + torch.jit.script(m) + + def test_regroup_as_dict_fxable(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, inputs: List[KeyedTensor]) -> Dict[str, torch.Tensor]: + groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + keys = ["group_0", "group_1"] + return KeyedTensor.regroup_as_dict(inputs, groups, keys) + + m = MyModule() + + # input + key_dim = 1 + tensor_list_1 = [torch.randn(2, 3) for i in range(3)] + keys_1 = ["dense_0", "dense_1", "dense_2"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3) for i in range(2)] + keys_2 = ["sparse_0", "sparse_1"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + inputs = [kt_1, kt_2] + + # ensure that symbolic tracing works + gm = torch.fx.symbolic_trace(m) + results = m(inputs) + traced_results = gm(inputs) + self.assertEqual(len(results), len(traced_results)) + for result, traced_result in zip(results.values(), traced_results.values()): + self.assertTrue(torch.equal(result, traced_result)) + def test_scriptable(self) -> None: class MyModule(torch.nn.Module): def forward(self, input: KeyedTensor) -> torch.Tensor: diff --git a/torchrec/sparse/tests/test_jagged_tensor_gpu.py b/torchrec/sparse/tests/test_jagged_tensor_gpu.py new file mode 100644 index 000000000..4e5975ae1 --- /dev/null +++ b/torchrec/sparse/tests/test_jagged_tensor_gpu.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import torch +from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor +from torchrec.sparse.tests.utils import build_groups, build_kts +from torchrec.test_utils import skip_if_asan_class + + +@skip_if_asan_class +class TestKeyedTensorGPU(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.device = torch.cuda.current_device() + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_regroup_backward_skips_and_duplicates(self) -> None: + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=self.device, + run_backward=True, + ) + groups = build_groups(kts=kts, num_groups=2, skips=True, duplicates=True) + labels = torch.randint(0, 1, (128,), device=self.device).float() + + tensor_groups = KeyedTensor.regroup(kts, groups) + pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, labels).sum() + actual_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + actual_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + # clear grads are return + kts[0].values().grad = None + kts[1].values().grad = None + + tensor_groups = _regroup_keyed_tensors(kts, groups) + pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, labels).sum() + expected_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + expected_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_regroup_backward(self) -> None: + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=self.device, + run_backward=True, + ) + groups = build_groups(kts=kts, num_groups=2, skips=False, duplicates=False) + labels = torch.randint(0, 1, (128,), device=self.device).float() + + tensor_groups = KeyedTensor.regroup(kts, groups) + pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, labels).sum() + actual_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + actual_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + # clear grads are return + kts[0].values().grad = None + kts[1].values().grad = None + + tensor_groups = _regroup_keyed_tensors(kts, groups) + pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, labels).sum() + expected_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + expected_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) diff --git a/torchrec/sparse/tests/utils.py b/torchrec/sparse/tests/utils.py new file mode 100644 index 000000000..46899b6c6 --- /dev/null +++ b/torchrec/sparse/tests/utils.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import random +from typing import List + +import torch +from torchrec.sparse.jagged_tensor import KeyedTensor + + +def build_kts( + dense_features: int, + sparse_features: int, + dim_dense: int, + dim_sparse: int, + batch_size: int, + device: torch.device, + run_backward: bool, +) -> List[KeyedTensor]: + key_dim = 1 + dense_embs = [ + torch.randn(batch_size, dim_dense, device=device, requires_grad=run_backward) + for i in range(dense_features) + ] + dense_keys = [f"dense_{i}" for i in range(dense_features)] + dense_kt = KeyedTensor.from_tensor_list(dense_keys, dense_embs, key_dim) + + sparse_embs = [ + torch.randn(batch_size, dim_sparse, device=device, requires_grad=run_backward) + for i in range(sparse_features) + ] + sparse_keys = [f"sparse_{i}" for i in range(sparse_features)] + sparse_kt = KeyedTensor.from_tensor_list(sparse_keys, sparse_embs, key_dim) + return [dense_kt, sparse_kt] + + +def build_groups( + kts: List[KeyedTensor], + num_groups: int, + skips: bool = False, + duplicates: bool = False, +) -> List[List[str]]: + all_keys = [] + for kt in kts: + all_keys.extend(kt.keys()) + allocation = [random.randint(0, num_groups - 1) for _ in range(len(all_keys))] + groups = [[] for _ in range(num_groups)] + for i, key in enumerate(allocation): + groups[key].append(all_keys[i]) + if skips: + for group in groups: + if len(group) > 1: + group.pop(random.randint(0, len(group) - 1)) + if duplicates: + for group in groups: + group.append(random.choice(all_keys)) + return groups