diff --git a/fuse/plugin.py b/fuse/plugin.py index 75bf56c0..ff651c40 100644 --- a/fuse/plugin.py +++ b/fuse/plugin.py @@ -29,6 +29,11 @@ class FuseBasePlugin(strax.Plugin): help="Set the random seed from lineage and run_id, or pull the seed from the OS.", ) + user_defined_random_seed = straxen.URLConfig( + default=None, + help="Define the random seed manually. You need to set deterministic_seed to False.", + ) + def setup(self): super().setup() @@ -41,13 +46,35 @@ def setup(self): log.setLevel("INFO") if self.deterministic_seed: + + if self.user_defined_random_seed is not None: + log.warning( + "deterministic_seed is set to True. " + "The provided user_defined_random_seed will not be used!" + ) + hash_string = strax.deterministic_hash((self.run_id, self.lineage)) self.seed = int(hash_string.encode().hex(), 16) self.rng = np.random.default_rng(seed=self.seed) - log.debug(f"Generating random numbers from seed {self.seed}") + log.debug(f"Generating random numbers from deterministic seed {self.seed}") else: - self.rng = np.random.default_rng() - log.debug("Generating random numbers with seed pulled from OS") + + if self.user_defined_random_seed is not None: + + assert ( + isinstance(self.user_defined_random_seed, int) + and self.user_defined_random_seed > 0 + ), "user_defined_random_seed must be a positive integer!" + + self.seed = self.user_defined_random_seed + self.rng = np.random.default_rng(self.user_defined_random_seed) + log.info( + "Generating random numbers with user" + f"defined seed {self.user_defined_random_seed}" + ) + else: + self.rng = np.random.default_rng() + log.debug("Generating random numbers with seed pulled from OS") class FuseBaseDownChunkingPlugin(strax.DownChunkingPlugin, FuseBasePlugin): diff --git a/fuse/plugins/detector_physics/s1_photon_propagation.py b/fuse/plugins/detector_physics/s1_photon_propagation.py index 2b43272c..b930b462 100644 --- a/fuse/plugins/detector_physics/s1_photon_propagation.py +++ b/fuse/plugins/detector_physics/s1_photon_propagation.py @@ -140,7 +140,7 @@ class S1PhotonPropagationBase(FuseBasePlugin): def setup(self): super().setup() - if self.deterministic_seed: + if self.deterministic_seed or (self.user_defined_random_seed is not None): # Dont know but nestpy seems to have a problem with large seeds self.short_seed = int(repr(self.seed)[-8:]) log.debug(f"Generating nestpy random numbers from seed {self.short_seed}") diff --git a/fuse/plugins/micro_physics/yields.py b/fuse/plugins/micro_physics/yields.py index 11e2431f..90d3687c 100644 --- a/fuse/plugins/micro_physics/yields.py +++ b/fuse/plugins/micro_physics/yields.py @@ -41,7 +41,7 @@ class NestYields(FuseBasePlugin): def setup(self): super().setup() - if self.deterministic_seed: + if self.deterministic_seed or (self.user_defined_random_seed is not None): # Dont know but nestpy seems to have a problem with large seeds self.short_seed = int(repr(self.seed)[-8:]) log.debug(f"Generating nest random numbers starting with seed {self.short_seed}") diff --git a/tests/test_plugin_random_seed.py b/tests/test_plugin_random_seed.py new file mode 100644 index 00000000..e232926e --- /dev/null +++ b/tests/test_plugin_random_seed.py @@ -0,0 +1,138 @@ +import os +import shutil +import unittest +import tempfile +import timeout_decorator +import fuse +import straxen +from _utils import test_root_file_name + +TIMEOUT = 60 + + +class TestPluginRandomSeeds(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.TemporaryDirectory() + cls.run_number = "TestRun_00000" + + cls.test_context = fuse.context.full_chain_context( + cls.temp_dir.name, run_without_proper_corrections=True + ) + cls.test_context.set_config( + { + "path": cls.temp_dir.name, + "file_name": test_root_file_name, + "entry_start": 0, + "entry_stop": 10, + } + ) + + # Get all registered fuse plutins. + cls.all_registered_fuse_plugins = {} + for key, value in cls.test_context._plugin_class_registry.items(): + if "fuse" in str(value): + cls.all_registered_fuse_plugins[key] = value + + @classmethod + def tearDownClass(cls): + cls.temp_dir.cleanup() + + def setUp(self): + downloader = straxen.MongoDownloader(store_files_at=(self.temp_dir.name,)) + downloader.download_single(test_root_file_name, human_readable_file_name=True) + + assert os.path.exists(os.path.join(self.temp_dir.name, test_root_file_name)) + + def tearDown(self): + # self.temp_dir.cleanup() + shutil.rmtree(self.temp_dir.name) + os.makedirs(self.temp_dir.name) + + @timeout_decorator.timeout(TIMEOUT, exception_message="test_if_plugins_get_user_seed timed out") + def test_if_plugins_get_user_seed(self): + self.test_context.set_config( + { + "deterministic_seed": False, + "user_defined_random_seed": 42, + } + ) + + # Lets check the random seed for all fuse plugins + for key in self.all_registered_fuse_plugins.keys(): + plugin = self.test_context.get_single_plugin(self.run_number, key) + + if hasattr(plugin, "seed"): + assert ( + plugin.seed == 42 + ), f"Expecting seed to be 42, but got {plugin.seed} for {key} plugin" + + @timeout_decorator.timeout( + TIMEOUT, exception_message="test_if_plugins_with_rng_have_a_proper_seed timed out" + ) + def test_if_plugins_with_rng_have_a_proper_seed(self): + + # Lets check the random seed for all fuse plugins + for key in self.all_registered_fuse_plugins.keys(): + plugin = self.test_context.get_single_plugin(self.run_number, key) + + if hasattr(plugin, "rng"): + if not hasattr(plugin, "seed"): + raise ValueError(f"Plugin {key} has rng but no seed") + + @timeout_decorator.timeout( + TIMEOUT, exception_message="test_if_negative_seeds_are_intercepted timed out" + ) + def test_if_negative_seeds_are_intercepted(self): + self.test_context.set_config( + { + "deterministic_seed": False, + "user_defined_random_seed": -42, + } + ) + + # Lets check the random seed for all fuse plugins + for key in self.all_registered_fuse_plugins.keys(): + + with self.assertRaises(AssertionError): + plugin = self.test_context.get_single_plugin(self.run_number, key) + + # Some plugins have no seed, so we can't check for negative seeds. + if not hasattr(plugin, "seed"): + raise AssertionError(f"Plugin {key} has no seed") + + @timeout_decorator.timeout( + TIMEOUT, exception_message="test_if_run_number_changes_deterministic_seed timed out" + ) + def test_if_run_number_changes_deterministic_seed(self): + + self.test_context.set_config({"deterministic_seed": True}) + + # Lets check the random seed for all fuse plugins + for key in self.all_registered_fuse_plugins.keys(): + + plugin = self.test_context.get_single_plugin("00000", key) + + if hasattr(plugin, "seed"): + + seed_0 = self.test_context.get_single_plugin("00000", key).seed + seed_1 = self.test_context.get_single_plugin("00001", key).seed + + assert ( + seed_0 != seed_1 + ), f"Expecting seed to be different for different run numbers for {key} plugin" + + @timeout_decorator.timeout( + TIMEOUT, exception_message="test_if_tracked_config_changes_deterministic_seed timed out" + ) + def test_if_tracked_config_changes_deterministic_seed(self): + + self.test_context.set_config({"deterministic_seed": True}) + + seed_0 = self.test_context.get_single_plugin(self.run_number, "raw_records").seed + + # Change some tracked config argument. + self.test_context.set_config({"entry_stop": 20}) + seed_1 = self.test_context.get_single_plugin(self.run_number, "raw_records").seed + + assert seed_0 != seed_1, "Expecting seed to be different for different config args!"