Source code for ccat_ops_db.ccat_ops_db

import logging
from typing import Tuple, Optional
from sqlalchemy import create_engine, Engine, text
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.pool import NullPool
from urllib.parse import quote_plus
from .config.config import ccat_ops_db_settings

# Add this near the top of the file
logger = logging.getLogger(__name__)
logger.propagate = True

Base = declarative_base()


[docs] def get_database_url( database_type: str, database: Optional[str] = None, host: Optional[str] = None, port: Optional[str] = None, user: Optional[str] = None, password: Optional[str] = None, db_name: Optional[str] = None, async_driver: bool = False, ) -> str: """ Generate database URL based on type and configuration. Args: database_type: Type of database ('sqlite', 'mysql', 'postgresql') database: Optional database URL override host: Optional host override port: Optional port override user: Optional user override password: Optional password override db_name: Optional database name override async_driver: Whether to use async driver (e.g., asyncpg for PostgreSQL) """ if database_type == "sqlite": database = database or ccat_ops_db_settings.DATABASE_SQLITE_DATABASE return f"sqlite:///{database}" elif database_type == "mysql": user = user or ccat_ops_db_settings.DATABASE_MYSQL_USER password = password or ccat_ops_db_settings.DATABASE_MYSQL_PASSWORD host = host or ccat_ops_db_settings.DATABASE_MYSQL_HOST port = port or ccat_ops_db_settings.DATABASE_MYSQL_PORT db = db_name or database or ccat_ops_db_settings.DATABASE_MYSQL_DATABASE return f"mysql+mysqldb://{user}:{quote_plus(password)}@{host}:{port}/{db}" elif database_type == "postgresql": user = user or ccat_ops_db_settings.DATABASE_POSTGRESQL_USER password = password or ccat_ops_db_settings.DATABASE_POSTGRESQL_PASSWORD host = host or ccat_ops_db_settings.DATABASE_POSTGRESQL_HOST port = port or ccat_ops_db_settings.DATABASE_POSTGRESQL_PORT db = db_name or database or ccat_ops_db_settings.DATABASE_POSTGRESQL_DATABASE if async_driver: return ( f"postgresql+asyncpg://{user}:{quote_plus(password)}@{host}:{port}/{db}" ) else: return f"postgresql://{user}:{quote_plus(password)}@{host}:{port}/{db}" else: raise ValueError(f"Unsupported database type: {database_type}")
[docs] def init_ccat_ops_db( database_type: Optional[str] = None, database: Optional[str] = None, drop: bool = False, null_pool: bool = False, host: Optional[str] = None, port: Optional[str] = None, user: Optional[str] = None, password: Optional[str] = None, db_name: Optional[str] = None, ) -> Tuple[scoped_session, Engine]: """ Initialize a database connection based on the supplied configuration. Parameters: ----------- database_type: str, optional Can be 'sqlite', 'mysql', or 'postgresql'. Defaults to config setting. database: str, optional URL string to connect to the database. Defaults to config setting. drop: bool, default False If True, drops all tables before creating them. null_pool: bool, default False If True, uses NullPool instead of the default connection pool. host: str, optional Database host override. port: str, optional Database port override. user: str, optional Database user override. password: str, optional Database password override. db_name: str, optional Database name override. Returns: -------- Tuple[scoped_session, Engine] A tuple containing the database session and engine. """ database_type = database_type or ccat_ops_db_settings.DATABASE_TYPE logger.info( "Using database_type %s host %s port %s", database_type, host or ccat_ops_db_settings.DATABASE_POSTGRESQL_HOST, port or ccat_ops_db_settings.DATABASE_POSTGRESQL_PORT, ) url = get_database_url(database_type, database, host, port, user, password, db_name) engine_kwargs = { "echo": False, "pool_pre_ping": True, } if null_pool: engine_kwargs["poolclass"] = NullPool engine = create_engine(url, **engine_kwargs) if drop: logger.info("Dropping all tables and types") # For PostgreSQL, we need to drop enum types first if database_type == "postgresql": with engine.connect() as conn: conn.execute(text("DROP TYPE IF EXISTS status CASCADE")) Base.metadata.drop_all(bind=engine) elif database_type == "postgresql": # For PostgreSQL, we need to ensure the enum type exists with engine.connect() as conn: # Check if the enum type exists result = conn.execute( text("SELECT 1 FROM pg_type WHERE typname = 'status'") ) if not result.scalar(): # Create the enum type if it doesn't exist conn.execute( text( """ CREATE TYPE status AS ENUM ( 'pending', 'scheduled', 'in_progress', 'completed', 'failed' ) """ ) ) Base.metadata.create_all(bind=engine) session = scoped_session(sessionmaker(bind=engine)) return session, engine