Source code for slurminade.function_map

"""
The internal datastructure to save all the slurmified functions.
Not relevant for endusers.
"""

from __future__ import annotations

import inspect
import logging
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any, ClassVar

from .execute_cmds import call_slurminade_to_get_function_ids

# Module-level logger for consistent logging
_logger = logging.getLogger("slurminade.function_map")


[docs] class FunctionMap: """ The function map assigns functions an id and stores them to be called later. This id is reproducible such that the slurm node can retrieve the function it is supposed to call just by the id. """ # Needed to save the entry point at the slurm node. # The slurm node just executes the file content of a script, so the file name is lost. # slurminade will set this value in the beginning to reconstruct it. entry_point: str | None = None _data: ClassVar[dict[str, Callable]] = {} _ids: ClassVar[set[str] | None] = set()
[docs] @staticmethod def get_id(func: Callable) -> str: """ Returns the unique id of a function. Necessary to slurmify it. Probably not needed by the enduser. It uses the function name and its file, which should be sufficient unless you do not overwrite a function (which is bad anyway). :param func: The function you want the id of. :return: The id as string. """ file = inspect.getfile(func) if file == "<string>": # on the slurm node, the functions in the entry point # are named `<string>`. if not FunctionMap.entry_point: msg = "No entry point known." raise RuntimeError(msg) file = FunctionMap.entry_point path = Path(file).resolve() return f"{path}:{getattr(func, '__name__', repr(func))}"
[docs] @staticmethod def get_readable_name(func_id: str) -> str: return func_id.rsplit(":", maxsplit=1)[-1]
[docs] @staticmethod def check_compatibility(func: Callable) -> None: """ Throw if the function cannot be assigned an id. :param func: The function to be checked. :return: None """ name = getattr(func, "__name__", "") if ( not name or name == "<lambda>" or not inspect.getfile(func) ): msg = "Can only slurmify proper functions." raise ValueError(msg)
[docs] @staticmethod def register(func: Callable[..., Any], allow_overwrite: bool = False) -> str: """ Register a function, allowing it to be called just by its id. Args: func: The function to be stored (must be a proper function) allow_overwrite: Whether to allow overwriting existing registrations Returns: The function's unique ID """ FunctionMap.check_compatibility(func) func_id = FunctionMap.get_id(func) if func_id in FunctionMap._data and not allow_overwrite: _logger.error("Attempted to register duplicate function: %s", func_id) msg = "Multiple function definitions!" raise RuntimeError(msg) FunctionMap._data[func_id] = func _logger.debug("Registered slurmified function: %s", func_id) _logger.debug("Total registered functions: %d", len(FunctionMap._data)) return func_id
[docs] @staticmethod def exists(func_id: str) -> bool: return func_id in FunctionMap._data
[docs] @staticmethod def call( func_id: str, args: Iterable, kwargs: dict[str, Any] ) -> Any: """ Calls a function by its id. :param func_id: The id of the function to be called. :param args: The positional arguments. :param kwargs: The keyword arguments. :return: The return value of the function. """ if func_id not in FunctionMap._data: msg = f"Function '{func_id}' unknown!" raise KeyError(msg) return FunctionMap._data[func_id](*args, **kwargs)
[docs] @staticmethod def check_id(func_id: str, entry_point: Path) -> bool: if FunctionMap._ids is None: return True if func_id in FunctionMap._ids: return True try: # Try to identify the available function ids when the entry point is known. # This can help detect some bugs early on. FunctionMap._ids = call_slurminade_to_get_function_ids(entry_point) except Exception as e: logging.getLogger("slurminade").warning( "Cannot verify function ids before submitting to slurm: %s. This is not critical, things will just be more difficult to debug in case you make an error.", e, ) FunctionMap._ids = None return True logging.getLogger("slurminade").info( "Entry point '%s' has functions %s", entry_point, list(FunctionMap._ids), ) return func_id in FunctionMap._ids
[docs] @staticmethod def get_all_ids() -> list[str]: return list(FunctionMap._data.keys())
[docs] def set_entry_point(entry_point: str | Path) -> None: """ This function usually is not necessary for endusers. Set a manual entry point. This can allow you to use slurmify from the interactive interpreter. :param entry_point: A path to the entry point file. :return: None """ entry_point = Path(entry_point) if not entry_point.is_file() or not str(entry_point).endswith(".py"): msg = f"Illegal entry point ({entry_point})." raise ValueError(msg) entry_point = entry_point.resolve() FunctionMap.entry_point = str(entry_point)
# SlurmFunction.dispatcher.entry_point = entry_point
[docs] def get_entry_point() -> Path: if FunctionMap.entry_point is None: import __main__ # check if attribute __file__ is available if not hasattr(__main__, "__file__"): msg = "No entry point known." raise FileNotFoundError(msg) entry_point = __main__.__file__ if not Path(entry_point).is_file() or Path(entry_point).suffix != ".py": msg = "No entry point known." raise FileNotFoundError(msg) set_entry_point(entry_point) assert FunctionMap.entry_point is not None path = Path(FunctionMap.entry_point) if not path.exists(): msg = f"Entry point {path} does not exist." raise FileNotFoundError(msg) return path