Skip to content

Commit

Permalink
Add imbalance stats for total perf, hbm and ddr (pytorch#2040)
Browse files Browse the repository at this point in the history
Summary:

Add and log 4 ways to measure imbalance of generated sharding plan.

Suppose we have $k$ gpus. Let $s$ be a vector where the $i^{th}$ component is the total size currently allocated to the $i^{th}$ device.

First we normalize vector $s$ such that sum of elements in $p$ equals 1. We can view this as a probability distribution, and to get a measure of imbalance, we can measure its deviation from uniform distribution $p$ using one of the following ways:

- total variations
- total distance
- chi divergence
- KL divergence

Reviewed By: henrylhtsang

Differential Revision: D57465383
  • Loading branch information
sarckk authored and facebook-github-bot committed May 30, 2024
1 parent f8b15d4 commit 07cb78f
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 1 deletion.
137 changes: 137 additions & 0 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,67 @@
MIN_WIDTH = 90


def _normalize_float(p: List[float]) -> List[float]:
p_total = sum(p)
return [p_i / p_total for p_i in p]


def _normalize_int(p: List[int]) -> List[float]:
p_total = sum(p)
return [p_i * 1.0 / p_total for p_i in p]


def _total_variation(p: List[float]) -> float:
k = len(p)
if not k:
return -1.0
return max(abs(pi - 1.0 / k) for pi in p)


def _total_distance(p: List[float]) -> float:
k = len(p)
if not k:
return -1.0
return sum(abs(pi - 1.0 / k) for pi in p)


def _chi_divergence(p: List[float], alpha: float = 1.0) -> float:
assert alpha >= 1
k = len(p)
if not k:
return -1.0
return sum(abs(pi - 1.0 / k) ** alpha * k ** (alpha - 1.0) for pi in p)


def _kl_divergence(p: List[float]) -> float:
k = len(p)
if not k:
return -1.0
return sum(pi * math.log(k * pi) for pi in p if pi > 0)


def _calc_max_chi_divergence(N: int, alpha: float) -> float:
assert N > 0
# Upper bound for chi divergence in our case given sample size of distribution (N) and alpha
return (N - 1) ** alpha * (1.0 / N) + (N - 1) * (1.0 / N)


def _calc_max_kl_divergence(N: int) -> float:
assert N > 0
# Upper bound for KL divergence in our case given sample size of distribution (N)
return math.log(N)


CHI_DIVERGENCE_ALPHA = 1.0

IMBALANCE_STAT_MEASURE: Dict[str, Tuple[Callable[..., float], Dict[str, Any]]] = {
"Total Variation": (_total_variation, {}),
"Total Distance": (_total_distance, {}),
"Chi Divergence": (_chi_divergence, {"alpha": CHI_DIVERGENCE_ALPHA}),
"KL Divergence": (_kl_divergence, {}),
}


class EmbeddingStats(Stats):
"""
Stats for a sharding planner execution.
Expand Down Expand Up @@ -436,6 +497,14 @@ def log(

if debug:
if sharding_plan.plan:
# Plan imbalance stats for perf and storage
self._log_plan_imbalance_stats(
perf,
used_hbm,
used_ddr,
)

# Max perf and HBM to help root cause imbalance
self._log_max_perf_and_max_hbm(perf, used_hbm)
self._log_storage_reservation_stats(
storage_reservation,
Expand Down Expand Up @@ -509,6 +578,74 @@ def _get_shard_stats(

return ranks, input_sizes, output_sizes

def _log_dist_imbalance_stats(
self,
normalized_dist: List[float],
) -> None:
for name, (measure, kwargs) in IMBALANCE_STAT_MEASURE.items():
result_txt = f"{name}: {measure(normalized_dist, **kwargs):.3f}"
self._stats_table.append(f"# {result_txt : <{self._width-3}}#")

def _log_plan_imbalance_stats(
self, perf: List[Perf], used_hbm: List[int], used_ddr: List[int]
) -> None:
imbalance_logged = False
total_perfs = [perf_i.total for perf_i in perf]

# Can extend with fwd/bwd perfs if needed
perf_dists = [
("Total", total_perfs),
]

for name, perf_dist in perf_dists:
if sum(perf_dist) > 0:
imbalance_logged = True
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(
f"# {name + ' Perf Imbalance Statistics' : <{self._width-3}}#"
)
normalized_perf_dist = _normalize_float(perf_dist)
self._log_dist_imbalance_stats(normalized_perf_dist)

if sum(used_hbm) > 0:
imbalance_logged = True
normalized_used_hbm = _normalize_int(used_hbm)
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(
f"# {'HBM Imbalance Statistics' : <{self._width-3}}#"
)
self._log_dist_imbalance_stats(normalized_used_hbm)

if sum(used_ddr) > 0:
imbalance_logged = True
normalized_used_ddr = _normalize_int(used_ddr)
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(
f"# {'DDR Imbalance Statistics' : <{self._width-3}}#"
)
self._log_dist_imbalance_stats(normalized_used_ddr)

if imbalance_logged:
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(
f"# {'Total Variation: higher means more imbalanced (ranges 0 to 1)' : <{self._width-3}}#"
)
self._stats_table.append(
f"# {'Total Distance: higher means more imbalanced (ranges 0 to 1)' : <{self._width-3}}#"
)
N = len(perf) # world size
if N > 0:
max_chi_divergence = _calc_max_chi_divergence(
N=N, alpha=CHI_DIVERGENCE_ALPHA
)
self._stats_table.append(
f"# {f'Chi Divergence: higher means more imbalanced (ranges 0 to {max_chi_divergence:.3f})' : <{self._width-3}}#"
)
max_kl_divergence = _calc_max_kl_divergence(N)
self._stats_table.append(
f"# {f'KL Divergence: higher means more imbalanced (ranges 0 to {max_kl_divergence:.3f})' : <{self._width-3}}#"
)

def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> None:

max_total_perf_text = f"Longest Critical Path (Maximum of Total Perf): {_generate_max_text([perf.total for perf in perfs])}"
Expand Down
91 changes: 90 additions & 1 deletion torchrec/distributed/planner/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,30 @@

# pyre-strict

import math
import unittest
from typing import List

import hypothesis.strategies as st

import torch
from hypothesis import given, settings
from torch import nn
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner.planners import EmbeddingShardingPlanner
from torchrec.distributed.planner.stats import EmbeddingStats, NoopEmbeddingStats
from torchrec.distributed.planner.stats import (
_calc_max_chi_divergence,
_calc_max_kl_divergence,
_chi_divergence,
_kl_divergence,
_normalize_float,
_normalize_int,
_total_distance,
_total_variation,
EmbeddingStats,
NoopEmbeddingStats,
)
from torchrec.distributed.planner.types import Topology
from torchrec.distributed.test_utils.test_model import TestSparseNN
from torchrec.distributed.types import ModuleSharder, ShardingType
Expand Down Expand Up @@ -86,3 +101,77 @@ def test_embedding_stats_output_with_top_hbm_usage(self) -> None:
if top_hbm_usage_keyword in row:
top_hbm_mem_usage = float(row.split(" ")[6])
self.assertIsNotNone(top_hbm_mem_usage)

def test_normalize_float(self) -> None:
p = [2.0, 2.0]
self.assertEqual(_normalize_float(p), [0.5, 0.5])

def test_normalize_int(self) -> None:
p = [2, 2]
self.assertEqual(_normalize_int(p), [0.5, 0.5])

def test_total_variation(self) -> None:
p_1 = [0.5, 0.5]
self.assertEqual(_total_variation(p_1), 0.0)

p_2 = [0.0, 1.0]
self.assertEqual(_total_variation(p_2), 0.5)

def test_total_distance(self) -> None:
p_1 = [0.5, 0.5]
self.assertEqual(_total_distance(p_1), 0.0)

p_2 = [0.0, 1.0]
self.assertEqual(_total_distance(p_2), 1.0)

def test_chi_divergence(self) -> None:
p_1 = [0.5, 0.5]
self.assertEqual(_chi_divergence(p_1), 0.0)

p_2 = [0.0, 1.0]
self.assertEqual(_chi_divergence(p_2), 1.0)

def test_kl_divergence(self) -> None:
p_1 = [0.5, 0.5]
self.assertEqual(_kl_divergence(p_1), 0.0)

p_2 = [0.1, 0.9]
self.assertAlmostEqual(_kl_divergence(p_2), 0.368, 3)

# pyre-ignore
@given(
N=st.integers(min_value=10, max_value=200),
)
@settings(max_examples=4, deadline=None)
def test_kl_divergence_upper_bound(self, N: int) -> None:
# Generate most imbalanced distribution
normalized_p = [
1.0,
] + [
0.0
] * (N - 1)
N = len(normalized_p)
self.assertEqual(_kl_divergence(normalized_p), _calc_max_kl_divergence(N))

# pyre-ignore
@given(
N=st.integers(min_value=10, max_value=200),
alpha=st.floats(min_value=1.0, max_value=5.0),
)
@settings(max_examples=4, deadline=None)
def test_chi_divergence_upper_bound(self, N: int, alpha: float) -> None:
# Generate most imbalanced distribution
normalized_p = [
1.0,
] + [
0.0
] * (N - 1)
N = len(normalized_p)

self.assertTrue(
math.isclose(
_chi_divergence(normalized_p, alpha),
_calc_max_chi_divergence(N, alpha),
abs_tol=1e-10,
)
)

0 comments on commit 07cb78f

Please sign in to comment.