# Copyright 2022-2026 The Ramble Authors
#
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
# https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
# <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
# option. This file may not be copied, modified, or distributed
# except according to those terms.
import itertools
import re
from typing import List, Optional, Set
import ramble.config
from ramble.error import RambleError
from ramble.util.logger import logger
ALL_PHASES: List[str] = ["*"]
[docs]
class Filters:
"""Object containing filters for limiting various operations in Ramble"""
def __init__(
self,
phase_filters: Optional[List[str]] = None,
include_where_filters: Optional[List[List[str]]] = None,
exclude_where_filters: Optional[List[List[str]]] = None,
tags: Optional[List[List[str]]] = None,
) -> None:
"""Create a new filter instance"""
self.phases: List[str] = ALL_PHASES if phase_filters is None else phase_filters
self.include_where: Optional[List[str]] = None
self.exclude_where: Optional[List[str]] = None
self.tags: Set[str] = set()
if include_where_filters:
self.include_where = list(itertools.chain.from_iterable(include_where_filters))
if exclude_where_filters:
self.exclude_where = list(itertools.chain.from_iterable(exclude_where_filters))
if tags:
self.tags = set(itertools.chain.from_iterable(tags))
[docs]
def translate_group_to_predicate(group_def: dict) -> str:
"""Translate a filter group definition into a predicate string"""
if not group_def:
return "True"
where_parts = group_def.get("where", [])
exclude_parts = group_def.get("exclude_where", [])
if not where_parts and not exclude_parts:
return "True"
parts = []
if where_parts:
where_str = " and ".join(f"({w})" for w in where_parts)
parts.append(f"({where_str})")
if exclude_parts:
exclude_str = " and ".join(f"not ({e})" for e in exclude_parts)
parts.append(f"({exclude_str})")
return " and ".join(parts)
[docs]
def expand_filter_groups(expression: str, filter_groups_defs: Optional[dict]) -> str:
"""Expand logical expression of filter groups into a predicate string"""
if not expression:
return "True"
if filter_groups_defs is None:
filter_groups_defs = {}
# Validate expression only contains allowed characters to prevent injection or silent failures
invalid_chars = re.sub(r"[a-zA-Z0-9_\s\(\)-]", "", expression)
if invalid_chars:
raise RambleError(
f"Invalid characters {repr(invalid_chars)} in filter group expression '{expression}'"
)
token_re = re.compile(r"([a-zA-Z0-9_-]+|\(|\))")
tokens = token_re.findall(expression)
expanded_tokens = []
for token in tokens:
low_token = token.lower()
if low_token in ("and", "or", "not", "(", ")"):
expanded_tokens.append(low_token)
elif token in filter_groups_defs:
group_def = filter_groups_defs[token]
group_expr = translate_group_to_predicate(group_def)
expanded_tokens.append(f"( {group_expr} )")
else:
raise RambleError(f"Unknown filter group '{token}' in expression '{expression}'")
return " ".join(expanded_tokens)
[docs]
def resolve_and_apply_filter_groups(args, include_where_filters):
"""Resolve filter groups from args/env and apply them to include_where_filters"""
filter_group_expr = getattr(args, "filter_group", None)
exclude_filter_group_expr = getattr(args, "exclude_filter_group", None)
# If not specified on CLI, check environment
if filter_group_expr is None and exclude_filter_group_expr is None:
import os
active_group = os.environ.get("RAMBLE_ACTIVE_FILTER_GROUP")
if active_group:
filter_group_expr = active_group
if not filter_group_expr and not exclude_filter_group_expr:
return include_where_filters
filter_groups_defs = ramble.config.get("filter_groups")
combined_expr = []
if filter_group_expr:
combined_expr.append(f"( {filter_group_expr} )")
if exclude_filter_group_expr:
combined_expr.append(f"not ( {exclude_filter_group_expr} )")
final_expr = " and ".join(combined_expr)
try:
predicate_expr = expand_filter_groups(final_expr, filter_groups_defs)
if include_where_filters is None:
include_where_filters = []
include_where_filters.append([predicate_expr])
except RambleError as e:
logger.die(str(e))
return include_where_filters
[docs]
def validate_filter_group_name(name: str):
"""Validate filter group name to prevent reserved keywords and invalid characters."""
if name.lower() in ("and", "or", "not"):
raise RambleError(f"Filter group name '{name}' is a reserved keyword.")
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
raise RambleError(
f"Filter group name '{name}' is invalid. "
"It can only contain alphanumeric characters, underscores, and hyphens."
)