Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add percentage based threshold function and unit tests #1679

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def dynamic_threshold_filter(
id_counts: torch.Tensor,
threshold_skew_multiplier: float = 10.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Threshold is total_count / num_ids * threshold_skew_multiplier. An id is
added if its count is strictly greater than the threshold.
"""

num_ids = id_counts.numel()
total_count = id_counts.sum()
Expand All @@ -69,6 +73,10 @@ def dynamic_threshold_filter(
def average_threshold_filter(
id_counts: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Threshold is average of id_counts. An id is added if its count is strictly
greater than the mean.
"""
if id_counts.dtype != torch.float:
id_counts = id_counts.float()
threshold = id_counts.mean()
Expand All @@ -77,6 +85,26 @@ def average_threshold_filter(
return threshold_mask, threshold


@torch.no_grad()
def probabilistic_threshold_filter(
id_counts: torch.Tensor,
per_id_probability: float = 0.01,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Each id has probability per_id_probability of being added. For example,
if per_id_probability is 0.01 and an id appears 100 times, then it has a 60%
of being added. More precisely, the id score is 1 - (1 - per_id_probability) ^ id_count,
and for a randomly generated threshold, the id score is the chance of it being added.
"""
probability = torch.full_like(id_counts, 1 - per_id_probability, dtype=torch.float)
id_scores = 1 - torch.pow(probability, id_counts)

threshold: torch.Tensor = torch.rand(id_counts.size(), device=id_counts.device)
threshold_mask = id_scores > threshold

return threshold_mask, threshold


class ManagedCollisionModule(nn.Module):
"""
Abstract base class for ManagedCollisionModule.
Expand Down
128 changes: 128 additions & 0 deletions torchrec/modules/tests/test_mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

import torch
from torchrec.modules.mc_modules import (
average_threshold_filter,
DistanceLFU_EvictionPolicy,
dynamic_threshold_filter,
LFU_EvictionPolicy,
LRU_EvictionPolicy,
MCHManagedCollisionModule,
probabilistic_threshold_filter,
)
from torchrec.sparse.jagged_tensor import JaggedTensor

Expand Down Expand Up @@ -215,3 +218,128 @@ def test_distance_lfu_eviction_fast_decay(self) -> None:
self.assertEqual(list(_mch_counts), [1, 1, 1, 1, torch.iinfo(torch.int64).max])
_mch_last_access_iter = mc_module._mch_last_access_iter
self.assertEqual(list(_mch_last_access_iter), [2, 2, 3, 3, 3])

def test_dynamic_threshold_filter(self) -> None:
mc_module = MCHManagedCollisionModule(
zch_size=5,
device=torch.device("cpu"),
eviction_policy=LFU_EvictionPolicy(
threshold_filtering_func=lambda tensor: dynamic_threshold_filter(
tensor, threshold_skew_multiplier=0.75
)
),
eviction_interval=1,
input_hash_size=100,
)

# check initial state
_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5)
_mch_counts = mc_module._mch_counts
self.assertEqual(list(_mch_counts), [0] * 5)

ids = [5, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, 2, 2, 1]
# threshold is len(ids) / unique_count(ids) * threshold_skew_multiplier
# = 15 / 5 * 0.5 = 2.25
features: Dict[str, JaggedTensor] = {
"f1": JaggedTensor(
values=torch.tensor(ids, dtype=torch.int64),
lengths=torch.tensor([1] * len(ids), dtype=torch.int64),
)
}
mc_module.profile(features)

_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(
list(_mch_sorted_raw_ids),
[3, 4, 5, torch.iinfo(torch.int64).max, torch.iinfo(torch.int64).max],
)
_mch_counts = mc_module._mch_counts
self.assertEqual(list(_mch_counts), [3, 4, 5, 0, torch.iinfo(torch.int64).max])

def test_average_threshold_filter(self) -> None:
mc_module = MCHManagedCollisionModule(
zch_size=5,
device=torch.device("cpu"),
eviction_policy=LFU_EvictionPolicy(
threshold_filtering_func=average_threshold_filter
),
eviction_interval=1,
input_hash_size=100,
)

# check initial state
_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5)
_mch_counts = mc_module._mch_counts
self.assertEqual(list(_mch_counts), [0] * 5)

# insert some values to zch
# we have 10 counts of 4 and 1 count of 5
mc_module._mch_sorted_raw_ids[0:2] = torch.tensor([4, 5])
mc_module._mch_counts[0:2] = torch.tensor([10, 1])

ids = [3, 4, 5, 6, 6, 6, 7, 8, 8, 9, 10]
# threshold is 1.375
features: Dict[str, JaggedTensor] = {
"f1": JaggedTensor(
values=torch.tensor(ids, dtype=torch.int64),
lengths=torch.tensor([1] * len(ids), dtype=torch.int64),
)
}
mc_module.profile(features)

# empty, empty will be evicted
# 6, 8 will be added
# 7 is not added because it's below the average threshold
_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(
list(_mch_sorted_raw_ids), [4, 5, 6, 8, torch.iinfo(torch.int64).max]
)
# count for 4 is not updated since it's below the average threshold
_mch_counts = mc_module._mch_counts
self.assertEqual(list(_mch_counts), [10, 1, 3, 2, torch.iinfo(torch.int64).max])

def test_probabilistic_threshold_filter(self) -> None:
mc_module = MCHManagedCollisionModule(
zch_size=5,
device=torch.device("cpu"),
eviction_policy=LFU_EvictionPolicy(
threshold_filtering_func=lambda tensor: probabilistic_threshold_filter(
tensor,
per_id_probability=0.01,
)
),
eviction_interval=1,
input_hash_size=100,
)

# check initial state
_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5)
_mch_counts = mc_module._mch_counts
self.assertEqual(list(_mch_counts), [0] * 5)

unique_ids = [5, 4, 3, 2, 1]
id_counts = [100, 80, 60, 40, 10]
ids = [id for id, count in zip(unique_ids, id_counts) for _ in range(count)]
# chance of being added is [0.63, 0.55, 0.45, 0.33]
features: Dict[str, JaggedTensor] = {
"f1": JaggedTensor(
values=torch.tensor(ids, dtype=torch.int64),
lengths=torch.tensor([1] * len(ids), dtype=torch.int64),
)
}

torch.manual_seed(42)
for _ in range(10):
mc_module.profile(features)

_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
print(f"henry {mc_module._mch_counts}")
self.assertEqual(
sorted(_mch_sorted_raw_ids.tolist()),
[2, 3, 4, 5, torch.iinfo(torch.int64).max],
)
# _mch_counts is like
# [80, 180, 160, 800, 9223372036854775807]
Loading