"""Generic filter engine for resolving DataGroupings into sub-groups.
Replaces the polymorphic resolver registry with a declarative, data-driven
approach. Filter rules and group_by dimensions are JSON on the models.
"""
from typing import List, Optional
from sqlalchemy.orm import Session
from ccat_ops_db import models
from .resolver import SubGroup
from ..logging_utils import get_structured_logger
from ..exceptions import FilterEngineError
logger = get_structured_logger(__name__)
# Declarative join graph: maps (from_table, to_table) to join condition.
# The engine walks this graph to build SQLAlchemy queries from filter_rules.
JOIN_GRAPH = {
("RawDataPackage", "ObsUnit"): lambda: (
models.ObsUnit,
models.RawDataPackage.obs_unit_id == models.ObsUnit.id,
),
("ObsUnit", "Source"): lambda: (
models.Source,
models.ObsUnit.source_id == models.Source.id,
),
("RawDataPackage", "InstrumentModule"): lambda: (
models.InstrumentModule,
models.RawDataPackage.instrument_module_id == models.InstrumentModule.id,
),
("ObsUnit", "ExecutedObsUnit"): lambda: (
models.ExecutedObsUnit,
models.ExecutedObsUnit.obs_unit_id == models.ObsUnit.id,
),
}
# Map table name strings to SQLAlchemy model classes
TABLE_MAP = {
"RawDataPackage": models.RawDataPackage,
"ObsUnit": models.ObsUnit,
"Source": models.Source,
"InstrumentModule": models.InstrumentModule,
"ExecutedObsUnit": models.ExecutedObsUnit,
"DataGrouping": models.DataGrouping,
}
FILTER_OPERATORS = {
"eq": lambda col, val: col == val,
"neq": lambda col, val: col != val,
"in": lambda col, val: col.in_(val),
"not_in": lambda col, val: col.notin_(val),
"gt": lambda col, val: col > val,
"gte": lambda col, val: col >= val,
"lt": lambda col, val: col < val,
"lte": lambda col, val: col <= val,
"like": lambda col, val: col.like(val),
}
[docs]
class FilterEngine:
"""Resolves a DataGrouping's filter_rules + group_by into SubGroups."""
[docs]
def resolve(
self,
session: Session,
data_grouping: models.DataGrouping,
group_by: Optional[List[str]] = None,
) -> List[SubGroup]:
filter_rules = data_grouping.filter_rules or []
group_by = group_by or []
query = session.query(models.RawDataPackage.id)
# Optionally filter by instrument module if set on the grouping
if data_grouping.instrument_module_id is not None:
query = query.filter(
models.RawDataPackage.instrument_module_id
== data_grouping.instrument_module_id
)
# Track which tables we've already joined
joined_tables = {"RawDataPackage"}
# Collect group_by columns for later
group_by_columns = []
# Process filter rules
for rule in filter_rules:
table_name = rule.get("table", "RawDataPackage")
column_name = rule["column"]
operator = rule.get("operator", "eq")
value = rule["value"]
json_path = rule.get("json_path")
# Ensure the required table is joined
query, joined_tables = self._ensure_joined(
query, "RawDataPackage", table_name, joined_tables
)
model_cls = TABLE_MAP.get(table_name)
if model_cls is None:
raise FilterEngineError(f"Unknown table in filter rule: {table_name}")
col = getattr(model_cls, column_name, None)
if col is None:
raise FilterEngineError(
f"Unknown column {column_name} on {table_name}"
)
# JSON path drilling
if json_path:
for key in json_path.split("."):
col = col[key]
op_fn = FILTER_OPERATORS.get(operator)
if op_fn is None:
raise FilterEngineError(f"Unknown filter operator: {operator}")
query = query.filter(op_fn(col, value))
# Handle group_by dimensions
if not group_by:
# No grouping — aggregate all into one sub-group
package_ids = [row[0] for row in query.all()]
if not package_ids:
return []
return [
SubGroup(
key="all",
metadata={},
raw_data_package_ids=package_ids,
)
]
# Add group_by columns to query
for dim in group_by:
parts = dim.split(".")
table_name = parts[0] if len(parts) > 1 else "RawDataPackage"
col_name = parts[-1]
query, joined_tables = self._ensure_joined(
query, "RawDataPackage", table_name, joined_tables
)
model_cls = TABLE_MAP.get(table_name)
if model_cls is None:
raise FilterEngineError(f"Unknown table in group_by: {table_name}")
col = getattr(model_cls, col_name, None)
if col is None:
raise FilterEngineError(
f"Unknown column {col_name} on {table_name} in group_by"
)
group_by_columns.append((dim, col))
query = query.add_columns(col)
# Execute and group results
rows = query.all()
groups = {}
for row in rows:
package_id = row[0]
dim_values = row[1:]
key_parts = []
meta = {}
for (dim, _), val in zip(group_by_columns, dim_values):
key_parts.append(f"{dim.split('.')[-1]}={val}")
meta[dim.split(".")[-1]] = val
key = "|".join(key_parts)
if key not in groups:
groups[key] = SubGroup(
key=key, metadata=meta, raw_data_package_ids=[]
)
groups[key].raw_data_package_ids.append(package_id)
result = list(groups.values())
logger.info(
"filter_engine_resolved",
grouping_id=data_grouping.id,
sub_group_count=len(result),
total_packages=sum(len(sg.raw_data_package_ids) for sg in result),
)
return result
def _ensure_joined(self, query, from_table, to_table, joined_tables):
"""Ensure a table is joined into the query via the join graph."""
if to_table in joined_tables:
return query, joined_tables
join_key = (from_table, to_table)
if join_key not in JOIN_GRAPH:
# Try reverse or indirect paths
for (a, b), join_fn in JOIN_GRAPH.items():
if b == to_table and a in joined_tables:
target, condition = join_fn()
query = query.join(target, condition)
joined_tables = joined_tables | {to_table}
return query, joined_tables
raise FilterEngineError(
f"No join path from {from_table} to {to_table}"
)
target, condition = JOIN_GRAPH[join_key]()
query = query.join(target, condition)
joined_tables = joined_tables | {to_table}
return query, joined_tables