Source code for ccat_workflow_manager.grouping.engine

"""Generic filter engine for resolving DataGroupings into sub-groups.

Replaces the polymorphic resolver registry with a declarative, data-driven
approach. Filter rules and group_by dimensions are JSON on the models.
"""

from typing import List, Optional

from sqlalchemy.orm import Session
from ccat_ops_db import models

from .resolver import SubGroup
from ..logging_utils import get_structured_logger
from ..exceptions import FilterEngineError

logger = get_structured_logger(__name__)

# Declarative join graph: maps (from_table, to_table) to join condition.
# The engine walks this graph to build SQLAlchemy queries from filter_rules.
JOIN_GRAPH = {
    ("RawDataPackage", "ObsUnit"): lambda: (
        models.ObsUnit,
        models.RawDataPackage.obs_unit_id == models.ObsUnit.id,
    ),
    ("ObsUnit", "Source"): lambda: (
        models.Source,
        models.ObsUnit.source_id == models.Source.id,
    ),
    ("RawDataPackage", "InstrumentModule"): lambda: (
        models.InstrumentModule,
        models.RawDataPackage.instrument_module_id == models.InstrumentModule.id,
    ),
    ("ObsUnit", "ExecutedObsUnit"): lambda: (
        models.ExecutedObsUnit,
        models.ExecutedObsUnit.obs_unit_id == models.ObsUnit.id,
    ),
}

# Map table name strings to SQLAlchemy model classes
TABLE_MAP = {
    "RawDataPackage": models.RawDataPackage,
    "ObsUnit": models.ObsUnit,
    "Source": models.Source,
    "InstrumentModule": models.InstrumentModule,
    "ExecutedObsUnit": models.ExecutedObsUnit,
    "DataGrouping": models.DataGrouping,
}

FILTER_OPERATORS = {
    "eq": lambda col, val: col == val,
    "neq": lambda col, val: col != val,
    "in": lambda col, val: col.in_(val),
    "not_in": lambda col, val: col.notin_(val),
    "gt": lambda col, val: col > val,
    "gte": lambda col, val: col >= val,
    "lt": lambda col, val: col < val,
    "lte": lambda col, val: col <= val,
    "like": lambda col, val: col.like(val),
}


[docs] class FilterEngine: """Resolves a DataGrouping's filter_rules + group_by into SubGroups."""
[docs] def resolve( self, session: Session, data_grouping: models.DataGrouping, group_by: Optional[List[str]] = None, ) -> List[SubGroup]: filter_rules = data_grouping.filter_rules or [] group_by = group_by or [] query = session.query(models.RawDataPackage.id) # Optionally filter by instrument module if set on the grouping if data_grouping.instrument_module_id is not None: query = query.filter( models.RawDataPackage.instrument_module_id == data_grouping.instrument_module_id ) # Track which tables we've already joined joined_tables = {"RawDataPackage"} # Collect group_by columns for later group_by_columns = [] # Process filter rules for rule in filter_rules: table_name = rule.get("table", "RawDataPackage") column_name = rule["column"] operator = rule.get("operator", "eq") value = rule["value"] json_path = rule.get("json_path") # Ensure the required table is joined query, joined_tables = self._ensure_joined( query, "RawDataPackage", table_name, joined_tables ) model_cls = TABLE_MAP.get(table_name) if model_cls is None: raise FilterEngineError(f"Unknown table in filter rule: {table_name}") col = getattr(model_cls, column_name, None) if col is None: raise FilterEngineError( f"Unknown column {column_name} on {table_name}" ) # JSON path drilling if json_path: for key in json_path.split("."): col = col[key] op_fn = FILTER_OPERATORS.get(operator) if op_fn is None: raise FilterEngineError(f"Unknown filter operator: {operator}") query = query.filter(op_fn(col, value)) # Handle group_by dimensions if not group_by: # No grouping — aggregate all into one sub-group package_ids = [row[0] for row in query.all()] if not package_ids: return [] return [ SubGroup( key="all", metadata={}, raw_data_package_ids=package_ids, ) ] # Add group_by columns to query for dim in group_by: parts = dim.split(".") table_name = parts[0] if len(parts) > 1 else "RawDataPackage" col_name = parts[-1] query, joined_tables = self._ensure_joined( query, "RawDataPackage", table_name, joined_tables ) model_cls = TABLE_MAP.get(table_name) if model_cls is None: raise FilterEngineError(f"Unknown table in group_by: {table_name}") col = getattr(model_cls, col_name, None) if col is None: raise FilterEngineError( f"Unknown column {col_name} on {table_name} in group_by" ) group_by_columns.append((dim, col)) query = query.add_columns(col) # Execute and group results rows = query.all() groups = {} for row in rows: package_id = row[0] dim_values = row[1:] key_parts = [] meta = {} for (dim, _), val in zip(group_by_columns, dim_values): key_parts.append(f"{dim.split('.')[-1]}={val}") meta[dim.split(".")[-1]] = val key = "|".join(key_parts) if key not in groups: groups[key] = SubGroup( key=key, metadata=meta, raw_data_package_ids=[] ) groups[key].raw_data_package_ids.append(package_id) result = list(groups.values()) logger.info( "filter_engine_resolved", grouping_id=data_grouping.id, sub_group_count=len(result), total_packages=sum(len(sg.raw_data_package_ids) for sg in result), ) return result
def _ensure_joined(self, query, from_table, to_table, joined_tables): """Ensure a table is joined into the query via the join graph.""" if to_table in joined_tables: return query, joined_tables join_key = (from_table, to_table) if join_key not in JOIN_GRAPH: # Try reverse or indirect paths for (a, b), join_fn in JOIN_GRAPH.items(): if b == to_table and a in joined_tables: target, condition = join_fn() query = query.join(target, condition) joined_tables = joined_tables | {to_table} return query, joined_tables raise FilterEngineError( f"No join path from {from_table} to {to_table}" ) target, condition = JOIN_GRAPH[join_key]() query = query.join(target, condition) joined_tables = joined_tables | {to_table} return query, joined_tables