diff --git a/pylammpsmpi/wrapper/concurrent.py b/pylammpsmpi/wrapper/concurrent.py index 5f3c5278..83b394ad 100644 --- a/pylammpsmpi/wrapper/concurrent.py +++ b/pylammpsmpi/wrapper/concurrent.py @@ -22,7 +22,13 @@ def _initialize_socket( - interface, cmdargs, cwd, cores, oversubscribe=False, enable_flux_backend=False + interface, + cmdargs, + cwd, + cores, + oversubscribe=False, + enable_flux_backend=False, + enable_slurm_backend=False, ): port_selected = interface.bind_to_random_port() executable = os.path.join( @@ -30,6 +36,8 @@ def _initialize_socket( ) if enable_flux_backend: cmds = ["flux", "run"] + elif enable_slurm_backend: + cmds = ["srun"] else: cmds = ["mpiexec"] if oversubscribe: @@ -42,7 +50,7 @@ def _initialize_socket( "--zmqport", str(port_selected), ] - if enable_flux_backend: + if enable_flux_backend or enable_slurm_backend: cmds += [ "--host", socket.gethostname(), @@ -59,6 +67,7 @@ def execute_async( cores, oversubscribe=False, enable_flux_backend=False, + enable_slurm_backend=False, cwd=None, queue_adapter=None, queue_adapter_kwargs=None, @@ -71,6 +80,7 @@ def execute_async( cwd=cwd, cores=cores, enable_flux_backend=enable_flux_backend, + enable_slurm_backend=enable_slurm_backend, oversubscribe=oversubscribe, ) while True: @@ -90,6 +100,7 @@ def __init__( cores=8, oversubscribe=False, enable_flux_backend=False, + enable_slurm_backend=False, working_directory=".", cmdargs=None, queue_adapter=None, @@ -101,6 +112,7 @@ def __init__( self._process = None self._oversubscribe = oversubscribe self._enable_flux_backend = enable_flux_backend + self._enable_slurm_backend = enable_slurm_backend self._cmdargs = cmdargs self._queue_adapter = queue_adapter self._queue_adapter_kwargs = queue_adapter_kwargs @@ -115,6 +127,7 @@ def _start_process(self): "cores": self.cores, "oversubscribe": self._oversubscribe, "enable_flux_backend": self._enable_flux_backend, + "enable_slurm_backend": self._enable_slurm_backend, "cwd": self.working_directory, "queue_adapter": self._queue_adapter, "queue_adapter_kwargs": self._queue_adapter_kwargs, diff --git a/pylammpsmpi/wrapper/extended.py b/pylammpsmpi/wrapper/extended.py index cc421bc9..e27b1cad 100644 --- a/pylammpsmpi/wrapper/extended.py +++ b/pylammpsmpi/wrapper/extended.py @@ -245,6 +245,7 @@ def __init__( cores=1, oversubscribe=False, enable_flux_backend=False, + enable_slurm_backend=False, working_directory=".", client=None, mode="local", @@ -256,12 +257,14 @@ def __init__( self.working_directory = working_directory self.oversubscribe = oversubscribe self.enable_flux_backend = enable_flux_backend + self.enable_slurm_backend = enable_slurm_backend self.client = client self.mode = mode self.lmp = LammpsConcurrent( cores=self.cores, oversubscribe=self.oversubscribe, enable_flux_backend=self.enable_flux_backend, + enable_slurm_backend=self.enable_slurm_backend, working_directory=self.working_directory, cmdargs=cmdargs, queue_adapter=queue_adapter,