Skip to content

Commit

Permalink
feat: create mpiflags option
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Aug 14, 2024
1 parent b3f28d5 commit dca76ee
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
12 changes: 10 additions & 2 deletions pysr/julia_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,22 @@ def _escape_filename(filename):
return str_repr


def _load_cluster_manager(cluster_manager: str):
def _load_cluster_manager(cluster_manager: str, mpi_flags: str):
if cluster_manager == "mpi":
jl.seval("using Distributed: addprocs")
jl.seval("using MPIClusterManagers: MPIWorkerManager")

return jl.seval(
"(np; exeflags=``, kws...) -> "
+ "addprocs(MPIWorkerManager(np); exeflags=`$exeflags --project=$(Base.active_project())`, kws...)"
+ "addprocs(MPIWorkerManager(np);"
+ ",".join(
[
"exeflags=`$exeflags --project=$(Base.active_project())`",
f"mpiflags=`{mpi_flags}`",
"kws...",
]
)
+ ")"
)
else:
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
Expand Down
1 change: 1 addition & 0 deletions pysr/param_groupings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
- multithreading
- cluster_manager
- heap_size_hint_in_bytes
- mpi_flags
- batching
- batch_size
- precision
Expand Down
8 changes: 7 additions & 1 deletion pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
"htc", or "mpi". If set to one of these, PySR will run in distributed
mode, and use `procs` to figure out how many processes to launch.
Default is `None`.
mpi_flags : str
(Experimental API) String of options to pass to `mpiexec`.
For example, `"-host worker1,worker2"`.
Default is `None`.
heap_size_hint_in_bytes : int
For multiprocessing, this sets the `--heap-size-hint` parameter
for new Julia processes. This can be configured when using
Expand Down Expand Up @@ -775,6 +779,7 @@ def __init__(
cluster_manager: Optional[
Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc", "mpi"]
] = None,
mpi_flags: str = "",
heap_size_hint_in_bytes: Optional[int] = None,
batching: bool = False,
batch_size: int = 50,
Expand Down Expand Up @@ -872,6 +877,7 @@ def __init__(
self.procs = procs
self.multithreading = multithreading
self.cluster_manager = cluster_manager
self.mpi_flags = mpi_flags
self.heap_size_hint_in_bytes = heap_size_hint_in_bytes
self.batching = batching
self.batch_size = batch_size
Expand Down Expand Up @@ -1751,7 +1757,7 @@ def _run(
)

if cluster_manager is not None:
cluster_manager = _load_cluster_manager(cluster_manager)
cluster_manager = _load_cluster_manager(cluster_manager, self.mpi_flags)

mutation_weights = SymbolicRegression.MutationWeights(
mutate_constant=self.weight_mutate_constant,
Expand Down

0 comments on commit dca76ee

Please sign in to comment.