From 6cd1d2f53c2c91c2b8434562e11d587909cc8940 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 5 Feb 2024 11:28:46 -0800 Subject: [PATCH] Add percentage based threshold function and unit tests (#1679) Summary: Add percentage based threshold function to mc_modules. Also: 1. Add unit tests for all threshold functions. 2. Add one liner documentation for all threshold functions. 3. Add loggings for threshold. Note that regardless of eviction policy, threshold functions only look at eviction counts. Reviewed By: dstaay-fb Differential Revision: D53030312 --- torchrec/modules/mc_modules.py | 28 +++++ torchrec/modules/tests/test_mc_modules.py | 128 ++++++++++++++++++++++ 2 files changed, 156 insertions(+) diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index a61b52c63..c002deb83 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -52,6 +52,10 @@ def dynamic_threshold_filter( id_counts: torch.Tensor, threshold_skew_multiplier: float = 10.0, ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Threshold is total_count / num_ids * threshold_skew_multiplier. An id is + added if its count is strictly greater than the threshold. + """ num_ids = id_counts.numel() total_count = id_counts.sum() @@ -69,6 +73,10 @@ def dynamic_threshold_filter( def average_threshold_filter( id_counts: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Threshold is average of id_counts. An id is added if its count is strictly + greater than the mean. + """ if id_counts.dtype != torch.float: id_counts = id_counts.float() threshold = id_counts.mean() @@ -77,6 +85,26 @@ def average_threshold_filter( return threshold_mask, threshold +@torch.no_grad() +def probabilistic_threshold_filter( + id_counts: torch.Tensor, + per_id_probability: float = 0.01, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Each id has probability per_id_probability of being added. For example, + if per_id_probability is 0.01 and an id appears 100 times, then it has a 60% + of being added. More precisely, the id score is 1 - (1 - per_id_probability) ^ id_count, + and for a randomly generated threshold, the id score is the chance of it being added. + """ + probability = torch.full_like(id_counts, 1 - per_id_probability, dtype=torch.float) + id_scores = 1 - torch.pow(probability, id_counts) + + threshold: torch.Tensor = torch.rand(id_counts.size(), device=id_counts.device) + threshold_mask = id_scores > threshold + + return threshold_mask, threshold + + class ManagedCollisionModule(nn.Module): """ Abstract base class for ManagedCollisionModule. diff --git a/torchrec/modules/tests/test_mc_modules.py b/torchrec/modules/tests/test_mc_modules.py index a534d78a3..abcfe687a 100644 --- a/torchrec/modules/tests/test_mc_modules.py +++ b/torchrec/modules/tests/test_mc_modules.py @@ -10,10 +10,13 @@ import torch from torchrec.modules.mc_modules import ( + average_threshold_filter, DistanceLFU_EvictionPolicy, + dynamic_threshold_filter, LFU_EvictionPolicy, LRU_EvictionPolicy, MCHManagedCollisionModule, + probabilistic_threshold_filter, ) from torchrec.sparse.jagged_tensor import JaggedTensor @@ -215,3 +218,128 @@ def test_distance_lfu_eviction_fast_decay(self) -> None: self.assertEqual(list(_mch_counts), [1, 1, 1, 1, torch.iinfo(torch.int64).max]) _mch_last_access_iter = mc_module._mch_last_access_iter self.assertEqual(list(_mch_last_access_iter), [2, 2, 3, 3, 3]) + + def test_dynamic_threshold_filter(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy( + threshold_filtering_func=lambda tensor: dynamic_threshold_filter( + tensor, threshold_skew_multiplier=0.75 + ) + ), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + self.assertEqual(list(_mch_counts), [0] * 5) + + ids = [5, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, 2, 2, 1] + # threshold is len(ids) / unique_count(ids) * threshold_skew_multiplier + # = 15 / 5 * 0.5 = 2.25 + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + list(_mch_sorted_raw_ids), + [3, 4, 5, torch.iinfo(torch.int64).max, torch.iinfo(torch.int64).max], + ) + _mch_counts = mc_module._mch_counts + self.assertEqual(list(_mch_counts), [3, 4, 5, 0, torch.iinfo(torch.int64).max]) + + def test_average_threshold_filter(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy( + threshold_filtering_func=average_threshold_filter + ), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + self.assertEqual(list(_mch_counts), [0] * 5) + + # insert some values to zch + # we have 10 counts of 4 and 1 count of 5 + mc_module._mch_sorted_raw_ids[0:2] = torch.tensor([4, 5]) + mc_module._mch_counts[0:2] = torch.tensor([10, 1]) + + ids = [3, 4, 5, 6, 6, 6, 7, 8, 8, 9, 10] + # threshold is 1.375 + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + # empty, empty will be evicted + # 6, 8 will be added + # 7 is not added because it's below the average threshold + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + list(_mch_sorted_raw_ids), [4, 5, 6, 8, torch.iinfo(torch.int64).max] + ) + # count for 4 is not updated since it's below the average threshold + _mch_counts = mc_module._mch_counts + self.assertEqual(list(_mch_counts), [10, 1, 3, 2, torch.iinfo(torch.int64).max]) + + def test_probabilistic_threshold_filter(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy( + threshold_filtering_func=lambda tensor: probabilistic_threshold_filter( + tensor, + per_id_probability=0.01, + ) + ), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + self.assertEqual(list(_mch_counts), [0] * 5) + + unique_ids = [5, 4, 3, 2, 1] + id_counts = [100, 80, 60, 40, 10] + ids = [id for id, count in zip(unique_ids, id_counts) for _ in range(count)] + # chance of being added is [0.63, 0.55, 0.45, 0.33] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + + torch.manual_seed(42) + for _ in range(10): + mc_module.profile(features) + + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + print(f"henry {mc_module._mch_counts}") + self.assertEqual( + sorted(_mch_sorted_raw_ids.tolist()), + [2, 3, 4, 5, torch.iinfo(torch.int64).max], + ) + # _mch_counts is like + # [80, 180, 160, 800, 9223372036854775807]