import logging
from typing import Tuple, Optional
from sqlalchemy import create_engine, Engine, text
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.pool import NullPool
from urllib.parse import quote_plus
from .config.config import ccat_ops_db_settings
# Add this near the top of the file
logger = logging.getLogger(__name__)
logger.propagate = True
Base = declarative_base()
[docs]
def get_database_url(
database_type: str,
database: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
user: Optional[str] = None,
password: Optional[str] = None,
db_name: Optional[str] = None,
async_driver: bool = False,
) -> str:
"""
Generate database URL based on type and configuration.
Args:
database_type: Type of database ('sqlite', 'mysql', 'postgresql')
database: Optional database URL override
host: Optional host override
port: Optional port override
user: Optional user override
password: Optional password override
db_name: Optional database name override
async_driver: Whether to use async driver (e.g., asyncpg for PostgreSQL)
"""
if database_type == "sqlite":
database = database or ccat_ops_db_settings.DATABASE_SQLITE_DATABASE
return f"sqlite:///{database}"
elif database_type == "mysql":
user = user or ccat_ops_db_settings.DATABASE_MYSQL_USER
password = password or ccat_ops_db_settings.DATABASE_MYSQL_PASSWORD
host = host or ccat_ops_db_settings.DATABASE_MYSQL_HOST
port = port or ccat_ops_db_settings.DATABASE_MYSQL_PORT
db = db_name or database or ccat_ops_db_settings.DATABASE_MYSQL_DATABASE
return f"mysql+mysqldb://{user}:{quote_plus(password)}@{host}:{port}/{db}"
elif database_type == "postgresql":
user = user or ccat_ops_db_settings.DATABASE_POSTGRESQL_USER
password = password or ccat_ops_db_settings.DATABASE_POSTGRESQL_PASSWORD
host = host or ccat_ops_db_settings.DATABASE_POSTGRESQL_HOST
port = port or ccat_ops_db_settings.DATABASE_POSTGRESQL_PORT
db = db_name or database or ccat_ops_db_settings.DATABASE_POSTGRESQL_DATABASE
if async_driver:
return (
f"postgresql+asyncpg://{user}:{quote_plus(password)}@{host}:{port}/{db}"
)
else:
return f"postgresql://{user}:{quote_plus(password)}@{host}:{port}/{db}"
else:
raise ValueError(f"Unsupported database type: {database_type}")
[docs]
def init_ccat_ops_db(
database_type: Optional[str] = None,
database: Optional[str] = None,
drop: bool = False,
null_pool: bool = False,
host: Optional[str] = None,
port: Optional[str] = None,
user: Optional[str] = None,
password: Optional[str] = None,
db_name: Optional[str] = None,
) -> Tuple[scoped_session, Engine]:
"""
Initialize a database connection based on the supplied configuration.
Parameters:
-----------
database_type: str, optional
Can be 'sqlite', 'mysql', or 'postgresql'. Defaults to config setting.
database: str, optional
URL string to connect to the database. Defaults to config setting.
drop: bool, default False
If True, drops all tables before creating them.
null_pool: bool, default False
If True, uses NullPool instead of the default connection pool.
host: str, optional
Database host override.
port: str, optional
Database port override.
user: str, optional
Database user override.
password: str, optional
Database password override.
db_name: str, optional
Database name override.
Returns:
--------
Tuple[scoped_session, Engine]
A tuple containing the database session and engine.
"""
database_type = database_type or ccat_ops_db_settings.DATABASE_TYPE
logger.info(
"Using database_type %s host %s port %s",
database_type,
host or ccat_ops_db_settings.DATABASE_POSTGRESQL_HOST,
port or ccat_ops_db_settings.DATABASE_POSTGRESQL_PORT,
)
url = get_database_url(database_type, database, host, port, user, password, db_name)
engine_kwargs = {
"echo": False,
"pool_pre_ping": True,
}
if null_pool:
engine_kwargs["poolclass"] = NullPool
engine = create_engine(url, **engine_kwargs)
if drop:
logger.info("Dropping all tables and types")
# For PostgreSQL, we need to drop enum types first
if database_type == "postgresql":
with engine.connect() as conn:
conn.execute(text("DROP TYPE IF EXISTS status CASCADE"))
Base.metadata.drop_all(bind=engine)
elif database_type == "postgresql":
# For PostgreSQL, we need to ensure the enum type exists
with engine.connect() as conn:
# Check if the enum type exists
result = conn.execute(
text("SELECT 1 FROM pg_type WHERE typname = 'status'")
)
if not result.scalar():
# Create the enum type if it doesn't exist
conn.execute(
text(
"""
CREATE TYPE status AS ENUM (
'pending',
'scheduled',
'in_progress',
'completed',
'failed'
)
"""
)
)
Base.metadata.create_all(bind=engine)
session = scoped_session(sessionmaker(bind=engine))
return session, engine