Skip to content

Commit

Permalink
Integrate SSD TBE stage 1
Browse files Browse the repository at this point in the history
Summary:
# Plan
Stage 1 aims to ensure that it can run, and won't break from normal operations (e.g. checkpointing).

Checkpointing (i.e. state_dict and load_state_dict) are still work in progress. We also need to guarantee checkpointing for optimizer states.

Stage 2: save state_dict (mostly on fbgemm side)
* current hope is we can rely on flush to save state dict

Stage 3: load_state_dict (need more thoughts)
* solution should be similar to that of PS

Stage 4: optimizer states checkpointing (torchrec side, should be pretty standard)
* should be straightforward
* need fbgemm to support split_embedding_weights api 

# Outstanding issues:
* init is not the same as before
* SSD TBE doesn't support mixed dim


# design doc

TODO: 

# tests should cover
* state dict and load state dict (done)
  * should copy dense parts and not break 
* deterministics output (done)
* numerical equivalence to normal TBE (done)
* changing learning rate and warm up policy (done)
* work for different sharding types (done)
* work with mixed kernel (done)
* work with mixed sharding types
* multi-gpu training (todo)

# OSS
NOTE: SSD TBE won't work in an OSS environment, due to some rocksdb problem.

# ad hoc
* SSD kernel is guarded, user must specify it in constraints to use it

