"""
Database connection and session management.
Supports SQLite (dev) and MySQL (production via DATABASE_URL).
"""
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.orm import sessionmaker, Session
from app.config import settings
from app.models import Base
import os
import logging

logger = logging.getLogger(__name__)

# Create data directory if it doesn't exist (for SQLite)
os.makedirs("data", exist_ok=True)

# SQLite needs check_same_thread=False; MySQL does not
_connect_args = {}
if settings.DATABASE_URL.startswith("sqlite"):
    _connect_args["check_same_thread"] = False

engine = create_engine(
    settings.DATABASE_URL,
    connect_args=_connect_args,
)

# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def _add_column_if_missing(conn, table: str, column: str, spec: str):
    """Add column to table if it doesn't exist. spec is DB-specific (e.g. VARCHAR(64))."""
    insp = inspect(engine)
    cols = [c["name"] for c in insp.get_columns(table)]
    if column in cols:
        return
    try:
        conn.execute(text(f"ALTER TABLE {table} ADD COLUMN {column} {spec}"))
        conn.commit()
        logger.info("Added column %s.%s", table, column)
    except Exception as e:
        conn.rollback()
        logger.warning("Could not add %s.%s: %s", table, column, e)


def _migrate_custom_page_layout_title():
    """Add title column to custom_page_layouts if missing."""
    with engine.connect() as conn:
        spec = "VARCHAR(200)" if "mysql" in settings.DATABASE_URL else "VARCHAR(200)"
        _add_column_if_missing(conn, "custom_page_layouts", "title", spec)


def _migrate_custom_page_view_permission():
    """Add view_permission_key to custom_page_layouts if missing."""
    with engine.connect() as conn:
        spec = "VARCHAR(80)" if "mysql" in settings.DATABASE_URL else "VARCHAR(80)"
        _add_column_if_missing(conn, "custom_page_layouts", "view_permission_key", spec)


def _migrate_invite_otp_columns():
    """Add invite, OTP, and password reset columns to existing tables."""
    with engine.connect() as conn:
        _add_column_if_missing(conn, "users", "invite_token", "VARCHAR(64)" if "mysql" in settings.DATABASE_URL else "VARCHAR(64)")
        _add_column_if_missing(conn, "users", "invite_expires_at", "DATETIME" if "mysql" in settings.DATABASE_URL else "DATETIME")
        _add_column_if_missing(conn, "users", "reset_token", "VARCHAR(64)" if "mysql" in settings.DATABASE_URL else "VARCHAR(64)")
        _add_column_if_missing(conn, "users", "reset_expires_at", "DATETIME" if "mysql" in settings.DATABASE_URL else "DATETIME")
        _add_column_if_missing(conn, "sessions", "otp_verified_at", "DATETIME" if "mysql" in settings.DATABASE_URL else "DATETIME")
        _add_column_if_missing(conn, "sessions", "otp_code_hash", "VARCHAR(255)" if "mysql" in settings.DATABASE_URL else "VARCHAR(255)")
        _add_column_if_missing(conn, "sessions", "otp_expires_at", "DATETIME" if "mysql" in settings.DATABASE_URL else "DATETIME")


def init_db():
    """Initialize the database by creating all tables. Migrate JSON config to DB if MySQL. Seed roles & permissions."""
    Base.metadata.create_all(bind=engine)
    try:
        _migrate_invite_otp_columns()
    except Exception as e:
        logger.warning("Invite/OTP column migration skipped: %s", e)
    try:
        _migrate_custom_page_layout_title()
    except Exception as e:
        logger.warning("custom_page_layouts.title migration skipped: %s", e)
    try:
        _migrate_custom_page_view_permission()
    except Exception as e:
        logger.warning("custom_page_layouts.view_permission_key migration skipped: %s", e)
    print("✅ Database initialized")
    if settings.DATABASE_URL and "mysql" in settings.DATABASE_URL:
        try:
            from app.config.db_config_storage import migrate_json_config_to_db_if_present
            migrate_json_config_to_db_if_present()
        except Exception as e:
            import logging
            logging.getLogger(__name__).warning(f"Config migration skipped: {e}")
    try:
        from app.seed_roles_permissions import seed_roles_permissions
        seed_roles_permissions()
        print("✅ Roles and permissions seeded")
    except Exception as e:
        import logging
        logging.getLogger(__name__).warning(f"Roles/permissions seed skipped: {e}")

def get_db() -> Session:
    """
    Dependency function to get database session.
    Yields a session and ensures it's closed after use.
    """
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

