Skip to content

Commit

Permalink
Create BoundsCheckMode fused_param for inference
Browse files Browse the repository at this point in the history
Summary: Introduce BoundsCheckMode fused_param for TBE BoundsCheckMode. There is no reason really to run bounds_check_indices during inference use case (AIMP has it off by default: https://fburl.com/code/q8zhundg), and it causes issues with the PT2 IR (bounds_check_indices is a mutating op)

Reviewed By: ZhengkaiZ

Differential Revision: D56743992

fbshipit-source-id: 1a474170bd5a16f6ef9ddc8d047eca6739d213fc
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed May 1, 2024
1 parent 185ad37 commit a649b4e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 3 deletions.
13 changes: 13 additions & 0 deletions torchrec/distributed/fused_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torchrec.distributed.embedding_types import GroupedEmbeddingConfig
from torchrec.distributed.types import BoundsCheckMode

FUSED_PARAM_REGISTER_TBE_BOOL: str = "__register_tbes_in_named_modules"
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: str = (
"__register_quant_state_dict_split_scale_bias"
)
FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment"
FUSED_PARAM_IS_WEIGHTED: str = "__register_tbe_is_weighted"
FUSED_PARAM_BOUNDS_CHECK_MODE: str = "__register_tbe_bounds_check_mode"


class TBEToRegisterMixIn:
Expand Down Expand Up @@ -65,6 +67,15 @@ def is_fused_param_weighted(fused_params: Optional[Dict[str, Any]]) -> Optional[
return fused_params[FUSED_PARAM_IS_WEIGHTED]


def fused_param_bounds_check_mode(
fused_params: Optional[Dict[str, Any]]
) -> Optional[BoundsCheckMode]:
if fused_params is None or FUSED_PARAM_BOUNDS_CHECK_MODE not in fused_params:
return None
else:
return fused_params[FUSED_PARAM_BOUNDS_CHECK_MODE]


def is_fused_param_quant_state_dict_split_scale_bias(
fused_params: Optional[Dict[str, Any]]
) -> bool:
Expand All @@ -90,5 +101,7 @@ def tbe_fused_params(
fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT)
if FUSED_PARAM_IS_WEIGHTED in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_IS_WEIGHTED)
if FUSED_PARAM_BOUNDS_CHECK_MODE in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE)

return fused_params_for_tbe
10 changes: 9 additions & 1 deletion torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
GroupedEmbeddingConfig,
)
from torchrec.distributed.fused_params import (
get_fused_param_tbe_row_alignment,
fused_param_bounds_check_mode,
is_fused_param_quant_state_dict_split_scale_bias,
is_fused_param_register_tbe,
is_fused_param_weighted,
tbe_fused_params,
TBEToRegisterMixIn,
)
from torchrec.distributed.types import BoundsCheckMode
from torchrec.distributed.utils import append_prefix
from torchrec.modules.embedding_configs import (
DATA_TYPE_NUM_BITS,
Expand Down Expand Up @@ -197,6 +198,10 @@ def __init__(
self._quant_state_dict_split_scale_bias: bool = (
is_fused_param_quant_state_dict_split_scale_bias(fused_params)
)
bounds_check_mode: Optional[BoundsCheckMode] = fused_param_bounds_check_mode(
fused_params
)

index_remapping = [
table.pruning_indices_remapping for table in config.embedding_tables
]
Expand Down Expand Up @@ -233,6 +238,9 @@ def __init__(
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
bounds_check_mode=(
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
),
**(tbe_fused_params(fused_params) or {}),
)
)
Expand Down
18 changes: 17 additions & 1 deletion torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from fbgemm_gpu import sparse_ops # noqa: F401, E402
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.fused_params import FUSED_PARAM_BOUNDS_CHECK_MODE
from torchrec.distributed.shard import _shard_modules
from torchrec.distributed.test_utils.infer_utils import (
assert_close,
Expand All @@ -35,7 +36,7 @@
replace_sharded_quant_modules_tbes_with_mock_tbes,
TestQuantEBCSharder,
)
from torchrec.distributed.types import ShardingEnv, ShardingType
from torchrec.distributed.types import BoundsCheckMode, ShardingEnv, ShardingType
from torchrec.sparse.jagged_tensor import ComputeKJTToJTDict, KeyedJaggedTensor


Expand Down Expand Up @@ -82,17 +83,22 @@ def _sharded_quant_ebc_model(

sharding_type: ShardingType = ShardingType.TABLE_WISE

fused_params = {
FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE,
}
if feature_processor:
sharder = TestQuantFPEBCSharder(
sharding_type=sharding_type.value,
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[table.name for table in mi.tables],
fused_params=fused_params,
)
else:
sharder = TestQuantEBCSharder(
sharding_type=sharding_type.value,
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[table.name for table in mi.tables],
fused_params=fused_params,
)
# pyre-ignore
plan = mi.planner.plan(
Expand Down Expand Up @@ -308,6 +314,11 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None:

ep.module()(kjt.values(), kjt.lengths())

# PT2 IR autofunctionalizes mutation funcs (bounds_check_indices)
# ensure such node isn't present, as it causes issues with IR
for n in ep.graph_module.graph.nodes:
self.assertFalse("auto_functionalized" in str(n.name))

# TODO: Fix Unflatten
# torch.export.unflatten(ep)

Expand Down Expand Up @@ -338,6 +349,11 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None:
)
ep.module()(kjt.values(), kjt.lengths())

# PT2 IR autofunctionalizes mutation funcs (bounds_check_indices)
# ensure such node isn't present, as it causes issues with IR
for n in ep.graph_module.graph.nodes:
self.assertFalse("auto_functionalized" in str(n.name))

# TODO: Fix Unflatten
# torch.export.unflatten(ep)

Expand Down
9 changes: 8 additions & 1 deletion torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchrec import distributed as trec_dist, inference as trec_infer
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.fused_params import (
FUSED_PARAM_BOUNDS_CHECK_MODE,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
FUSED_PARAM_REGISTER_TBE_BOOL,
)
Expand All @@ -36,7 +37,12 @@
QuantFeatureProcessedEmbeddingBagCollectionSharder,
)
from torchrec.distributed.shard import _shard_modules
from torchrec.distributed.types import ModuleSharder, ShardingPlan, ShardingType
from torchrec.distributed.types import (
BoundsCheckMode,
ModuleSharder,
ShardingPlan,
ShardingType,
)

from torchrec.modules.embedding_configs import QuantConfig
from torchrec.modules.embedding_modules import (
Expand Down Expand Up @@ -375,6 +381,7 @@ def shard_quant_model(
_fused_param: Dict[str, Any] = {
FUSED_PARAM_REGISTER_TBE_BOOL: True,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: True,
FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE,
}

_sharders: List[ModuleSharder[torch.nn.Module]] = [
Expand Down

0 comments on commit a649b4e

Please sign in to comment.