Differential Revision: D57452256
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jun 5, 2024
1 parent da49f44 commit 8ac8c0e
Show file tree
Hide file tree
Showing 11 changed files with 1,334 additions and 28 deletions.
372 changes: 372 additions & 0 deletions torchrec/distributed/batched_embedding_kernel.py

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import logging
import warnings
from collections import defaultdict, deque, OrderedDict
from dataclasses import dataclass, field
from itertools import accumulate
from typing import Any, cast, Dict, List, MutableMapping, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -654,6 +653,10 @@ def _initialize_torch_state(self) -> None: # noqa
table_name,
local_shards,
) in self._model_parallel_name_to_local_shards.items():
if model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.SSD.value
}:
continue
# for shards that don't exist on this rank, register with empty tensor
if not hasattr(self.embeddings[table_name], "weight"):
self.embeddings[table_name].register_parameter(
Expand Down Expand Up @@ -702,6 +705,10 @@ def reset_parameters(self) -> None:
return
# Initialize embedding weights with init_fn
for table_config in self._embedding_configs:
if self.module_sharding_plan[table_config.name].compute_kernel in {
EmbeddingComputeKernel.SSD.value,
}:
continue
assert table_config.init_fn is not None
param = self.embeddings[f"{table_config.name}"].weight
# pyre-ignore
Expand Down
19 changes: 19 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
BatchedDenseEmbeddingBag,
BatchedFusedEmbedding,
BatchedFusedEmbeddingBag,
KeyValueEmbedding,
KeyValueEmbeddingBag,
)
from torchrec.distributed.comm_ops import get_gradient_division
from torchrec.distributed.composable.table_batched_embedding_slice import (
Expand Down Expand Up @@ -168,6 +170,14 @@ def _create_lookup(
pg=pg,
device=device,
)
elif config.compute_kernel in {
EmbeddingComputeKernel.SSD,
}:
return KeyValueEmbedding(
config=config,
pg=pg,
device=device,
)
else:
raise ValueError(
f"Compute kernel not supported {config.compute_kernel}"
Expand Down Expand Up @@ -368,6 +378,15 @@ def _create_lookup(
device=device,
sharding_type=sharding_type,
)
elif config.compute_kernel in {
EmbeddingComputeKernel.SSD,
}:
return KeyValueEmbeddingBag(
config=config,
pg=pg,
device=device,
sharding_type=sharding_type,
)
else:
raise ValueError(
f"Compute kernel not supported {config.compute_kernel}"
Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import abc
import copy
import uuid
from collections import defaultdict
from dataclasses import dataclass
from itertools import filterfalse
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class EmbeddingComputeKernel(Enum):
QUANT = "quant"
QUANT_UVM = "quant_uvm"
QUANT_UVM_CACHING = "quant_uvm_caching"
SSD = "SSD"


def compute_kernel_to_embedding_location(
Expand All @@ -69,6 +70,7 @@ def compute_kernel_to_embedding_location(
EmbeddingComputeKernel.DENSE,
EmbeddingComputeKernel.FUSED,
EmbeddingComputeKernel.QUANT,
EmbeddingComputeKernel.SSD, # use hbm for cache
]:
return EmbeddingLocation.DEVICE
elif compute_kernel in [
Expand Down Expand Up @@ -410,6 +412,7 @@ def compute_kernels(
ret += [
EmbeddingComputeKernel.FUSED_UVM.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
EmbeddingComputeKernel.SSD.value,
]
else:
# TODO re-enable model parallel and dense
Expand Down
9 changes: 8 additions & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from functools import partial
from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
Expand Down Expand Up @@ -793,6 +792,10 @@ def _initialize_torch_state(self) -> None: # noqa
table_name,
local_shards,
) in self._model_parallel_name_to_local_shards.items():
if model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.SSD.value
}:
continue
# for shards that don't exist on this rank, register with empty tensor
if not hasattr(self.embedding_bags[table_name], "weight"):
self.embedding_bags[table_name].register_parameter(
Expand Down Expand Up @@ -841,6 +844,10 @@ def reset_parameters(self) -> None:

# Initialize embedding bags weights with init_fn
for table_config in self._embedding_bag_configs:
if self.module_sharding_plan[table_config.name].compute_kernel in {
EmbeddingComputeKernel.SSD.value,
}:
continue
assert table_config.init_fn is not None
param = self.embedding_bags[f"{table_config.name}"].weight
# pyre-ignore
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/planner/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def kernel_bw_lookup(
caching_ratio * hbm_mem_bw + (1 - caching_ratio) * ddr_mem_bw
)
/ 10,
("cuda", EmbeddingComputeKernel.SSD.value): ddr_mem_bw,
}

if (
Expand Down
40 changes: 25 additions & 15 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import logging
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union

from torch import nn
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
Expand Down Expand Up @@ -40,6 +40,9 @@

logger: logging.Logger = logging.getLogger(__name__)

# compute kernels that should only be used if users specified them
GUARDED_COMPUTE_KERNELS: Set[EmbeddingComputeKernel] = {EmbeddingComputeKernel.SSD}


class EmbeddingEnumerator(Enumerator):
"""
Expand Down Expand Up @@ -256,22 +259,29 @@ def _filter_compute_kernels(
allowed_compute_kernels: List[str],
sharding_type: str,
) -> List[str]:
# for the log message only
constrained_compute_kernels: List[str] = [
compute_kernel.value for compute_kernel in EmbeddingComputeKernel
]
if not self._constraints or not self._constraints.get(name):
filtered_compute_kernels = allowed_compute_kernels
# setup constrained_compute_kernels
if (
self._constraints
and self._constraints.get(name)
and self._constraints[name].compute_kernels
):
# pyre-ignore
constrained_compute_kernels: List[str] = self._constraints[
name
].compute_kernels
else:
constraints: ParameterConstraints = self._constraints[name]
if not constraints.compute_kernels:
filtered_compute_kernels = allowed_compute_kernels
else:
constrained_compute_kernels = constraints.compute_kernels
filtered_compute_kernels = list(
set(constrained_compute_kernels) & set(allowed_compute_kernels)
)
constrained_compute_kernels: List[str] = [
compute_kernel.value
for compute_kernel in EmbeddingComputeKernel
if compute_kernel not in GUARDED_COMPUTE_KERNELS
]

# setup filtered_compute_kernels
filtered_compute_kernels = list(
set(constrained_compute_kernels) & set(allowed_compute_kernels)
)

# special rules
if EmbeddingComputeKernel.DENSE.value in filtered_compute_kernels:
if (
EmbeddingComputeKernel.FUSED.value in filtered_compute_kernels
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,9 +1051,12 @@ def calculate_shard_storages(
if compute_kernel in {
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
EmbeddingComputeKernel.QUANT_UVM_CACHING.value,
EmbeddingComputeKernel.SSD.value,
}:
hbm_storage = round(ddr_storage * caching_ratio)
table_cached = True
if compute_kernel in {EmbeddingComputeKernel.SSD.value}:
ddr_storage = 0

optimizer_class = getattr(tensor, "_optimizer_class", None)

Expand Down
20 changes: 10 additions & 10 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ def setUp(self, backend: str = "nccl") -> None:

self.tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 2) * 8,
num_embeddings=(i + 1) * 1000,
embedding_dim=16,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(num_features)
]
shared_features_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 2) * 8,
num_embeddings=(i + 1) * 1000,
embedding_dim=16,
name="table_" + str(i + num_features),
feature_names=["feature_" + str(i)],
)
Expand All @@ -60,8 +60,8 @@ def setUp(self, backend: str = "nccl") -> None:

self.mean_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 2) * 8,
num_embeddings=(i + 1) * 1000,
embedding_dim=16,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
pooling=PoolingType.MEAN,
Expand All @@ -71,8 +71,8 @@ def setUp(self, backend: str = "nccl") -> None:

shared_features_tables_mean = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 2) * 8,
num_embeddings=(i + 1) * 1000,
embedding_dim=16,
name="table_" + str(i + num_features),
feature_names=["feature_" + str(i)],
pooling=PoolingType.MEAN,
Expand All @@ -83,8 +83,8 @@ def setUp(self, backend: str = "nccl") -> None:

self.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 2) * 4,
num_embeddings=(i + 1) * 1000,
embedding_dim=16,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
Expand Down
Loading

0 comments on commit 8ac8c0e

Please sign in to comment.