Source code for slurminade.dispatcher

"""
The dispatcher distribute function calls to slurm or the local machine.
It can be accessed with `get_dispatcher` and set with `set_dispatcher`.
This allows to change the behaviour of the distribution, e.g., we use it for batch:
Batch simply wraps the dispatcher by a buffered version.
"""

from __future__ import annotations

import abc
import logging
import shlex
import shutil
import subprocess
from collections.abc import Iterable
from pathlib import Path
from typing import Any

import simple_slurm

from .conf import _get_conf
from .execute_cmds import create_slurminade_command
from .function_call import FunctionCall
from .function_map import FunctionMap, get_entry_point
from .guard import dispatch_guard
from .job_reference import JobReference
from .options import SlurmOptions

# MAX_ARG_STRLEN on a Linux system with PAGE_SIZE 4096 is 131072
DEFAULT_MAX_ARG_LENGTH = 100000

# Module-level logger for consistent logging
_logger = logging.getLogger("slurminade.dispatcher")


[docs] class Dispatcher(abc.ABC): """ Abstract dispatcher to be inherited by all concrete dispatchers. For implementing a dispatcher you have to implement `_dispatch`, `srun` and `sbatch`. """ @abc.abstractmethod def _dispatch( self, funcs: Iterable[FunctionCall], options: SlurmOptions, entry_point: Path, block: bool = False, ) -> JobReference: """ Define how to dispatch a number of function calls. :param funcs: The function calls to be dispatched. :param options: The slurm options to be used. :return: The job id. Use -1 if not applicable (e.g., because buffered) """
[docs] @abc.abstractmethod def srun( self, command: str, conf: SlurmOptions | None = None, simple_slurm_kwargs: dict | None = None, ) -> JobReference: """ Define how you want to execute an `srun` command. This command is directly executed and only terminates after completion. :param command: A system command, e.g. `echo hello world > foobar.txt`. :param conf: The slurm configuration. :param simple_slurm_kwargs: Additional options for simple_slurm. :return: Job id """
[docs] @abc.abstractmethod def sbatch( self, command: str, conf: SlurmOptions | None = None, simple_slurm_kwargs: dict | None = None, ) -> JobReference: """ Define how you want to execute an `sbatch` command. The command is scheduled and the function return immediately. :param command: A system command, e.g. `echo hello world > foobar.txt`. :param conf: The slurm configuration. :param simple_slurm_kwargs: Additional options for simple_slurm. :return: Job id. """
def _log_dispatch(self, funcs: list[FunctionCall], options: SlurmOptions) -> None: """Log dispatching information with lazy formatting.""" if len(funcs) == 1: _logger.info( "Dispatching task with options %s: %s", options, funcs[0] ) else: _logger.info( "Dispatching task with %d function calls and options %s: %s", len(funcs), options, ", ".join(str(f) for f in funcs), ) def __call__( self, funcs: FunctionCall | Iterable[FunctionCall], options: SlurmOptions, entry_point: Path, block: bool = False, ) -> JobReference: """ Dispatches a function call or a number of function calls. :param funcs: The function calls to be distributed. :param options: The slurm options to be used. :return: Job id. """ if isinstance(funcs, FunctionCall): funcs = [funcs] funcs = list(funcs) self._log_dispatch(funcs, options) return self._dispatch(funcs, options, entry_point, block)
[docs] def is_sequential(self) -> bool: """ Return true if the dispatcher works sequential. In this case, the dependencies are trivially fulfilled. Slurm does not work sequentially, because this would destroy its purpose. In some cases however, you do not want to use slurm for compatibility reasons, without changing the script. In these cases, this function tells slurminade not to be too strict about dependencies. :return: True is tasks are executed sequentially, false if not. """ return False
[docs] def join(self) -> None: if self.is_sequential(): # Already sequential, nothing to do return msg = "Joining is not implemented for this dispatcher." raise NotImplementedError(msg)
[docs] class TestJobReference(JobReference):
[docs] def get_job_id(self) -> None: return None
[docs] def get_exit_code(self) -> None: return None
[docs] def get_info(self) -> dict[str, Any]: return {"info": "test"}
[docs] class TestDispatcher(Dispatcher): """ A dummy dispatcher that just prints the output. Primarily for debugging and testing. """ def __init__(self): super().__init__() self.calls = [] self.sbatches = [] self.sruns = [] self.max_arg_length = DEFAULT_MAX_ARG_LENGTH def _dispatch( self, funcs: Iterable[FunctionCall], options: SlurmOptions, # noqa: ARG002 entry_point: Path, # noqa: ARG002 block: bool = False, # noqa: ARG002 ) -> JobReference: dispatch_guard() funcs = list(funcs) command = create_slurminade_command( get_entry_point(), funcs, self.max_arg_length ) logging.getLogger("slurminade").info("Command: %s", command) self.calls.append(funcs) self._cleanup(command) return TestJobReference() def _cleanup(self, command: list[str]) -> None: if len(command) < 2 or command[-2] != "--fromfile": return filename = Path(command[-1]) if filename.exists(): filename.unlink()
[docs] def srun( self, command: str, conf: SlurmOptions | None = None, # noqa: ARG002 simple_slurm_kwargs: dict | None = None, # noqa: ARG002 ) -> JobReference: dispatch_guard() self.sruns.append(command) logging.getLogger("slurminade").info("[test output] SRUN %s", command) return TestJobReference()
[docs] def sbatch( self, command: str, conf: SlurmOptions | None = None, # noqa: ARG002 simple_slurm_kwargs: dict | None = None, # noqa: ARG002 ) -> JobReference: dispatch_guard() self.sbatches.append(command) logging.getLogger("slurminade").info("[test output] SBATCH %s", command) return TestJobReference()
[docs] def is_sequential(self) -> bool: return True
[docs] class SlurmJobReference(JobReference): def __init__(self, job_id: int | None, exit_code: int | None, mode: str): self.job_id = job_id self.exit_code = exit_code self.mode = mode
[docs] def get_job_id(self) -> int | None: return self.job_id
[docs] def get_exit_code(self) -> int | None: return self.exit_code
[docs] def get_info(self) -> dict[str, Any]: return { "job_id": self.job_id, "exit_code": self.exit_code, "on_slurm": True, "mode": self.mode, }
[docs] class SlurmDispatcher(Dispatcher): """ The most important dispatcher: Distributing function calls to slurm. """ def __init__(self): super().__init__() if not shutil.which("sbatch"): msg = "Slurm could not be found." raise RuntimeError(msg) self.max_arg_length = DEFAULT_MAX_ARG_LENGTH self._all_job_ids = [] self._join_dependencies = [] def _create_slurm_api(self, special_slurm_opts: SlurmOptions) -> simple_slurm.Slurm: conf = _get_conf(special_slurm_opts) return simple_slurm.Slurm(**conf) def _job_name(self, funcs: list[FunctionCall]) -> str: func_names = list({FunctionMap.get_readable_name(f.func_id) for f in funcs}) if len(funcs) == 1: return f"slurminade:{func_names[0]}" return f"slurminade[batch]:{func_names[0]}..." def _dispatch( self, funcs: Iterable[FunctionCall], options: SlurmOptions, entry_point: Path, block: bool = False, ) -> SlurmJobReference: dispatch_guard() if "job_name" not in options: funcs = list(funcs) # This is complicated to prevent warnings about the type options = SlurmOptions(**options.as_dict()) options["job_name"] = self._job_name(funcs) options = SlurmOptions(**options) if self._join_dependencies: options.add_dependencies(self._join_dependencies, "afterany") slurm = self._create_slurm_api(options) command = create_slurminade_command(entry_point, funcs, self.max_arg_length) logging.getLogger("slurminade").debug(command) if block: ret = slurm.srun(" ".join(shlex.quote(c) for c in command)) logging.getLogger("slurminade").info( "Returned from srun with exit code %s", ret ) return SlurmJobReference(None, ret, "srun") jid = slurm.sbatch(" ".join(shlex.quote(c) for c in command)) self._all_job_ids.append(jid) return SlurmJobReference(jid, None, "sbatch")
[docs] def sbatch( self, command: str, conf: SlurmOptions | None = None, simple_slurm_kwargs: dict | None = None, ) -> SlurmJobReference: dispatch_guard() conf_ = _get_conf(conf) slurm = simple_slurm.Slurm(**conf_) logging.getLogger("slurminade").debug("SBATCH %s", command) if simple_slurm_kwargs: jid = slurm.sbatch(command, **simple_slurm_kwargs) else: jid = slurm.sbatch(command) self._all_job_ids.append(jid) return SlurmJobReference(jid, None, "sbatch")
[docs] def join(self) -> None: if not self._all_job_ids: return self._join_dependencies = list(set(self._all_job_ids))
[docs] def srun( self, command: str, conf: SlurmOptions | None = None, simple_slurm_kwargs: dict | None = None, ) -> SlurmJobReference: dispatch_guard() conf_ = _get_conf(conf) slurm = simple_slurm.Slurm(**conf_) logging.getLogger("slurminade").debug("SRUN %s", command) if simple_slurm_kwargs: ret = slurm.srun(command, **simple_slurm_kwargs) else: ret = slurm.srun(command) return SlurmJobReference(None, ret, "srun")
[docs] class SubprocessJobReference(JobReference): def __init__(self): pass
[docs] def get_job_id(self) -> int | None: return None
[docs] def get_exit_code(self) -> int | None: return None
[docs] def get_info(self) -> dict[str, Any]: return {"on_slurm": False}
[docs] class SubprocessDispatcher(Dispatcher): """ A dispatcher for debugging that distributes function calls using subprocesses. Thus, it uses the same serialization mechanisms, but without a slurm dependency. Completely useless for productive purposes. Use `DirectCallDispatcher` if you don't want to use slurm. Despite using subprocesses, it does not parallelize but works sequential. """ def __init__(self): super().__init__() self.max_arg_length = DEFAULT_MAX_ARG_LENGTH def _dispatch( self, funcs: Iterable[FunctionCall], options: SlurmOptions, # noqa: ARG002 entry_point: Path, block: bool = False, # noqa: ARG002 ) -> SubprocessJobReference: dispatch_guard() command = create_slurminade_command(entry_point, funcs, self.max_arg_length) subprocess.run(command, shell=False, check=True) return SubprocessJobReference()
[docs] def srun( self, command: str, conf: SlurmOptions | None = None, # noqa: ARG002 simple_slurm_kwargs: dict | None = None, # noqa: ARG002 ) -> SubprocessJobReference: dispatch_guard() logging.getLogger("slurminade").debug("SRUN %s", command) subprocess.run(command, shell=True, check=True) return SubprocessJobReference()
[docs] def sbatch( self, command: str, conf: SlurmOptions | None = None, # noqa: ARG002 simple_slurm_kwargs: dict | None = None, # noqa: ARG002 ) -> SubprocessJobReference: return self.srun(command)
[docs] def is_sequential(self) -> bool: return True
[docs] class LocalJobReference(JobReference):
[docs] def get_job_id(self) -> None: return None
[docs] def get_exit_code(self) -> None: return None
[docs] def get_info(self) -> dict[str, Any]: return {"on_slurm": False}
[docs] class DirectCallDispatcher(Dispatcher): """ A dispatcher that calls functions as if we would not use slurminade. This allows compatibility of scripts also on computers not integrated into the slurm network. """ def _dispatch( self, funcs: Iterable[FunctionCall], options: SlurmOptions, # noqa: ARG002 entry_point: Path, # noqa: ARG002 block: bool = False, # noqa: ARG002 ) -> LocalJobReference: dispatch_guard() for func in funcs: FunctionMap.call(func.func_id, func.args, func.kwargs) return LocalJobReference()
[docs] def srun( self, command: str, conf: SlurmOptions | None = None, # noqa: ARG002 simple_slurm_kwargs: dict | None = None, # noqa: ARG002 ) -> LocalJobReference: dispatch_guard() subprocess.run(command, shell=True, check=True) return LocalJobReference()
[docs] def sbatch( self, command: str, conf: SlurmOptions | None = None, # noqa: ARG002 simple_slurm_kwargs: dict | None = None, # noqa: ARG002 ) -> LocalJobReference: return self.srun(command)
[docs] def is_sequential(self) -> bool: return True
# The current dispatcher. Use with `get_dispatcher` and `set_dispatcher`. __dispatcher: Dispatcher | None = None
[docs] def get_dispatcher() -> Dispatcher: """ Returns the current dispatcher. Creates a dispatcher if none is available. First tries to create the slurm-dispatcher (as this is the primary purpose of slurminade). If no slurm-environment can be found, it creates a DirectCallDispatcher to allow compatibility. :return: The dispatcher. """ global __dispatcher # noqa: PLW0603 if __dispatcher is None: _logger.debug("No dispatcher set, creating default dispatcher") try: __dispatcher = SlurmDispatcher() _logger.debug("Using SlurmDispatcher (Slurm environment detected)") except RuntimeError as re: _logger.warning("Slurm environment not found: %s", re) _logger.warning("Using DirectCallDispatcher (local execution)") __dispatcher = DirectCallDispatcher() return __dispatcher
[docs] def set_dispatcher(dispatcher: Dispatcher) -> None: """ Replaces the dispatcher. Can be used to enforce a specific dispatcher. :param dispatcher: The dispatcher to be used. :return: None """ global __dispatcher # noqa: PLW0603 _logger.debug( "Setting dispatcher to %s", dispatcher.__class__.__name__ ) __dispatcher = dispatcher if get_dispatcher() is not dispatcher: msg = "Failed to set dispatcher." raise RuntimeError(msg)
[docs] def dispatch( funcs: FunctionCall | Iterable[FunctionCall], options: SlurmOptions, entry_point: Path, block: bool = False, ) -> JobReference: """ Distribute function calls with the current dispatcher. :param funcs: The functions calls to be distributed. :param options: The slurm options to be used. :return: The job id. """ funcs = list(funcs) if not isinstance(funcs, FunctionCall) else [funcs] for func in funcs: if not FunctionMap.check_id(func.func_id, entry_point): msg = f"Function '{func.func_id}' cannot be called from the given entry point." raise KeyError(msg) return get_dispatcher()(funcs, options, entry_point, block)
[docs] def srun( command: str | list[str], conf: SlurmOptions | dict | None = None, simple_slurm_kwargs: dict | None = None, ) -> JobReference: """ Call srun with the current dispatcher. This command is directly executed and only terminates after completion. :param command: A system command, e.g. `echo hello world > foobar.txt`. :param conf: The slurm configuration. :param simple_slurm_kwargs: Additional options for simple_slurm. :return: Job id """ if not isinstance(conf, SlurmOptions): if conf is None: conf = {} conf = SlurmOptions(**conf) command = ( command if isinstance(command, str) else " ".join(shlex.quote(c) for c in command) ) return get_dispatcher().srun(command, conf, simple_slurm_kwargs)
[docs] def sbatch( command: str | list[str], conf: SlurmOptions | dict | None = None, simple_slurm_kwargs: dict | None = None, ) -> JobReference: """ The command is scheduled and the function returns immediately. :param command: A system command, e.g. `echo hello world > foobar.txt`. :param conf: The slurm configuration. :param simple_slurm_kwargs: Additional options for simple_slurm. :return: Job id. """ if not isinstance(conf, SlurmOptions): if conf is None: conf = {} conf = SlurmOptions(**conf) command = ( command if isinstance(command, str) else " ".join(shlex.quote(c) for c in command) ) return get_dispatcher().sbatch(command, conf, simple_slurm_kwargs)
[docs] def join() -> None: """ Join all jobs that have been dispatched so far. :return: None """ get_dispatcher().join()