diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index e0a19547c..2f894b8ae 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -9,7 +9,11 @@ import abc import copy +import inspect import itertools +import logging +import tempfile +from collections import OrderedDict from dataclasses import dataclass from typing import ( Any, @@ -37,6 +41,10 @@ SparseType, SplitTableBatchedEmbeddingBagsCodegen, ) +from fbgemm_gpu.ssd_split_table_batched_embeddings_ops import ( + ASSOC, + SSDTableBatchedEmbeddingBags, +) from torch import nn from torchrec.distributed.composable.table_batched_embedding_slice import ( TableBatchedEmbeddingSlice, @@ -66,6 +74,102 @@ ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +logger: logging.Logger = logging.getLogger(__name__) + + +def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: + fused_params = config.fused_params or {} + + ssd_tbe_params: Dict[str, Any] = {} + + # drop the non-ssd tbe fused params + ssd_tbe_signature = inspect.signature( + SSDTableBatchedEmbeddingBags.__init__ + ).parameters.keys() + for key, value in fused_params.items(): + if key not in ssd_tbe_signature: + logger.warning(f"{key} is not a valid ssd tbe fused param, dropping now.") + else: + ssd_tbe_params[key] = value + + # populate number cache sets, aka number of rows of the cache space + if "cache_sets" not in ssd_tbe_params: + cache_load_factor = fused_params.get("cache_load_factor") + if cache_load_factor: + cache_load_factor = fused_params.get("cache_load_factor") + logger.info( + f"Using cache load factor from fused params dict: {cache_load_factor}" + ) + else: + cache_load_factor = 0.2 + + local_rows_sum: int = sum(table.local_rows for table in config.embedding_tables) + ssd_tbe_params["cache_sets"] = int(cache_load_factor * local_rows_sum / ASSOC) + + # populate init min and max + if ( + "ssd_uniform_init_lower" not in ssd_tbe_params + or "ssd_uniform_init_upper" not in ssd_tbe_params + ): + # Right now we do not support a per table init max and min. To use + # per table init max and min, either we allow it in SSD TBE, or we + # create one SSD TBE per table. + # TODO: Solve the init problem + mins = [table.get_weight_init_min() for table in config.embedding_tables] + maxs = [table.get_weight_init_max() for table in config.embedding_tables] + ssd_tbe_params["ssd_uniform_init_lower"] = sum(mins) / len( + config.embedding_tables + ) + ssd_tbe_params["ssd_uniform_init_upper"] = sum(maxs) / len( + config.embedding_tables + ) + + if "ssd_storage_directory" not in ssd_tbe_params: + ssd_tbe_params["ssd_storage_directory"] = tempfile.mkdtemp() + + if "weights_precision" not in ssd_tbe_params: + weights_precision = data_type_to_sparse_type(config.data_type) + ssd_tbe_params["weights_precision"] = weights_precision + + return ssd_tbe_params + + +class KeyValueEmbeddingFusedOptimizer(FusedOptimizer): + def __init__( + self, + config: GroupedEmbeddingConfig, + emb_module: SSDTableBatchedEmbeddingBags, + pg: Optional[dist.ProcessGroup] = None, + ) -> None: + """ + Fused optimizer for SSD TBE. Right now it only supports tuning learning + rate. + """ + self._emb_module: SSDTableBatchedEmbeddingBags = emb_module + self._pg = pg + + # TODO: access momentum1_dev of SSD TBE after figuring out if it helps with checkpointing. + + # pyre-ignore [33] + state: Dict[Any, Any] = {} + param_group: Dict[str, Any] = { + "params": [], + "lr": emb_module.optimizer_args.learning_rate, + } + + params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {} + + super().__init__(params, state, [param_group]) + + def zero_grad(self, set_to_none: bool = False) -> None: + # pyre-ignore [16] + self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + # pyre-ignore [16] + self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) + class EmbeddingFusedOptimizer(FusedOptimizer): def __init__( # noqa C901 @@ -376,6 +480,24 @@ def step(self, closure: Any = None) -> None: self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) +def _gen_named_parameters_by_table_ssd( + emb_module: SSDTableBatchedEmbeddingBags, + table_name_to_count: Dict[str, int], + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, +) -> Iterator[Tuple[str, nn.Parameter]]: + """ + Return an empty tensor to indicate that the table is on remote device. + """ + for table in config.embedding_tables: + table_name = table.name + # placeholder + weight: nn.Parameter = nn.Parameter(torch.empty(0)) + # pyre-ignore + weight._in_backward_optimizers = [EmptyFusedOptimizer()] + yield (table_name, weight) + + def _gen_named_parameters_by_table_fused( emb_module: SplitTableBatchedEmbeddingBagsCodegen, table_name_to_count: Dict[str, int], @@ -563,6 +685,130 @@ def named_parameters_by_table( yield name, param +class KeyValueEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__(config, pg, device) + + assert ( + len(config.embedding_tables) > 0 + ), "Expected to see at least one table in SSD TBE, but found 0." + assert ( + len({table.embedding_dim for table in config.embedding_tables}) == 1 + ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + + ssd_tbe_params = _populate_ssd_tbe_params(config) + compute_kernel = config.embedding_tables[0].compute_kernel + embedding_location = compute_kernel_to_embedding_location(compute_kernel) + + self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags( + embedding_specs=list(zip(self._local_rows, self._local_cols)), + feature_table_map=self._feature_table_map, + ssd_cache_location=embedding_location, + pooling_mode=PoolingMode.NONE, + **ssd_tbe_params, + ).to(device) + + self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer( + config, + self._emb_module, + pg, + ) + self._param_per_table: Dict[str, nn.Parameter] = dict( + _gen_named_parameters_by_table_ssd( + emb_module=self._emb_module, + table_name_to_count=self.table_name_to_count.copy(), + config=self._config, + pg=pg, + ) + ) + self.init_parameters() + + def init_parameters(self) -> None: + """ + An advantage of SSD TBE is that we don't need to init weights. Hence skipping. + """ + pass + + @property + def emb_module( + self, + ) -> SSDTableBatchedEmbeddingBags: + return self._emb_module + + @property + def fused_optimizer(self) -> FusedOptimizer: + """ + SSD Embedding fuses backward with backward. + """ + return self._optim + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + + return destination + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + # WIP + """ + Only allowed ways to get state_dict. + """ + for name, tensor in self.named_split_embedding_weights( + prefix, recurse, remove_duplicate + ): + # hack before we support optimizer on sharded parameter level + # can delete after PEA deprecation + param = nn.Parameter(tensor) + # pyre-ignore + param._in_backward_optimizers = [EmptyFusedOptimizer()] + yield name, param + + def named_split_embedding_weights( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + assert ( + remove_duplicate + ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights(), + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, tensor + + def flush(self) -> None: + """ + Flush the embeddings in cache back to SSD. Should be pretty expensive. + """ + self.emb_module.flush() + + def purge(self) -> None: + """ + Reset the cache space. This is needed when we load state dict. + """ + # TODO: move the following to SSD TBE. + self.emb_module.lxu_cache_weights.zero_() + self.emb_module.lxu_cache_state.fill_(-1) + + def split_embedding_weights(self) -> List[torch.Tensor]: + """ + Return fake tensors. + """ + return [param.data for param in self._param_per_table.values()] + + class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule): def __init__( self, @@ -856,6 +1102,132 @@ def named_parameters_by_table( yield name, param +class KeyValueEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, + ) -> None: + super().__init__(config, pg, device, sharding_type) + + assert ( + len(config.embedding_tables) > 0 + ), "Expected to see at least one table in SSD TBE, but found 0." + assert ( + len({table.embedding_dim for table in config.embedding_tables}) == 1 + ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + + ssd_tbe_params = _populate_ssd_tbe_params(config) + compute_kernel = config.embedding_tables[0].compute_kernel + embedding_location = compute_kernel_to_embedding_location(compute_kernel) + + self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags( + embedding_specs=list(zip(self._local_rows, self._local_cols)), + feature_table_map=self._feature_table_map, + ssd_cache_location=embedding_location, + pooling_mode=self._pooling, + **ssd_tbe_params, + ).to(device) + + self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer( + config, + self._emb_module, + pg, + ) + self._param_per_table: Dict[str, nn.Parameter] = dict( + _gen_named_parameters_by_table_ssd( + emb_module=self._emb_module, + table_name_to_count=self.table_name_to_count.copy(), + config=self._config, + pg=pg, + ) + ) + self.init_parameters() + + def init_parameters(self) -> None: + """ + An advantage of SSD TBE is that we don't need to init weights. Hence + skipping. + """ + pass + + @property + def emb_module( + self, + ) -> SSDTableBatchedEmbeddingBags: + return self._emb_module + + @property + def fused_optimizer(self) -> FusedOptimizer: + """ + SSD Embedding fuses backward with backward. + """ + return self._optim + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + + return destination + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + # WIP + """ + Only allowed ways to get state_dict. + """ + for name, tensor in self.named_split_embedding_weights( + prefix, recurse, remove_duplicate + ): + # hack before we support optimizer on sharded parameter level + # can delete after PEA deprecation + param = nn.Parameter(tensor) + # pyre-ignore + param._in_backward_optimizers = [EmptyFusedOptimizer()] + yield name, param + + def named_split_embedding_weights( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + assert ( + remove_duplicate + ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights(), + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, tensor + + def flush(self) -> None: + """ + Flush the embeddings in cache back to SSD. Should be pretty expensive. + """ + self.emb_module.flush() + + def purge(self) -> None: + """ + Reset the cache space. This is needed when we load state dict. + """ + # TODO: move the following to SSD TBE. + self.emb_module.lxu_cache_weights.zero_() + self.emb_module.lxu_cache_state.fill_(-1) + + def split_embedding_weights(self) -> List[torch.Tensor]: + """ + Return fake tensors. + """ + return [param.data for param in self._param_per_table.values()] + + class BatchedFusedEmbeddingBag( BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule ): diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 412d0811a..74dd86ff9 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -12,7 +12,6 @@ import logging import warnings from collections import defaultdict, deque, OrderedDict -from dataclasses import dataclass, field from itertools import accumulate from typing import Any, cast, Dict, List, MutableMapping, Optional, Tuple, Type, Union @@ -654,6 +653,10 @@ def _initialize_torch_state(self) -> None: # noqa table_name, local_shards, ) in self._model_parallel_name_to_local_shards.items(): + if model_parallel_name_to_compute_kernel[table_name] in { + EmbeddingComputeKernel.SSD.value + }: + continue # for shards that don't exist on this rank, register with empty tensor if not hasattr(self.embeddings[table_name], "weight"): self.embeddings[table_name].register_parameter( @@ -702,6 +705,10 @@ def reset_parameters(self) -> None: return # Initialize embedding weights with init_fn for table_config in self._embedding_configs: + if self.module_sharding_plan[table_config.name].compute_kernel in { + EmbeddingComputeKernel.SSD.value, + }: + continue assert table_config.init_fn is not None param = self.embeddings[f"{table_config.name}"].weight # pyre-ignore diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index f2476ec1d..b257f5e08 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -31,6 +31,8 @@ BatchedDenseEmbeddingBag, BatchedFusedEmbedding, BatchedFusedEmbeddingBag, + KeyValueEmbedding, + KeyValueEmbeddingBag, ) from torchrec.distributed.comm_ops import get_gradient_division from torchrec.distributed.composable.table_batched_embedding_slice import ( @@ -168,6 +170,14 @@ def _create_lookup( pg=pg, device=device, ) + elif config.compute_kernel in { + EmbeddingComputeKernel.SSD, + }: + return KeyValueEmbedding( + config=config, + pg=pg, + device=device, + ) else: raise ValueError( f"Compute kernel not supported {config.compute_kernel}" @@ -368,6 +378,15 @@ def _create_lookup( device=device, sharding_type=sharding_type, ) + elif config.compute_kernel in { + EmbeddingComputeKernel.SSD, + }: + return KeyValueEmbeddingBag( + config=config, + pg=pg, + device=device, + sharding_type=sharding_type, + ) else: raise ValueError( f"Compute kernel not supported {config.compute_kernel}" diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index d15a1935c..c5d2cb888 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -9,7 +9,6 @@ import abc import copy -import uuid from collections import defaultdict from dataclasses import dataclass from itertools import filterfalse diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 61f5641f3..f5063805f 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -60,6 +60,7 @@ class EmbeddingComputeKernel(Enum): QUANT = "quant" QUANT_UVM = "quant_uvm" QUANT_UVM_CACHING = "quant_uvm_caching" + SSD = "SSD" def compute_kernel_to_embedding_location( @@ -69,6 +70,7 @@ def compute_kernel_to_embedding_location( EmbeddingComputeKernel.DENSE, EmbeddingComputeKernel.FUSED, EmbeddingComputeKernel.QUANT, + EmbeddingComputeKernel.SSD, # use hbm for cache ]: return EmbeddingLocation.DEVICE elif compute_kernel in [ @@ -410,6 +412,7 @@ def compute_kernels( ret += [ EmbeddingComputeKernel.FUSED_UVM.value, EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.SSD.value, ] else: # TODO re-enable model parallel and dense diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index a9c478969..bed37cd84 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -13,7 +13,6 @@ from functools import partial from typing import ( Any, - Callable, cast, Dict, Iterator, @@ -793,6 +792,10 @@ def _initialize_torch_state(self) -> None: # noqa table_name, local_shards, ) in self._model_parallel_name_to_local_shards.items(): + if model_parallel_name_to_compute_kernel[table_name] in { + EmbeddingComputeKernel.SSD.value + }: + continue # for shards that don't exist on this rank, register with empty tensor if not hasattr(self.embedding_bags[table_name], "weight"): self.embedding_bags[table_name].register_parameter( @@ -841,6 +844,10 @@ def reset_parameters(self) -> None: # Initialize embedding bags weights with init_fn for table_config in self._embedding_bag_configs: + if self.module_sharding_plan[table_config.name].compute_kernel in { + EmbeddingComputeKernel.SSD.value, + }: + continue assert table_config.init_fn is not None param = self.embedding_bags[f"{table_config.name}"].weight # pyre-ignore diff --git a/torchrec/distributed/planner/constants.py b/torchrec/distributed/planner/constants.py index c00192193..f02bb6c87 100644 --- a/torchrec/distributed/planner/constants.py +++ b/torchrec/distributed/planner/constants.py @@ -87,6 +87,7 @@ def kernel_bw_lookup( caching_ratio * hbm_mem_bw + (1 - caching_ratio) * ddr_mem_bw ) / 10, + ("cuda", EmbeddingComputeKernel.SSD.value): ddr_mem_bw, } if ( diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 8f356018e..ad1537f3c 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -8,7 +8,7 @@ # pyre-strict import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple, Union from torch import nn from torchrec.distributed.embedding_types import EmbeddingComputeKernel @@ -40,6 +40,9 @@ logger: logging.Logger = logging.getLogger(__name__) +# compute kernels that should only be used if users specified them +GUARDED_COMPUTE_KERNELS: Set[EmbeddingComputeKernel] = {EmbeddingComputeKernel.SSD} + class EmbeddingEnumerator(Enumerator): """ @@ -256,22 +259,29 @@ def _filter_compute_kernels( allowed_compute_kernels: List[str], sharding_type: str, ) -> List[str]: - # for the log message only - constrained_compute_kernels: List[str] = [ - compute_kernel.value for compute_kernel in EmbeddingComputeKernel - ] - if not self._constraints or not self._constraints.get(name): - filtered_compute_kernels = allowed_compute_kernels + # setup constrained_compute_kernels + if ( + self._constraints + and self._constraints.get(name) + and self._constraints[name].compute_kernels + ): + # pyre-ignore + constrained_compute_kernels: List[str] = self._constraints[ + name + ].compute_kernels else: - constraints: ParameterConstraints = self._constraints[name] - if not constraints.compute_kernels: - filtered_compute_kernels = allowed_compute_kernels - else: - constrained_compute_kernels = constraints.compute_kernels - filtered_compute_kernels = list( - set(constrained_compute_kernels) & set(allowed_compute_kernels) - ) + constrained_compute_kernels: List[str] = [ + compute_kernel.value + for compute_kernel in EmbeddingComputeKernel + if compute_kernel not in GUARDED_COMPUTE_KERNELS + ] + + # setup filtered_compute_kernels + filtered_compute_kernels = list( + set(constrained_compute_kernels) & set(allowed_compute_kernels) + ) + # special rules if EmbeddingComputeKernel.DENSE.value in filtered_compute_kernels: if ( EmbeddingComputeKernel.FUSED.value in filtered_compute_kernels diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index cdfd25219..093aebe6b 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -1051,9 +1051,12 @@ def calculate_shard_storages( if compute_kernel in { EmbeddingComputeKernel.FUSED_UVM_CACHING.value, EmbeddingComputeKernel.QUANT_UVM_CACHING.value, + EmbeddingComputeKernel.SSD.value, }: hbm_storage = round(ddr_storage * caching_ratio) table_cached = True + if compute_kernel in {EmbeddingComputeKernel.SSD.value}: + ddr_storage = 0 optimizer_class = getattr(tensor, "_optimizer_class", None) diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index de4923207..ca863afb3 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -40,8 +40,8 @@ def setUp(self, backend: str = "nccl") -> None: self.tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 2) * 8, + num_embeddings=(i + 1) * 1000, + embedding_dim=16, name="table_" + str(i), feature_names=["feature_" + str(i)], ) @@ -49,8 +49,8 @@ def setUp(self, backend: str = "nccl") -> None: ] shared_features_tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 2) * 8, + num_embeddings=(i + 1) * 1000, + embedding_dim=16, name="table_" + str(i + num_features), feature_names=["feature_" + str(i)], ) @@ -60,8 +60,8 @@ def setUp(self, backend: str = "nccl") -> None: self.mean_tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 2) * 8, + num_embeddings=(i + 1) * 1000, + embedding_dim=16, name="table_" + str(i), feature_names=["feature_" + str(i)], pooling=PoolingType.MEAN, @@ -71,8 +71,8 @@ def setUp(self, backend: str = "nccl") -> None: shared_features_tables_mean = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 2) * 8, + num_embeddings=(i + 1) * 1000, + embedding_dim=16, name="table_" + str(i + num_features), feature_names=["feature_" + str(i)], pooling=PoolingType.MEAN, @@ -83,8 +83,8 @@ def setUp(self, backend: str = "nccl") -> None: self.weighted_tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 2) * 4, + num_embeddings=(i + 1) * 1000, + embedding_dim=16, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], ) diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py new file mode 100644 index 000000000..47a1d9547 --- /dev/null +++ b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py @@ -0,0 +1,885 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import cast, Dict, List, Optional, OrderedDict, Tuple, Union + +import torch +import torch.nn as nn +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from hypothesis import given, settings, strategies as st, Verbosity +from torch import distributed as dist +from torchrec import distributed as trec_dist +from torchrec.distributed.batched_embedding_kernel import ( + KeyValueEmbedding, + KeyValueEmbeddingBag, +) +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + ShardedEmbeddingTable, +) +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + ParameterConstraints, + Topology, +) +from torchrec.distributed.sharding_plan import get_default_sharders +from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNN +from torchrec.distributed.test_utils.test_model_parallel_base import ( + ModelParallelSingleRankBase, +) +from torchrec.distributed.test_utils.test_sharding import ( + copy_state_dict, + create_test_sharder, + SharderType, +) +from torchrec.distributed.tests.test_sequence_model import ( + TestEmbeddingCollectionSharder, + TestSequenceSparseNN, +) +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import ( + DataType, + EmbeddingBagConfig, + EmbeddingConfig, +) + + +def _load_split_embedding_weights( + emb_module: Union[KeyValueEmbedding, KeyValueEmbeddingBag], + weights: List[torch.Tensor], +) -> None: + embedding_tables: List[ShardedEmbeddingTable] = emb_module.config.embedding_tables + + assert len(weights) == len( + embedding_tables + ), "Expect length of weights to be equal to number of embedding tables. " + + cum_sum = 0 + for table_id, (table, weight) in enumerate(zip(embedding_tables, weights)): + # load weights for SSD TBE + height = weight.shape[0] + shard_shape = table.local_rows, table.local_cols + assert shard_shape == weight.shape, "Expect shard shape to match tensor shape." + assert weight.device == torch.device("cpu"), "Weight has to be on CPU." + emb_module.emb_module.ssd_db.set_cuda( + torch.arange(cum_sum, cum_sum + height, dtype=torch.int64), + weight, + torch.as_tensor([height]), + table_id, + ) + cum_sum += height + + +class KeyValueModelParallelTest(ModelParallelSingleRankBase): + def setUp(self, backend: str = "nccl") -> None: + super().setUp(backend=backend) + + num_features = 4 + self.batch_size = 20 + self.num_float_features = 10 + + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 1000, + embedding_dim=256, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [] + + def _generate_dmps_and_batch( + self, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + constraints: Optional[Dict[str, trec_dist.planner.ParameterConstraints]] = None, + ) -> Tuple[List[DistributedModelParallel], ModelInput]: + + if constraints is None: + constraints = {} + if sharders is None: + sharders = get_default_sharders() + + _, local_batch = ModelInput.generate( + batch_size=self.batch_size, + world_size=1, + num_float_features=self.num_float_features, + tables=self.tables, + weighted_tables=self.weighted_tables, + ) + batch = local_batch[0].to(self.device) + + dmps = [] + pg = dist.GroupMember.WORLD + assert pg is not None, "Process group is not initialized" + env = ShardingEnv.from_process_group(pg) + + planner = EmbeddingShardingPlanner( + topology=Topology( + local_world_size=trec_dist.comm.get_local_size(env.world_size), + world_size=env.world_size, + compute_device=self.device.type, + ), + constraints=constraints, + ) + + for _ in range(2): + # Create two TestSparseNN modules, wrap both in DMP + m = TestSparseNN( + tables=self.tables, + num_float_features=self.num_float_features, + weighted_tables=self.weighted_tables, + dense_device=self.device, + sparse_device=torch.device("meta"), + ) + if pg is not None: + plan = planner.collective_plan(m, sharders, pg) + else: + plan = planner.plan(m, sharders) + + dmp = DistributedModelParallel( + module=m, + init_data_parallel=False, + device=self.device, + sharders=sharders, + plan=plan, + ) + + with torch.no_grad(): + dmp(batch) + dmp.init_data_parallel() + dmps.append(dmp) + return (dmps, batch) + + def _set_table_weights_precision(self, dtype: DataType) -> None: + for table in self.tables: + table.data_type = dtype + + @staticmethod + def _copy_ssd_emb_modules( + m1: DistributedModelParallel, m2: DistributedModelParallel + ) -> None: + """ + Util function to copy and set the SSD TBE modules of two models. It + requires both DMP modules to have the same sharding plan. + """ + for lookup1, lookup2 in zip( + m1.module.sparse.ebc._lookups, m2.module.sparse.ebc._lookups + ): + for emb_module1, emb_module2 in zip( + lookup1._emb_modules, lookup2._emb_modules + ): + ssd_emb_modules = {KeyValueEmbeddingBag, KeyValueEmbedding} + if type(emb_module1) in ssd_emb_modules: + assert type(emb_module1) is type(emb_module2), ( + "Expect two emb_modules to be of the same type, either both " + "SSDEmbeddingBag or SSDEmbeddingBag." + ) + + weights = emb_module1.emb_module.debug_split_embedding_weights() + # need to set emb_module1 as well, since otherwise emb_module1 would + # produce a random debug_split_embedding_weights everytime + _load_split_embedding_weights(emb_module1, weights) + _load_split_embedding_weights(emb_module2, weights) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + emb_module1.purge() + emb_module2.purge() + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_load_state_dict( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + This test checks that if SSD TBE is deterministic. That is, if two SSD + TBEs start with the same state, they would produce the same output. + """ + self._set_table_weights_precision(dtype) + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ] + + # pyre-ignore + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models(m1, m2, is_deterministic=is_deterministic) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_tbe_numerical_accuracy( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + Make sure it produces same numbers as normal TBE. + """ + self._set_table_weights_precision(dtype) + + base_kernel_type = EmbeddingComputeKernel.FUSED.value + learning_rate = 0.1 + fused_params = { + "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, + "learning_rate": learning_rate, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + fused_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + base_kernel_type, # base kernel type + fused_params=fused_params, + ), + ), + ] + ssd_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ), + ] + ssd_constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + (fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders) + (ssd_model, _), batch = self._generate_dmps_and_batch( + ssd_sharders, constraints=ssd_constraints + ) + + # load state dict for dense modules + copy_state_dict( + ssd_model.state_dict(), fused_model.state_dict(), exclude_predfix="sparse" + ) + + # for this to work, we expect the order of lookups to be the same + assert len(fused_model.module.sparse.ebc._lookups) == len( + ssd_model.module.sparse.ebc._lookups + ), "Expect same number of lookups" + + for fused_lookup, ssd_lookup in zip( + fused_model.module.sparse.ebc._lookups, ssd_model.module.sparse.ebc._lookups + ): + assert len(fused_lookup._emb_modules) == len( + ssd_lookup._emb_modules + ), "Expect same number of emb modules" + for fused_emb_module, ssd_emb_module in zip( + fused_lookup._emb_modules, ssd_lookup._emb_modules + ): + weights = fused_emb_module.split_embedding_weights() + weights = [weight.to("cpu") for weight in weights] + _load_split_embedding_weights(ssd_emb_module, weights) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + ssd_emb_module.purge() + + if is_training: + self._train_models(fused_model, ssd_model, batch) + self._eval_models( + fused_model, ssd_model, batch, is_deterministic=is_deterministic + ) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_fused_optimizer( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + Purpose of this test is to make sure it works with warm up policy. + """ + self._set_table_weights_precision(dtype) + + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + base_sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "learning_rate": 0.2, + "stochastic_rounding": stochastic_rounding, + }, + ), + ] + models, batch = self._generate_dmps_and_batch( + base_sharders, # pyre-ignore + constraints=constraints, + ) + base_model, _ = models + + test_sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + }, + ), + ] + models, _ = self._generate_dmps_and_batch( + test_sharders, # pyre-ignore + constraints=constraints, + ) + test_model, _ = models + + # load state dict for dense modules + test_model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", base_model.state_dict()) + ) + self._copy_ssd_emb_modules(base_model, test_model) + + self._eval_models( + base_model, test_model, batch, is_deterministic=is_deterministic + ) + + # change learning rate for test_model + fused_opt = test_model.fused_optimizer + # pyre-ignore + fused_opt.param_groups[0]["lr"] = 0.2 + fused_opt.zero_grad() + + if is_training: + self._train_models(base_model, test_model, batch) + self._eval_models( + base_model, test_model, batch, is_deterministic=is_deterministic + ) + self._compare_models(base_model, test_model, is_deterministic=is_deterministic) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + fused_first=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_mixed_kernels( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + fused_first: bool, + ) -> None: + """ + Purpose of this test is to make sure it works with warm up policy. + """ + self._set_table_weights_precision(dtype) + + base_kernel_type = EmbeddingComputeKernel.FUSED.value + + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=( + [base_kernel_type] if i % 2 == fused_first else [kernel_type] + ), + ) + for i, table in enumerate(self.tables) + } + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + sharders = [ + EmbeddingBagCollectionSharder(fused_params=fused_params), + ] + + # pyre-ignore + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models(m1, m2, is_deterministic=is_deterministic) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + table_wise_first=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_mixed_sharding_types( + self, + sharder_type: str, + kernel_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + table_wise_first: bool, + ) -> None: + """ + Purpose of this test is to make sure it works with warm up policy. + """ + self._set_table_weights_precision(dtype) + + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + + constraints = { + table.name: ParameterConstraints( + sharding_types=( + [ShardingType.TABLE_WISE.value] + if i % 2 == table_wise_first + else [ShardingType.ROW_WISE.value] + ), + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + sharders = [ + EmbeddingBagCollectionSharder(fused_params=fused_params), + ] + + # pyre-ignore + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models(m1, m2, is_deterministic=is_deterministic) + + +class KeyValueSequenceModelParallelStateDictTest(ModelParallelSingleRankBase): + def setUp(self, backend: str = "nccl") -> None: + super().setUp(backend=backend) + + num_features = 4 + self.num_float_features = 16 + self.batch_size = 20 + shared_features = 2 + + initial_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 1000, + embedding_dim=16, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + + shared_features_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 11, + embedding_dim=16, + name="table_" + str(i + num_features), + feature_names=["feature_" + str(i)], + ) + for i in range(shared_features) + ] + + self.tables = initial_tables + shared_features_tables + self.shared_features = [f"feature_{i}" for i in range(shared_features)] + + self.embedding_groups = { + "group_0": [ + ( + f"{feature}@{table.name}" + if feature in self.shared_features + else feature + ) + for table in self.tables + for feature in table.feature_names + ] + } + + def _set_table_weights_precision(self, dtype: DataType) -> None: + for table in self.tables: + table.data_type = dtype + + def _generate_dmps_and_batch( + self, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + constraints: Optional[Dict[str, trec_dist.planner.ParameterConstraints]] = None, + ) -> Tuple[List[DistributedModelParallel], ModelInput]: + """ + Generate two DMPs based on Sequence Sparse NN and one batch of data. + """ + if constraints is None: + constraints = {} + if sharders is None: + sharders = get_default_sharders() + + _, local_batch = ModelInput.generate( + batch_size=self.batch_size, + world_size=1, + tables=self.tables, + num_float_features=self.num_float_features, + weighted_tables=[], + ) + batch = local_batch[0].to(self.device) + + dmps = [] + pg = dist.GroupMember.WORLD + assert pg is not None, "Process group is not initialized" + env = ShardingEnv.from_process_group(pg) + + planner = EmbeddingShardingPlanner( + topology=Topology( + local_world_size=trec_dist.comm.get_local_size(env.world_size), + world_size=env.world_size, + compute_device=self.device.type, + ), + constraints=constraints, + ) + + for _ in range(2): + # Create two TestSparseNN modules, wrap both in DMP + m = TestSequenceSparseNN( + tables=self.tables, + num_float_features=self.num_float_features, + embedding_groups=self.embedding_groups, + dense_device=self.device, + sparse_device=torch.device("meta"), + ) + if pg is not None: + plan = planner.collective_plan(m, sharders, pg) + else: + plan = planner.plan(m, sharders) + + dmp = DistributedModelParallel( + module=m, + init_data_parallel=False, + device=self.device, + sharders=sharders, + plan=plan, + ) + + with torch.no_grad(): + dmp(batch) + dmp.init_data_parallel() + dmps.append(dmp) + return (dmps, batch) + + @staticmethod + def _copy_ssd_emb_modules( + m1: DistributedModelParallel, m2: DistributedModelParallel + ) -> None: + """ + Util function to copy and set the SSD TBE modules of two models. It + requires both DMP modules to have the same sharding plan. + """ + for lookup1, lookup2 in zip( + m1.module.sparse.ec._lookups, m2.module.sparse.ec._lookups + ): + for emb_module1, emb_module2 in zip( + lookup1._emb_modules, lookup2._emb_modules + ): + ssd_emb_modules = {KeyValueEmbeddingBag, KeyValueEmbedding} + if type(emb_module1) in ssd_emb_modules: + assert type(emb_module1) is type(emb_module2), ( + "Expect two emb_modules to be of the same type, either both " + "SSDEmbeddingBag or SSDEmbeddingBag." + ) + + weights = emb_module1.emb_module.debug_split_embedding_weights() + # need to set emb_module1 as well, since otherwise emb_module1 would + # produce a random debug_split_embedding_weights everytime + _load_split_embedding_weights(emb_module1, weights) + _load_split_embedding_weights(emb_module2, weights) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + emb_module1.purge() + emb_module2.purge() + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_load_state_dict( + self, + sharding_type: str, + kernel_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + This test checks that if SSD TBE is deterministic. That is, if two SSD + TBEs start with the same state, they would produce the same output. + """ + self._set_table_weights_precision(dtype) + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + sharders = [ + cast( + ModuleSharder[nn.Module], + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ), + ), + ] + + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models(m1, m2, is_deterministic=is_deterministic) + + +# TODO: remove after development is done +def main() -> None: + unittest.main() + + +if __name__ == "__main__": + main()