diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index a61b52c63..c002deb83 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -52,6 +52,10 @@ def dynamic_threshold_filter( id_counts: torch.Tensor, threshold_skew_multiplier: float = 10.0, ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Threshold is total_count / num_ids * threshold_skew_multiplier. An id is + added if its count is strictly greater than the threshold. + """ num_ids = id_counts.numel() total_count = id_counts.sum() @@ -69,6 +73,10 @@ def dynamic_threshold_filter( def average_threshold_filter( id_counts: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Threshold is average of id_counts. An id is added if its count is strictly + greater than the mean. + """ if id_counts.dtype != torch.float: id_counts = id_counts.float() threshold = id_counts.mean() @@ -77,6 +85,26 @@ def average_threshold_filter( return threshold_mask, threshold +@torch.no_grad() +def probabilistic_threshold_filter( + id_counts: torch.Tensor, + per_id_probability: float = 0.01, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Each id has probability per_id_probability of being added. For example, + if per_id_probability is 0.01 and an id appears 100 times, then it has a 60% + of being added. More precisely, the id score is 1 - (1 - per_id_probability) ^ id_count, + and for a randomly generated threshold, the id score is the chance of it being added. + """ + probability = torch.full_like(id_counts, 1 - per_id_probability, dtype=torch.float) + id_scores = 1 - torch.pow(probability, id_counts) + + threshold: torch.Tensor = torch.rand(id_counts.size(), device=id_counts.device) + threshold_mask = id_scores > threshold + + return threshold_mask, threshold + + class ManagedCollisionModule(nn.Module): """ Abstract base class for ManagedCollisionModule. diff --git a/torchrec/modules/tests/test_mc_modules.py b/torchrec/modules/tests/test_mc_modules.py index a534d78a3..abcfe687a 100644 --- a/torchrec/modules/tests/test_mc_modules.py +++ b/torchrec/modules/tests/test_mc_modules.py @@ -10,10 +10,13 @@ import torch from torchrec.modules.mc_modules import ( + average_threshold_filter, DistanceLFU_EvictionPolicy, + dynamic_threshold_filter, LFU_EvictionPolicy, LRU_EvictionPolicy, MCHManagedCollisionModule, + probabilistic_threshold_filter, ) from torchrec.sparse.jagged_tensor import JaggedTensor @@ -215,3 +218,128 @@ def test_distance_lfu_eviction_fast_decay(self) -> None: self.assertEqual(list(_mch_counts), [1, 1, 1, 1, torch.iinfo(torch.int64).max]) _mch_last_access_iter = mc_module._mch_last_access_iter self.assertEqual(list(_mch_last_access_iter), [2, 2, 3, 3, 3]) + + def test_dynamic_threshold_filter(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy( + threshold_filtering_func=lambda tensor: dynamic_threshold_filter( + tensor, threshold_skew_multiplier=0.75 + ) + ), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + self.assertEqual(list(_mch_counts), [0] * 5) + + ids = [5, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, 2, 2, 1] + # threshold is len(ids) / unique_count(ids) * threshold_skew_multiplier + # = 15 / 5 * 0.5 = 2.25 + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + list(_mch_sorted_raw_ids), + [3, 4, 5, torch.iinfo(torch.int64).max, torch.iinfo(torch.int64).max], + ) + _mch_counts = mc_module._mch_counts + self.assertEqual(list(_mch_counts), [3, 4, 5, 0, torch.iinfo(torch.int64).max]) + + def test_average_threshold_filter(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy( + threshold_filtering_func=average_threshold_filter + ), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + self.assertEqual(list(_mch_counts), [0] * 5) + + # insert some values to zch + # we have 10 counts of 4 and 1 count of 5 + mc_module._mch_sorted_raw_ids[0:2] = torch.tensor([4, 5]) + mc_module._mch_counts[0:2] = torch.tensor([10, 1]) + + ids = [3, 4, 5, 6, 6, 6, 7, 8, 8, 9, 10] + # threshold is 1.375 + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + # empty, empty will be evicted + # 6, 8 will be added + # 7 is not added because it's below the average threshold + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + list(_mch_sorted_raw_ids), [4, 5, 6, 8, torch.iinfo(torch.int64).max] + ) + # count for 4 is not updated since it's below the average threshold + _mch_counts = mc_module._mch_counts + self.assertEqual(list(_mch_counts), [10, 1, 3, 2, torch.iinfo(torch.int64).max]) + + def test_probabilistic_threshold_filter(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy( + threshold_filtering_func=lambda tensor: probabilistic_threshold_filter( + tensor, + per_id_probability=0.01, + ) + ), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + self.assertEqual(list(_mch_counts), [0] * 5) + + unique_ids = [5, 4, 3, 2, 1] + id_counts = [100, 80, 60, 40, 10] + ids = [id for id, count in zip(unique_ids, id_counts) for _ in range(count)] + # chance of being added is [0.63, 0.55, 0.45, 0.33] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + + torch.manual_seed(42) + for _ in range(10): + mc_module.profile(features) + + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + print(f"henry {mc_module._mch_counts}") + self.assertEqual( + sorted(_mch_sorted_raw_ids.tolist()), + [2, 3, 4, 5, torch.iinfo(torch.int64).max], + ) + # _mch_counts is like + # [80, 180, 160, 800, 9223372036854775807]