Skip to content

Commit

Permalink
Standardize access of ShardingOption cache load factor (#1634)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1634

Cache params can be `None`, so different places have implementations to retrieve the load factor safely. Make this a property of `ShardingOption` to standardize property access

Reviewed By: henrylhtsang

Differential Revision: D52804086

fbshipit-source-id: 2949c31d37281dd96cf4d9d167a478b267fbfd40
  • Loading branch information
sarckk authored and facebook-github-bot committed Jan 18, 2024
1 parent 4a5dd65 commit 1e6b49d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 41 deletions.
19 changes: 2 additions & 17 deletions torchrec/distributed/planner/proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,7 @@ def none_to_zero(x: Optional[float]) -> float:
* none_to_zero(
EmbeddingOffloadScaleupProposer.get_expected_lookups(sharding_option)
)
* none_to_zero(
EmbeddingOffloadScaleupProposer.get_load_factor(sharding_option)
)
* none_to_zero(sharding_option.cache_load_factor)
> 0
]
# Nothing to scale
Expand All @@ -423,10 +421,7 @@ def none_to_zero(x: Optional[float]) -> float:
cache_tables, enumerator
)
clfs = torch.tensor(
[
EmbeddingOffloadScaleupProposer.get_load_factor(sharding_option)
for sharding_option in cache_tables
]
[sharding_option.cache_load_factor for sharding_option in cache_tables]
)
# cooked_cacheability is cacheability scaled by the expected number of cache
# lookups.
Expand Down Expand Up @@ -483,16 +478,6 @@ def get_expected_lookups(sharding_option: ShardingOption) -> Optional[float]:
return None
return sharding_option.cache_params.stats.expected_lookups

@staticmethod
def get_load_factor(sharding_option: ShardingOption) -> Optional[float]:
# helper to appease pyre type checker, as cache_params is Optional it maybe None
if (
sharding_option.cache_params is None
or sharding_option.cache_params.stats is None
):
return None
return sharding_option.cache_params.load_factor

# The relationship between clf and shard memory usage is non-linear due to non-clf
# overheads like optimization stats and input/output storage. We model it as an
# affine relationship: bytes = clf * A + B where B is fixed overhead independent of
Expand Down
12 changes: 2 additions & 10 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ def estimate(
sharder_key = sharder_name(type(sharding_option.module[1]))
sharder = sharder_map[sharder_key]

caching_ratio = (
sharding_option.cache_params.load_factor
if sharding_option.cache_params
else None
)
caching_ratio = sharding_option.cache_load_factor
# TODO: remove after deprecating fused_params in sharder
if caching_ratio is None:
caching_ratio = (
Expand Down Expand Up @@ -819,11 +815,7 @@ def estimate(
sharder_key = sharder_name(type(sharding_option.module[1]))
sharder = sharder_map[sharder_key]

caching_ratio = (
sharding_option.cache_params.load_factor
if sharding_option.cache_params
else None
)
caching_ratio = sharding_option.cache_load_factor
# TODO: remove after deprecating fused_params in sharder
if caching_ratio is None:
caching_ratio = (
Expand Down
15 changes: 1 addition & 14 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,7 @@ def log(
or so.sharding_type == ShardingType.TABLE_COLUMN_WISE.value
else f"{so.tensor.shape[1]}"
)
cache_load_factor = _get_cache_load_factor(sharding_option=so)
cache_load_factor = (
str(cache_load_factor) if cache_load_factor is not None else "None"
)
cache_load_factor = str(so.cache_load_factor)
hash_size = so.tensor.shape[0]
param_table.append(
[
Expand Down Expand Up @@ -598,16 +595,6 @@ def _generate_max_text(perfs: List[float]) -> str:
return f"{round(max_perf, 3)} ms on {max_perf_ranks}"


def _get_cache_load_factor(
sharding_option: ShardingOption,
) -> Optional[float]:
return (
sharding_option.cache_params.load_factor
if sharding_option.cache_params
else None
)


def _get_sharding_type_abbr(sharding_type: str) -> str:
if sharding_type == ShardingType.DATA_PARALLEL.value:
return "DP"
Expand Down
6 changes: 6 additions & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,12 @@ def module(self) -> Tuple[str, nn.Module]:
def fqn(self) -> str:
return self.module[0] + "." + self.name

@property
def cache_load_factor(self) -> Optional[float]:
if self.cache_params is not None:
return self.cache_params.load_factor
return None

@property
def path(self) -> str:
return self.module[0]
Expand Down

0 comments on commit 1e6b49d

Please sign in to comment.