"""
The dispatcher distribute function calls to slurm or the local machine.
It can be accessed with `get_dispatcher` and set with `set_dispatcher`.
This allows to change the behaviour of the distribution, e.g., we use it for batch:
Batch simply wraps the dispatcher by a buffered version.
"""
from __future__ import annotations
import abc
import logging
import shlex
import shutil
import subprocess
from collections.abc import Iterable
from pathlib import Path
from typing import Any
import simple_slurm
from .conf import _get_conf
from .execute_cmds import create_slurminade_command
from .function_call import FunctionCall
from .function_map import FunctionMap, get_entry_point
from .guard import dispatch_guard
from .job_reference import JobReference
from .options import SlurmOptions
# MAX_ARG_STRLEN on a Linux system with PAGE_SIZE 4096 is 131072
DEFAULT_MAX_ARG_LENGTH = 100000
# Module-level logger for consistent logging
_logger = logging.getLogger("slurminade.dispatcher")
[docs]
class Dispatcher(abc.ABC):
"""
Abstract dispatcher to be inherited by all concrete dispatchers.
For implementing a dispatcher you have to implement `_dispatch`, `srun` and `sbatch`.
"""
@abc.abstractmethod
def _dispatch(
self,
funcs: Iterable[FunctionCall],
options: SlurmOptions,
entry_point: Path,
block: bool = False,
) -> JobReference:
"""
Define how to dispatch a number of function calls.
:param funcs: The function calls to be dispatched.
:param options: The slurm options to be used.
:return: The job id. Use -1 if not applicable (e.g., because buffered)
"""
[docs]
@abc.abstractmethod
def srun(
self,
command: str,
conf: SlurmOptions | None = None,
simple_slurm_kwargs: dict | None = None,
) -> JobReference:
"""
Define how you want to execute an `srun` command. This command is directly
executed and only terminates after completion.
:param command: A system command, e.g. `echo hello world > foobar.txt`.
:param conf: The slurm configuration.
:param simple_slurm_kwargs: Additional options for simple_slurm.
:return: Job id
"""
[docs]
@abc.abstractmethod
def sbatch(
self,
command: str,
conf: SlurmOptions | None = None,
simple_slurm_kwargs: dict | None = None,
) -> JobReference:
"""
Define how you want to execute an `sbatch` command. The command is scheduled
and the function return immediately.
:param command: A system command, e.g. `echo hello world > foobar.txt`.
:param conf: The slurm configuration.
:param simple_slurm_kwargs: Additional options for simple_slurm.
:return: Job id.
"""
def _log_dispatch(self, funcs: list[FunctionCall], options: SlurmOptions) -> None:
"""Log dispatching information with lazy formatting."""
if len(funcs) == 1:
_logger.info(
"Dispatching task with options %s: %s", options, funcs[0]
)
else:
_logger.info(
"Dispatching task with %d function calls and options %s: %s",
len(funcs),
options,
", ".join(str(f) for f in funcs),
)
def __call__(
self,
funcs: FunctionCall | Iterable[FunctionCall],
options: SlurmOptions,
entry_point: Path,
block: bool = False,
) -> JobReference:
"""
Dispatches a function call or a number of function calls.
:param funcs: The function calls to be distributed.
:param options: The slurm options to be used.
:return: Job id.
"""
if isinstance(funcs, FunctionCall):
funcs = [funcs]
funcs = list(funcs)
self._log_dispatch(funcs, options)
return self._dispatch(funcs, options, entry_point, block)
[docs]
def is_sequential(self) -> bool:
"""
Return true if the dispatcher works sequential. In this case, the dependencies
are trivially fulfilled. Slurm does not work sequentially, because this
would destroy its purpose. In some cases however, you do not want to use
slurm for compatibility reasons, without changing the script. In these cases,
this function tells slurminade not to be too strict about dependencies.
:return: True is tasks are executed sequentially, false if not.
"""
return False
[docs]
def join(self) -> None:
if self.is_sequential():
# Already sequential, nothing to do
return
msg = "Joining is not implemented for this dispatcher."
raise NotImplementedError(msg)
[docs]
class TestJobReference(JobReference):
[docs]
def get_job_id(self) -> None:
return None
[docs]
def get_exit_code(self) -> None:
return None
[docs]
def get_info(self) -> dict[str, Any]:
return {"info": "test"}
[docs]
class TestDispatcher(Dispatcher):
"""
A dummy dispatcher that just prints the output. Primarily for debugging and testing.
"""
def __init__(self):
super().__init__()
self.calls = []
self.sbatches = []
self.sruns = []
self.max_arg_length = DEFAULT_MAX_ARG_LENGTH
def _dispatch(
self,
funcs: Iterable[FunctionCall],
options: SlurmOptions, # noqa: ARG002
entry_point: Path, # noqa: ARG002
block: bool = False, # noqa: ARG002
) -> JobReference:
dispatch_guard()
funcs = list(funcs)
command = create_slurminade_command(
get_entry_point(), funcs, self.max_arg_length
)
logging.getLogger("slurminade").info("Command: %s", command)
self.calls.append(funcs)
self._cleanup(command)
return TestJobReference()
def _cleanup(self, command: list[str]) -> None:
if len(command) < 2 or command[-2] != "--fromfile":
return
filename = Path(command[-1])
if filename.exists():
filename.unlink()
[docs]
def srun(
self,
command: str,
conf: SlurmOptions | None = None, # noqa: ARG002
simple_slurm_kwargs: dict | None = None, # noqa: ARG002
) -> JobReference:
dispatch_guard()
self.sruns.append(command)
logging.getLogger("slurminade").info("[test output] SRUN %s", command)
return TestJobReference()
[docs]
def sbatch(
self,
command: str,
conf: SlurmOptions | None = None, # noqa: ARG002
simple_slurm_kwargs: dict | None = None, # noqa: ARG002
) -> JobReference:
dispatch_guard()
self.sbatches.append(command)
logging.getLogger("slurminade").info("[test output] SBATCH %s", command)
return TestJobReference()
[docs]
def is_sequential(self) -> bool:
return True
[docs]
class SlurmJobReference(JobReference):
def __init__(self, job_id: int | None, exit_code: int | None, mode: str):
self.job_id = job_id
self.exit_code = exit_code
self.mode = mode
[docs]
def get_job_id(self) -> int | None:
return self.job_id
[docs]
def get_exit_code(self) -> int | None:
return self.exit_code
[docs]
def get_info(self) -> dict[str, Any]:
return {
"job_id": self.job_id,
"exit_code": self.exit_code,
"on_slurm": True,
"mode": self.mode,
}
[docs]
class SlurmDispatcher(Dispatcher):
"""
The most important dispatcher: Distributing function calls to slurm.
"""
def __init__(self):
super().__init__()
if not shutil.which("sbatch"):
msg = "Slurm could not be found."
raise RuntimeError(msg)
self.max_arg_length = DEFAULT_MAX_ARG_LENGTH
self._all_job_ids = []
self._join_dependencies = []
def _create_slurm_api(self, special_slurm_opts: SlurmOptions) -> simple_slurm.Slurm:
conf = _get_conf(special_slurm_opts)
return simple_slurm.Slurm(**conf)
def _job_name(self, funcs: list[FunctionCall]) -> str:
func_names = list({FunctionMap.get_readable_name(f.func_id) for f in funcs})
if len(funcs) == 1:
return f"slurminade:{func_names[0]}"
return f"slurminade[batch]:{func_names[0]}..."
def _dispatch(
self,
funcs: Iterable[FunctionCall],
options: SlurmOptions,
entry_point: Path,
block: bool = False,
) -> SlurmJobReference:
dispatch_guard()
if "job_name" not in options:
funcs = list(funcs)
# This is complicated to prevent warnings about the type
options = SlurmOptions(**options.as_dict())
options["job_name"] = self._job_name(funcs)
options = SlurmOptions(**options)
if self._join_dependencies:
options.add_dependencies(self._join_dependencies, "afterany")
slurm = self._create_slurm_api(options)
command = create_slurminade_command(entry_point, funcs, self.max_arg_length)
logging.getLogger("slurminade").debug(command)
if block:
ret = slurm.srun(" ".join(shlex.quote(c) for c in command))
logging.getLogger("slurminade").info(
"Returned from srun with exit code %s", ret
)
return SlurmJobReference(None, ret, "srun")
jid = slurm.sbatch(" ".join(shlex.quote(c) for c in command))
self._all_job_ids.append(jid)
return SlurmJobReference(jid, None, "sbatch")
[docs]
def sbatch(
self,
command: str,
conf: SlurmOptions | None = None,
simple_slurm_kwargs: dict | None = None,
) -> SlurmJobReference:
dispatch_guard()
conf_ = _get_conf(conf)
slurm = simple_slurm.Slurm(**conf_)
logging.getLogger("slurminade").debug("SBATCH %s", command)
if simple_slurm_kwargs:
jid = slurm.sbatch(command, **simple_slurm_kwargs)
else:
jid = slurm.sbatch(command)
self._all_job_ids.append(jid)
return SlurmJobReference(jid, None, "sbatch")
[docs]
def join(self) -> None:
if not self._all_job_ids:
return
self._join_dependencies = list(set(self._all_job_ids))
[docs]
def srun(
self,
command: str,
conf: SlurmOptions | None = None,
simple_slurm_kwargs: dict | None = None,
) -> SlurmJobReference:
dispatch_guard()
conf_ = _get_conf(conf)
slurm = simple_slurm.Slurm(**conf_)
logging.getLogger("slurminade").debug("SRUN %s", command)
if simple_slurm_kwargs:
ret = slurm.srun(command, **simple_slurm_kwargs)
else:
ret = slurm.srun(command)
return SlurmJobReference(None, ret, "srun")
[docs]
class SubprocessJobReference(JobReference):
def __init__(self):
pass
[docs]
def get_job_id(self) -> int | None:
return None
[docs]
def get_exit_code(self) -> int | None:
return None
[docs]
def get_info(self) -> dict[str, Any]:
return {"on_slurm": False}
[docs]
class SubprocessDispatcher(Dispatcher):
"""
A dispatcher for debugging that distributes function calls using subprocesses.
Thus, it uses the same serialization mechanisms, but without a slurm dependency.
Completely useless for productive purposes. Use `DirectCallDispatcher` if you
don't want to use slurm.
Despite using subprocesses, it does not parallelize but works sequential.
"""
def __init__(self):
super().__init__()
self.max_arg_length = DEFAULT_MAX_ARG_LENGTH
def _dispatch(
self,
funcs: Iterable[FunctionCall],
options: SlurmOptions, # noqa: ARG002
entry_point: Path,
block: bool = False, # noqa: ARG002
) -> SubprocessJobReference:
dispatch_guard()
command = create_slurminade_command(entry_point, funcs, self.max_arg_length)
subprocess.run(command, shell=False, check=True)
return SubprocessJobReference()
[docs]
def srun(
self,
command: str,
conf: SlurmOptions | None = None, # noqa: ARG002
simple_slurm_kwargs: dict | None = None, # noqa: ARG002
) -> SubprocessJobReference:
dispatch_guard()
logging.getLogger("slurminade").debug("SRUN %s", command)
subprocess.run(command, shell=True, check=True)
return SubprocessJobReference()
[docs]
def sbatch(
self,
command: str,
conf: SlurmOptions | None = None, # noqa: ARG002
simple_slurm_kwargs: dict | None = None, # noqa: ARG002
) -> SubprocessJobReference:
return self.srun(command)
[docs]
def is_sequential(self) -> bool:
return True
[docs]
class LocalJobReference(JobReference):
[docs]
def get_job_id(self) -> None:
return None
[docs]
def get_exit_code(self) -> None:
return None
[docs]
def get_info(self) -> dict[str, Any]:
return {"on_slurm": False}
[docs]
class DirectCallDispatcher(Dispatcher):
"""
A dispatcher that calls functions as if we would not use slurminade.
This allows compatibility of scripts also on computers not integrated into
the slurm network.
"""
def _dispatch(
self,
funcs: Iterable[FunctionCall],
options: SlurmOptions, # noqa: ARG002
entry_point: Path, # noqa: ARG002
block: bool = False, # noqa: ARG002
) -> LocalJobReference:
dispatch_guard()
for func in funcs:
FunctionMap.call(func.func_id, func.args, func.kwargs)
return LocalJobReference()
[docs]
def srun(
self,
command: str,
conf: SlurmOptions | None = None, # noqa: ARG002
simple_slurm_kwargs: dict | None = None, # noqa: ARG002
) -> LocalJobReference:
dispatch_guard()
subprocess.run(command, shell=True, check=True)
return LocalJobReference()
[docs]
def sbatch(
self,
command: str,
conf: SlurmOptions | None = None, # noqa: ARG002
simple_slurm_kwargs: dict | None = None, # noqa: ARG002
) -> LocalJobReference:
return self.srun(command)
[docs]
def is_sequential(self) -> bool:
return True
# The current dispatcher. Use with `get_dispatcher` and `set_dispatcher`.
__dispatcher: Dispatcher | None = None
[docs]
def get_dispatcher() -> Dispatcher:
"""
Returns the current dispatcher. Creates a dispatcher if none is available.
First tries to create the slurm-dispatcher (as this is the primary purpose of
slurminade). If no slurm-environment can be found, it creates a DirectCallDispatcher
to allow compatibility.
:return: The dispatcher.
"""
global __dispatcher # noqa: PLW0603
if __dispatcher is None:
_logger.debug("No dispatcher set, creating default dispatcher")
try:
__dispatcher = SlurmDispatcher()
_logger.debug("Using SlurmDispatcher (Slurm environment detected)")
except RuntimeError as re:
_logger.warning("Slurm environment not found: %s", re)
_logger.warning("Using DirectCallDispatcher (local execution)")
__dispatcher = DirectCallDispatcher()
return __dispatcher
[docs]
def set_dispatcher(dispatcher: Dispatcher) -> None:
"""
Replaces the dispatcher. Can be used to enforce a specific dispatcher.
:param dispatcher: The dispatcher to be used.
:return: None
"""
global __dispatcher # noqa: PLW0603
_logger.debug(
"Setting dispatcher to %s", dispatcher.__class__.__name__
)
__dispatcher = dispatcher
if get_dispatcher() is not dispatcher:
msg = "Failed to set dispatcher."
raise RuntimeError(msg)
[docs]
def dispatch(
funcs: FunctionCall | Iterable[FunctionCall],
options: SlurmOptions,
entry_point: Path,
block: bool = False,
) -> JobReference:
"""
Distribute function calls with the current dispatcher.
:param funcs: The functions calls to be distributed.
:param options: The slurm options to be used.
:return: The job id.
"""
funcs = list(funcs) if not isinstance(funcs, FunctionCall) else [funcs]
for func in funcs:
if not FunctionMap.check_id(func.func_id, entry_point):
msg = f"Function '{func.func_id}' cannot be called from the given entry point."
raise KeyError(msg)
return get_dispatcher()(funcs, options, entry_point, block)
[docs]
def srun(
command: str | list[str],
conf: SlurmOptions | dict | None = None,
simple_slurm_kwargs: dict | None = None,
) -> JobReference:
"""
Call srun with the current dispatcher. This command is directly
executed and only terminates after completion.
:param command: A system command, e.g. `echo hello world > foobar.txt`.
:param conf: The slurm configuration.
:param simple_slurm_kwargs: Additional options for simple_slurm.
:return: Job id
"""
if not isinstance(conf, SlurmOptions):
if conf is None:
conf = {}
conf = SlurmOptions(**conf)
command = (
command
if isinstance(command, str)
else " ".join(shlex.quote(c) for c in command)
)
return get_dispatcher().srun(command, conf, simple_slurm_kwargs)
[docs]
def sbatch(
command: str | list[str],
conf: SlurmOptions | dict | None = None,
simple_slurm_kwargs: dict | None = None,
) -> JobReference:
"""
The command is scheduled and the function returns immediately.
:param command: A system command, e.g. `echo hello world > foobar.txt`.
:param conf: The slurm configuration.
:param simple_slurm_kwargs: Additional options for simple_slurm.
:return: Job id.
"""
if not isinstance(conf, SlurmOptions):
if conf is None:
conf = {}
conf = SlurmOptions(**conf)
command = (
command
if isinstance(command, str)
else " ".join(shlex.quote(c) for c in command)
)
return get_dispatcher().sbatch(command, conf, simple_slurm_kwargs)
[docs]
def join() -> None:
"""
Join all jobs that have been dispatched so far.
:return: None
"""
get_dispatcher().join()