# Copyright 2022-2025 The Ramble Authors
#
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
# https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
# <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
# option. This file may not be copied, modified, or distributed
# except according to those terms.
"""Define base classes for modifier definitions"""
import io
import re
import textwrap
import fnmatch
from typing import List
from ramble.language.modifier_language import ModifierMeta, mode
from ramble.language.shared_language import SharedMeta
from ramble.error import RambleError
import ramble.util.directives
import ramble.util.class_attributes
from ramble.util.logger import logger
from ramble.util.naming import NS_SEPARATOR
[docs]
class ModifierBase(metaclass=ModifierMeta):
name = None
_builtin_name = NS_SEPARATOR.join(("modifier_builtin", "{obj_name}", "{name}"))
_mod_prefix_builtin = f"modifier_builtin{NS_SEPARATOR}"
_language_classes = [ModifierMeta, SharedMeta]
_pipelines = ["analyze", "archive", "mirror", "setup", "pushtocache", "execute", "logs"]
modifier_class = "ModifierBase"
#: Lists of strings which contains GitHub usernames of attributes.
#: Do not include @ here in order not to unnecessarily ping the users.
maintainers: List[str] = []
tags: List[str] = []
disabled = False
mode("disabled", description="Mode to disable all modifier functionality")
def __init__(self, file_path):
super().__init__()
ramble.util.class_attributes.convert_class_attributes(self)
self._file_path = file_path
self._on_executables = ["*"]
self.expander = None
self._usage_mode = None
self._verbosity = "short"
ramble.util.directives.define_directive_methods(self)
[docs]
def copy(self):
"""Deep copy a modifier instance"""
new_copy = type(self)(self._file_path)
new_copy._on_executables = self._on_executables.copy()
new_copy._usage_mode = self._usage_mode
new_copy._verbosity = self._verbosity
return new_copy
[docs]
def set_usage_mode(self, mode):
"""Set the usage mode for this modifier.
If not set, or given an empty string the modifier tries to auto-detect a mode.
If it cannot auto detect the usage mode, an error is raised.
"""
if mode:
self._usage_mode = mode
elif hasattr(self, "_default_usage_mode"):
self._usage_mode = self._default_usage_mode
if len(logger.log_stack) >= 1:
logger.msg(
f" Using default usage mode {self._usage_mode} on modifier {self.name}"
)
else:
non_disabled_modes = set(self.modes)
non_disabled_modes.remove("disabled")
if len(non_disabled_modes) > 1 or len(non_disabled_modes) == 0:
raise InvalidModeError(
"Cannot auto determine usage " f"mode for modifier {self.name}"
)
self._usage_mode = non_disabled_modes.pop()
if len(logger.log_stack) >= 1:
logger.msg(
f" Using default usage mode {self._usage_mode} on modifier {self.name}"
)
if self._usage_mode == "disabled":
self.disabled = True
[docs]
def set_on_executables(self, on_executables):
"""Set the executables this modifier applies to.
If given an empty list or a value of None, the default of: '*' is usage.
"""
if on_executables:
if not isinstance(on_executables, list):
raise ModifierError(
f"Modifier {self.name} given an unsupported on_executables "
f"type of {type(on_executables)}"
)
self._on_executables = []
for exec_name in on_executables:
self._on_executables.append(exec_name)
else:
self._on_executables = ["*"]
[docs]
def inherit_from_application(self, app):
self.expander = app.expander.copy()
modded_vars = self.modded_variables(app)
self.expander._variables.update(modded_vars)
[docs]
def define_variable(self, var_name, var_value):
"""Define a variable within this modifier's expander instance"""
self.expander._variables[var_name] = var_value
[docs]
def modify_experiment(self, app):
"""Stubbed method to allow modification of experiment variables before
an experiment is completely defined.
This can be used to define things like n_ranks and have it influence
the name of the resulting experiment.
"""
pass
def __str__(self):
return self.name
[docs]
def modded_variables(self, app, extra_vars=None):
mods = {}
if extra_vars is None:
extra_vars = {}
if self._usage_mode not in self.variable_modifications:
return mods
for var, var_mods in self.variable_modifications[self._usage_mode].items():
for var_mod in var_mods:
if var_mod["method"] in ["append", "prepend"]:
if var in mods:
prev_val = mods[var]
elif var in extra_vars:
prev_val = extra_vars[var]
elif var in app.variables:
prev_val = app.variables[var]
else:
prev_val = ""
if prev_val != "" and prev_val is not None:
sep = var_mod["separator"]
else:
sep = ""
if var_mod["method"] == "append":
mods[var] = f'{prev_val}{sep}{var_mod["modification"]}'
else: # method == prepend
mods[var] = f'{var_mod["modification"]}{sep}{prev_val}'
else: # method == set
mods[var] = var_mod["modification"]
return mods
[docs]
def applies_to_executable(self, executable):
apply = False
mod_regex = re.compile(self._mod_prefix_builtin + f"{self.name}{NS_SEPARATOR}")
for pattern in self._on_executables:
if fnmatch.fnmatch(executable, pattern):
apply = True
exec_match = mod_regex.match(executable)
if exec_match:
apply = True
return apply
[docs]
def apply_executable_modifiers(self, executable_name, executable, app_inst=None):
pre_execs = []
post_execs = []
for exec_mod in self.executable_modifiers:
mod_func = getattr(self, exec_mod)
pre_exec, post_exec = mod_func(executable_name, executable, app_inst=app_inst)
pre_execs.extend(pre_exec)
post_execs.extend(post_exec)
return pre_execs, post_execs
[docs]
def all_env_var_modifications(self):
if self._usage_mode not in self.env_var_modifications:
return
yield from self.env_var_modifications[self._usage_mode].items()
[docs]
def all_package_manager_requirements(self):
if self._usage_mode in self.package_manager_requirements:
yield from self.package_manager_requirements[self._usage_mode]
[docs]
def all_pipeline_phases(self, pipeline):
if pipeline in self.phase_definitions:
yield from self.phase_definitions[pipeline].items()
[docs]
def no_expand_vars(self):
"""Iterator over non-expandable variables in current mode
Yields:
(str): Variable name
"""
if self._usage_mode in self.modifier_variables:
for var, var_conf in self.modifier_variables[self._usage_mode].items():
if not var_conf.expandable:
yield var
[docs]
def mode_variables(self):
"""Return a dict of variables that should be defined for the current mode"""
if self._usage_mode in self.modifier_variables:
return self.modifier_variables[self._usage_mode]
else:
return {}
[docs]
def run_phase_hook(self, workspace, pipeline, hook_name):
"""Run a modifier hook.
Hooks are internal functions named _{hook_name}.
This is a wrapper to extract the hook function, and execute it
properly.
Hooks are only executed if they are not defined as a phase from the
modifier.
"""
run_hook = True
if pipeline in self.phase_definitions:
if hook_name in self.phase_definitions[pipeline]:
run_hook = False
if run_hook:
hook_func_name = f"_{hook_name}"
if hasattr(self, hook_func_name):
phase_func = getattr(self, hook_func_name)
phase_func(workspace)
[docs]
def artifact_inventory(self, workspace, app_inst=None):
"""Return an inventory of modifier artifacts
Artifact inventories are up to the individual modifier to define the
format of.
This will then show up in an experiment inventory.
Returns:
(Any) Artifact inventory for this modifier
"""
return None
def _prepare_analysis(self, workspace):
"""Hook to perform analysis that a modifier defines.
This function allows modifier definitions to inject their own
processing to output files, before FOMs are extracted.
"""
pass
[docs]
def get_required_variables(self):
"""Get all the required variables based on the mode."""
required_vars = self.required_vars
filtered_vars = {}
if required_vars:
mode = self._usage_mode
for var_name, var_props in required_vars.items():
modes = var_props["modes"]
if modes is None or mode in modes:
filtered_vars[var_name] = {
# Exclude the extra modes prop
k: var_props[k]
for k in var_props.keys() - {"modes"}
}
return filtered_vars
[docs]
class ModifierError(RambleError):
"""
Exception that is raised by modifiers
"""
[docs]
class InvalidModeError(ModifierError):
"""
Exception raised when an invalid mode is passed
"""