From dca76eeb4518a61401f80301ceee193a390d63f0 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 14 Aug 2024 22:15:10 +0100 Subject: [PATCH] feat: create `mpiflags` option --- pysr/julia_helpers.py | 12 ++++++++++-- pysr/param_groupings.yml | 1 + pysr/sr.py | 8 +++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pysr/julia_helpers.py b/pysr/julia_helpers.py index 2591ec330..b9bf6c5a5 100644 --- a/pysr/julia_helpers.py +++ b/pysr/julia_helpers.py @@ -27,14 +27,22 @@ def _escape_filename(filename): return str_repr -def _load_cluster_manager(cluster_manager: str): +def _load_cluster_manager(cluster_manager: str, mpi_flags: str): if cluster_manager == "mpi": jl.seval("using Distributed: addprocs") jl.seval("using MPIClusterManagers: MPIWorkerManager") return jl.seval( "(np; exeflags=``, kws...) -> " - + "addprocs(MPIWorkerManager(np); exeflags=`$exeflags --project=$(Base.active_project())`, kws...)" + + "addprocs(MPIWorkerManager(np);" + + ",".join( + [ + "exeflags=`$exeflags --project=$(Base.active_project())`", + f"mpiflags=`{mpi_flags}`", + "kws...", + ] + ) + + ")" ) else: jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}") diff --git a/pysr/param_groupings.yml b/pysr/param_groupings.yml index 0ff9d63da..fcec5a6fc 100644 --- a/pysr/param_groupings.yml +++ b/pysr/param_groupings.yml @@ -70,6 +70,7 @@ - multithreading - cluster_manager - heap_size_hint_in_bytes + - mpi_flags - batching - batch_size - precision diff --git a/pysr/sr.py b/pysr/sr.py index adb4fa95a..63d4c9553 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -499,6 +499,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): "htc", or "mpi". If set to one of these, PySR will run in distributed mode, and use `procs` to figure out how many processes to launch. Default is `None`. + mpi_flags : str + (Experimental API) String of options to pass to `mpiexec`. + For example, `"-host worker1,worker2"`. + Default is `None`. heap_size_hint_in_bytes : int For multiprocessing, this sets the `--heap-size-hint` parameter for new Julia processes. This can be configured when using @@ -775,6 +779,7 @@ def __init__( cluster_manager: Optional[ Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc", "mpi"] ] = None, + mpi_flags: str = "", heap_size_hint_in_bytes: Optional[int] = None, batching: bool = False, batch_size: int = 50, @@ -872,6 +877,7 @@ def __init__( self.procs = procs self.multithreading = multithreading self.cluster_manager = cluster_manager + self.mpi_flags = mpi_flags self.heap_size_hint_in_bytes = heap_size_hint_in_bytes self.batching = batching self.batch_size = batch_size @@ -1751,7 +1757,7 @@ def _run( ) if cluster_manager is not None: - cluster_manager = _load_cluster_manager(cluster_manager) + cluster_manager = _load_cluster_manager(cluster_manager, self.mpi_flags) mutation_weights = SymbolicRegression.MutationWeights( mutate_constant=self.weight_mutate_constant,