Source code for slurminade.guard

"""
Some security measures to warn you about common mistakes and prevent you from
accidentally DDoSing your slurm environment.

1. Preventing recursive distributions, i.e., slurm nodes also distributing tasks.
2. Limiting the number of distributed tasks.
3. Warn about multiple flushes of batches, often caused by wrong indentation.

You can disable these security mechanisms by
``allow_recursive_distribution``, ``set_dispatch_limit(None)``, and
``disable_warning_for_multiple_flushes``.
"""

import logging
import typing

_exec_flag = False


[docs] def on_slurm_node(): global _exec_flag # noqa: PLW0602 return _exec_flag
[docs] def guard_recursive_distribution(): if on_slurm_node(): msg = """ You tried to distribute a task recursively. This is not allowed by default, because it probably indicates a bug in your code. To save you from accidentally overloading your slurm environment, this feature has been disabled by default. The most common reason for this error is that you forgot to guard your script with 'if __name__==\"__main__\":'. If you are sure that you want to distribute tasks recursively, you can disable this security mechanism by calling `allow_recursive_distribution` before the first call of `distribute`. """ raise RuntimeError(msg)
[docs] def prevent_distribution(): global _exec_flag # noqa: PLW0603 _exec_flag = True
[docs] def allow_recursive_distribution() -> None: """ Allow recursive distribution. Dangerous! :return: None """ global _exec_flag # noqa: PLW0603 _exec_flag = False
[docs] class TooManyDispatchesError(RuntimeError): def __init__(self, n_calls): self.n_calls = n_calls def __str__(self): return ( f"Exceeded the dispatch limit of {self.n_calls} calls. " f"This limit has been introduced to prevent you from overloading your " f"slurm environment in case of a bug. You can increase it" f" using `set_dispatch_limit`." )
class _DispatchGuard: def __init__(self, max_calls): self.max_calls = max_calls self.remaining_calls = max_calls def __call__(self): if not self.max_calls: return None if self.remaining_calls <= 0: raise TooManyDispatchesError(self.max_calls) self.remaining_calls -= 1 return self.remaining_calls def set_limit(self, n): self.max_calls = n self.remaining_calls = n dispatch_guard = _DispatchGuard(100)
[docs] def set_dispatch_limit(n: typing.Optional[int]): """ Set a limit to the number of dispatches. This feature has been introduced to prevent you from accidentally DDoSing you Slurm environment due to a bug. :param n: The maximal number of dispatches. :return: None """ dispatch_guard.set_limit(n)
[docs] class BatchGuard: """ Warns you if you flush more than once, as putting the flush call in a loop is a common mistake, compared to the intended use of flushing once at the end of your context, to get the job ids for dependency management. """ already_warned = False def __init__(self) -> None: self._num_of_flushes = 0 def _get_error_msg(self) -> str: return """ You repeatedly flushed a batch. There are various scenarios where this is done on purpose, but we want to warn you, because it is a common mistake to put the flush call in a loop by wrong indentation, instead of calling it once at the end of your context. If used for dependency management, e.g., with `wait_for`, such a mistake can lead to a faulty execution order. This warning allows you to quickly call `scancel -u <username>` to cancel your jobs, before they do any harm. If you are sure that you want to flush your batch multiple times, you can disable this warning by calling `disable_warning_on_repeated_flushes` before the first call of `flush`. You can also just ignore this warning, as no action is taken by default. """
[docs] def report_flush(self, num_tasks: int) -> None: if num_tasks == 0: # ignore empty flushes return self._num_of_flushes += 1 if self._num_of_flushes == 2 and not self.already_warned: logging.getLogger("slurminade").warning(self._get_error_msg()) self.already_warned = True
[docs] def disable_warning_on_repeated_flushes(): """ Disable the warning on multiple flushes. This is useful if you want to flush multiple times in a loop, without getting a warning. """ BatchGuard.already_warned = True