Skip to content

Commit

Permalink
Merge pull request #765 from DilumAluthge/dpa/eager-julia-registry
Browse files Browse the repository at this point in the history
Fall back to `eager` registry when needed
  • Loading branch information
MilesCranmer authored Dec 6, 2024
2 parents 9d976e9 + 25c8639 commit 2b00ada
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 4 deletions.
9 changes: 7 additions & 2 deletions pysr/julia_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Literal

from .julia_import import Pkg, jl
from .julia_registry_helpers import try_with_registry_fallback
from .logger_specs import AbstractLoggerSpec, TensorBoardLoggerSpec


Expand Down Expand Up @@ -47,8 +48,12 @@ def isinstalled(uuid_s: str):

def load_package(package_name: str, uuid_s: str) -> None:
if not isinstalled(uuid_s):
Pkg.add(name=package_name, uuid=uuid_s)
Pkg.resolve()

def _add_package():
Pkg.add(name=package_name, uuid=uuid_s)
Pkg.resolve()

try_with_registry_fallback(_add_package)

# TODO: Protect against loading the same symbol from two packages,
# maybe with a @gensym here.
Expand Down
10 changes: 10 additions & 0 deletions pysr/julia_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from types import ModuleType
from typing import cast

from .julia_registry_helpers import try_with_registry_fallback

# Check if JuliaCall is already loaded, and if so, warn the user
# about the relevant environment variables. If not loaded,
# set up sensible defaults.
Expand Down Expand Up @@ -42,6 +44,14 @@
# Deprecated; so just pass to juliacall
os.environ["PYTHON_JULIACALL_AUTOLOAD_IPYTHON_EXTENSION"] = autoload_extensions


def _import_juliacall():
import juliacall # type: ignore


try_with_registry_fallback(_import_juliacall)


from juliacall import AnyValue # type: ignore
from juliacall import VectorValue # type: ignore
from juliacall import Main as jl # type: ignore
Expand Down
44 changes: 44 additions & 0 deletions pysr/julia_registry_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Utilities for managing Julia registry preferences during package operations."""

import os
import warnings
from collections.abc import Callable
from typing import TypeVar

T = TypeVar("T")

PREFERENCE_KEY = "JULIA_PKG_SERVER_REGISTRY_PREFERENCE"


def try_with_registry_fallback(f: Callable[..., T], *args, **kwargs) -> T:
"""Execute function with modified Julia registry preference.
First tries with existing registry preference. If that fails with a Julia registry error,
temporarily modifies the registry preference to 'eager'. Restores original preference after
execution.
"""
try:
return f(*args, **kwargs)
except Exception as initial_error:
# Check if this is a Julia registry error by looking at the error message
if "JuliaError" not in str(
type(initial_error)
) or "Unsatisfiable requirements detected" not in str(initial_error):
raise initial_error

old_value = os.environ.get(PREFERENCE_KEY, None)
if old_value == "eager":
raise initial_error

warnings.warn(
"Initial Julia registry operation failed. Attempting to use the `eager` registry flavor of the Julia "
+ f"General registry from the Julia Pkg server (via the `{PREFERENCE_KEY}` environment variable)."
)
os.environ[PREFERENCE_KEY] = "eager"
try:
return f(*args, **kwargs)
finally:
if old_value is not None:
os.environ[PREFERENCE_KEY] = old_value
else:
del os.environ[PREFERENCE_KEY]
70 changes: 68 additions & 2 deletions pysr/test/test_startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

import numpy as np

from pysr import PySRRegressor
from pysr import PySRRegressor, jl
from pysr.julia_import import jl_version
from pysr.julia_registry_helpers import PREFERENCE_KEY, try_with_registry_fallback

from .params import DEFAULT_NITERATIONS, DEFAULT_POPULATIONS

Expand Down Expand Up @@ -159,8 +160,73 @@ def test_notebook(self):
self.assertEqual(result.returncode, 0)


class TestRegistryHelper(unittest.TestCase):
"""Test the custom Julia registry preference handling."""

def setUp(self):
self.old_value = os.environ.get(PREFERENCE_KEY, None)
self.recorded_env_vars = []
self.hits = 0

def failing_operation():
self.recorded_env_vars.append(os.environ[PREFERENCE_KEY])
self.hits += 1
# Just add some package I know will not exist and also not be in the dependency chain:
jl.Pkg.add(name="AirspeedVelocity", version="100.0.0")

self.failing_operation = failing_operation

def tearDown(self):
if self.old_value is not None:
os.environ[PREFERENCE_KEY] = self.old_value
else:
os.environ.pop(PREFERENCE_KEY, None)

def test_successful_operation(self):
self.assertEqual(try_with_registry_fallback(lambda s: s, "success"), "success")

def test_non_julia_errors_reraised(self):
with self.assertRaises(SyntaxError) as context:
try_with_registry_fallback(lambda: exec("invalid syntax !@#$"))
self.assertNotIn("JuliaError", str(context.exception))

def test_julia_error_triggers_fallback(self):
os.environ[PREFERENCE_KEY] = "conservative"

with self.assertWarns(Warning) as warn_context:
with self.assertRaises(Exception) as error_context:
try_with_registry_fallback(self.failing_operation)

self.assertIn(
"Unsatisfiable requirements detected", str(error_context.exception)
)
self.assertIn(
"Initial Julia registry operation failed. Attempting to use the `eager` registry flavor of the Julia",
str(warn_context.warning),
)

# Verify both modes are tried in order
self.assertEqual(self.recorded_env_vars, ["conservative", "eager"])
self.assertEqual(self.hits, 2)

# Verify environment is restored
self.assertEqual(os.environ[PREFERENCE_KEY], "conservative")

def test_eager_mode_fails_directly(self):
os.environ[PREFERENCE_KEY] = "eager"

with self.assertRaises(Exception) as context:
try_with_registry_fallback(self.failing_operation)

self.assertIn("Unsatisfiable requirements detected", str(context.exception))
self.assertEqual(
self.recorded_env_vars, ["eager"]
) # Should only try eager mode
self.assertEqual(self.hits, 1)


def runtests(just_tests=False):
tests = [TestStartup]
tests = [TestStartup, TestRegistryHelper]
if just_tests:
return tests
suite = unittest.TestSuite()
Expand Down

0 comments on commit 2b00ada

Please sign in to comment.