From 7f3bb63261048eb2eafae2d193f29665f974eb93 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 18 Mar 2021 14:35:43 -0700 Subject: [PATCH] make config.FLAGS.jax_enable_foo an error --- jax/config.py | 24 +++++++++++++----------- tests/api_test.py | 2 +- tests/debug_nans_test.py | 4 ++-- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/jax/config.py b/jax/config.py index 650877bab07c..7d748d7fc55b 100644 --- a/jax/config.py +++ b/jax/config.py @@ -51,6 +51,7 @@ def __init__(self): self.meta = {} self.FLAGS = NameSpace(self.read) self.use_absl = False + self._contextmanager_flags = set() # TODO(mattjj): delete these when only omnistaging is available self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True) @@ -71,6 +72,13 @@ def update(self, name, val): lib.jax_jit.global_state().enable_x64 = val def read(self, name): + if name in self._contextmanager_flags: + raise AttributeError( + "For flags with a corresponding contextmanager, read their value " + f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.") + return self._read(name) + + def _read(self, name): if self.use_absl: return getattr(self.absl_flags.FLAGS, name) else: @@ -193,23 +201,17 @@ def define_bool_state(self, name: str, default: bool, help: str): with enable_foo(True): ... - Accessing ``config.FLAGS.jax_enable_foo`` is different from accessing the - thread-local state value via ``config.jax_enable_foo``: the former reads the - flag value determined set by the environment variable or command-line flag - and does not read the thread-local state, whereas the latter reads the - thread-local state value managed by the contextmanager. Think of the - contextmanager state as a layer on top of the flag value: if no - contextmanager is in use then ``config.jax_enable_foo`` reflects the flag - value ``config.FLAGS.jax_enable_foo``, whereas if a contextmanager is in use - then only ``config.jax_enable_foo`` is updated. So in general using - ``config.jax_enable_foo`` is best. + The value of the thread-local state or flag can be accessed via + ``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is + an error. """ name = name.lower() self.DEFINE_bool(name, bool_env(name.upper(), default), help) + self._contextmanager_flags.add(name) def get_state(self): val = getattr(_thread_local_state, name, unset) - return val if val is not unset else self.read(name) + return val if val is not unset else self._read(name) setattr(Config, name, property(get_state)) @contextlib.contextmanager diff --git a/tests/api_test.py b/tests/api_test.py index f66d74bbeb0d..04451597308c 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2334,7 +2334,7 @@ def test_leak_checker_catches_a_sublevel_leak(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): @jit def f(x): lst = [] diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 99785ff0f869..dba67072aa88 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -30,7 +30,7 @@ class DebugNaNsTest(jtu.JaxTestCase): def setUp(self): - self.cfg = config.read("jax_debug_nans") + self.cfg = config._read("jax_debug_nans") config.update("jax_debug_nans", True) def tearDown(self): @@ -144,7 +144,7 @@ def testPjit(self): class DebugInfsTest(jtu.JaxTestCase): def setUp(self): - self.cfg = config.read("jax_debug_infs") + self.cfg = config._read("jax_debug_infs") config.update("jax_debug_infs", True) def tearDown(self):