Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional random seed as user input #179

Merged
merged 9 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions fuse/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class FuseBasePlugin(strax.Plugin):
help="Set the random seed from lineage and run_id, or pull the seed from the OS.",
)

user_defined_random_seed = straxen.URLConfig(
default=None,
help="Define the random seed manually. You need to set deterministic_seed to False.",
)

def setup(self):
super().setup()

Expand All @@ -41,13 +46,35 @@ def setup(self):
log.setLevel("INFO")

if self.deterministic_seed:

if self.user_defined_random_seed is not None:
log.warning(
"deterministic_seed is set to True. "
"The provided user_defined_random_seed will not be used!"
)

hash_string = strax.deterministic_hash((self.run_id, self.lineage))
self.seed = int(hash_string.encode().hex(), 16)
self.rng = np.random.default_rng(seed=self.seed)
log.debug(f"Generating random numbers from seed {self.seed}")
log.debug(f"Generating random numbers from deterministic seed {self.seed}")
else:
self.rng = np.random.default_rng()
log.debug("Generating random numbers with seed pulled from OS")

if self.user_defined_random_seed is not None:

assert (
isinstance(self.user_defined_random_seed, int)
and self.user_defined_random_seed > 0
), "user_defined_random_seed must be a positive integer!"

self.seed = self.user_defined_random_seed
self.rng = np.random.default_rng(self.user_defined_random_seed)
log.info(
"Generating random numbers with user"
f"defined seed {self.user_defined_random_seed}"
)
else:
self.rng = np.random.default_rng()
log.debug("Generating random numbers with seed pulled from OS")


class FuseBaseDownChunkingPlugin(strax.DownChunkingPlugin, FuseBasePlugin):
Expand Down
2 changes: 1 addition & 1 deletion fuse/plugins/detector_physics/s1_photon_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class S1PhotonPropagationBase(FuseBasePlugin):
def setup(self):
super().setup()

if self.deterministic_seed:
if self.deterministic_seed or (self.user_defined_random_seed is not None):
# Dont know but nestpy seems to have a problem with large seeds
self.short_seed = int(repr(self.seed)[-8:])
log.debug(f"Generating nestpy random numbers from seed {self.short_seed}")
Expand Down
2 changes: 1 addition & 1 deletion fuse/plugins/micro_physics/yields.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class NestYields(FuseBasePlugin):
def setup(self):
super().setup()

if self.deterministic_seed:
if self.deterministic_seed or (self.user_defined_random_seed is not None):
# Dont know but nestpy seems to have a problem with large seeds
self.short_seed = int(repr(self.seed)[-8:])
log.debug(f"Generating nest random numbers starting with seed {self.short_seed}")
Expand Down
138 changes: 138 additions & 0 deletions tests/test_plugin_random_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
import shutil
import unittest
import tempfile
import timeout_decorator
import fuse
import straxen
from _utils import test_root_file_name

TIMEOUT = 60


class TestPluginRandomSeeds(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.TemporaryDirectory()
cls.run_number = "TestRun_00000"

cls.test_context = fuse.context.full_chain_context(
cls.temp_dir.name, run_without_proper_corrections=True
)
cls.test_context.set_config(
{
"path": cls.temp_dir.name,
"file_name": test_root_file_name,
"entry_start": 0,
"entry_stop": 10,
}
)

# Get all registered fuse plutins.
cls.all_registered_fuse_plugins = {}
for key, value in cls.test_context._plugin_class_registry.items():
if "fuse" in str(value):
cls.all_registered_fuse_plugins[key] = value

@classmethod
def tearDownClass(cls):
cls.temp_dir.cleanup()

def setUp(self):
downloader = straxen.MongoDownloader(store_files_at=(self.temp_dir.name,))
downloader.download_single(test_root_file_name, human_readable_file_name=True)

assert os.path.exists(os.path.join(self.temp_dir.name, test_root_file_name))

def tearDown(self):
# self.temp_dir.cleanup()
shutil.rmtree(self.temp_dir.name)
os.makedirs(self.temp_dir.name)

@timeout_decorator.timeout(TIMEOUT, exception_message="test_if_plugins_get_user_seed timed out")
def test_if_plugins_get_user_seed(self):
self.test_context.set_config(
{
"deterministic_seed": False,
"user_defined_random_seed": 42,
}
)

# Lets check the random seed for all fuse plugins
for key in self.all_registered_fuse_plugins.keys():
plugin = self.test_context.get_single_plugin(self.run_number, key)

if hasattr(plugin, "seed"):
assert (
plugin.seed == 42
), f"Expecting seed to be 42, but got {plugin.seed} for {key} plugin"

@timeout_decorator.timeout(
TIMEOUT, exception_message="test_if_plugins_with_rng_have_a_proper_seed timed out"
)
def test_if_plugins_with_rng_have_a_proper_seed(self):

# Lets check the random seed for all fuse plugins
for key in self.all_registered_fuse_plugins.keys():
plugin = self.test_context.get_single_plugin(self.run_number, key)

if hasattr(plugin, "rng"):
if not hasattr(plugin, "seed"):
raise ValueError(f"Plugin {key} has rng but no seed")

@timeout_decorator.timeout(
TIMEOUT, exception_message="test_if_negative_seeds_are_intercepted timed out"
)
def test_if_negative_seeds_are_intercepted(self):
self.test_context.set_config(
{
"deterministic_seed": False,
"user_defined_random_seed": -42,
}
)

# Lets check the random seed for all fuse plugins
for key in self.all_registered_fuse_plugins.keys():

with self.assertRaises(AssertionError):
plugin = self.test_context.get_single_plugin(self.run_number, key)

# Some plugins have no seed, so we can't check for negative seeds.
if not hasattr(plugin, "seed"):
raise AssertionError(f"Plugin {key} has no seed")

@timeout_decorator.timeout(
TIMEOUT, exception_message="test_if_run_number_changes_deterministic_seed timed out"
)
def test_if_run_number_changes_deterministic_seed(self):

self.test_context.set_config({"deterministic_seed": True})

# Lets check the random seed for all fuse plugins
for key in self.all_registered_fuse_plugins.keys():

plugin = self.test_context.get_single_plugin("00000", key)

if hasattr(plugin, "seed"):

seed_0 = self.test_context.get_single_plugin("00000", key).seed
seed_1 = self.test_context.get_single_plugin("00001", key).seed

assert (
seed_0 != seed_1
), f"Expecting seed to be different for different run numbers for {key} plugin"

@timeout_decorator.timeout(
TIMEOUT, exception_message="test_if_tracked_config_changes_deterministic_seed timed out"
)
def test_if_tracked_config_changes_deterministic_seed(self):

self.test_context.set_config({"deterministic_seed": True})

seed_0 = self.test_context.get_single_plugin(self.run_number, "raw_records").seed

# Change some tracked config argument.
self.test_context.set_config({"entry_stop": 20})
seed_1 = self.test_context.get_single_plugin(self.run_number, "raw_records").seed

assert seed_0 != seed_1, "Expecting seed to be different for different config args!"