import csv
import io
import re
from datetime import datetime

from django.core.management.base import BaseCommand, CommandError
from django.db import transaction
from django.utils import timezone

from accounts.models import Permission, Role, User, UserPermission


INSERT_USERS_RE = re.compile(
    r"INSERT INTO\s+`users`\s*\((?P<columns>[^)]+)\)\s*VALUES\s*(?P<values>.*?);",
    re.IGNORECASE | re.DOTALL,
)
INSERT_PERMISSIONS_RE = re.compile(
    r"INSERT INTO\s+`permissions`\s*\((?P<columns>[^)]+)\)\s*VALUES\s*(?P<values>.*?);",
    re.IGNORECASE | re.DOTALL,
)
INSERT_USER_PERMISSIONS_RE = re.compile(
    r"INSERT INTO\s+`user_permissions`\s*\((?P<columns>[^)]+)\)\s*VALUES\s*(?P<values>.*?);",
    re.IGNORECASE | re.DOTALL,
)


class DryRunRollback(Exception):
    pass


def _parse_datetime(raw_value):
    if not raw_value:
        return None
    dt = datetime.fromisoformat(raw_value)
    if timezone.is_naive(dt):
        return timezone.make_aware(dt, timezone.get_current_timezone())
    return dt


def _split_row_tuples(values_blob):
    rows = []
    i = 0
    n = len(values_blob)
    while i < n:
        if values_blob[i] != "(":
            i += 1
            continue

        start = i + 1
        i += 1
        in_string = False
        while i < n:
            ch = values_blob[i]
            if ch == "'":
                # SQL escaping can be '' or \' - support both.
                next_ch = values_blob[i + 1] if i + 1 < n else ""
                prev_ch = values_blob[i - 1] if i > 0 else ""
                if in_string and next_ch == "'":
                    i += 2
                    continue
                if in_string and prev_ch == "\\":
                    i += 1
                    continue
                in_string = not in_string
                i += 1
                continue
            if ch == ")" and not in_string:
                rows.append(values_blob[start:i])
                i += 1
                break
            i += 1
    return rows


def _parse_row(row_text):
    reader = csv.reader(io.StringIO(row_text), delimiter=",", quotechar="'", skipinitialspace=True)
    values = next(reader)
    parsed = []
    for value in values:
        token = value.strip()
        if token.upper() == "NULL":
            parsed.append(None)
            continue
        if token.isdigit():
            parsed.append(int(token))
            continue
        parsed.append(token.replace("''", "'"))
    return parsed


class Command(BaseCommand):
    help = "Import legacy FastAPI users from an SQL dump into Django users table."

    def add_arguments(self, parser):
        parser.add_argument(
            "--sql-file",
            required=True,
            help="Path to SQL dump containing INSERT INTO `users` rows.",
        )
        parser.add_argument(
            "--preserve-ids",
            action="store_true",
            help="Preserve legacy user IDs as primary keys when possible.",
        )
        parser.add_argument(
            "--dry-run",
            action="store_true",
            help="Parse and report changes without writing to the database.",
        )
        parser.add_argument(
            "--skip-permissions",
            action="store_true",
            help="Skip importing permissions and user_permissions from SQL.",
        )

    def handle(self, *args, **options):
        sql_file = options["sql_file"]
        preserve_ids = options["preserve_ids"]
        dry_run = options["dry_run"]
        skip_permissions = options["skip_permissions"]

        try:
            with open(sql_file, "r", encoding="utf-8") as fh:
                sql_text = fh.read()
        except OSError as exc:
            raise CommandError(f"Unable to read SQL file: {exc}") from exc

        match = INSERT_USERS_RE.search(sql_text)
        if not match:
            raise CommandError("Could not find INSERT INTO `users` block in SQL file.")

        columns = [c.strip().strip("`") for c in match.group("columns").split(",")]
        rows = _split_row_tuples(match.group("values"))
        if not rows:
            raise CommandError("No user rows found in users INSERT statement.")

        permissions_rows = []
        permissions_columns = []
        user_permissions_rows = []
        user_permissions_columns = []
        if not skip_permissions:
            permission_match = INSERT_PERMISSIONS_RE.search(sql_text)
            if permission_match:
                permissions_columns = [c.strip().strip("`") for c in permission_match.group("columns").split(",")]
                permissions_rows = _split_row_tuples(permission_match.group("values"))
            user_permission_match = INSERT_USER_PERMISSIONS_RE.search(sql_text)
            if user_permission_match:
                user_permissions_columns = [
                    c.strip().strip("`") for c in user_permission_match.group("columns").split(",")
                ]
                user_permissions_rows = _split_row_tuples(user_permission_match.group("values"))

        role_map = {r.id: r for r in Role.objects.all()}
        role_slug_map = {r.slug: r for r in Role.objects.exclude(slug__isnull=True)}

        created = 0
        updated = 0
        skipped = 0
        permissions_upserted = 0
        user_permissions_upserted = 0

        @transaction.atomic
        def do_import():
            nonlocal created, updated, skipped, permissions_upserted, user_permissions_upserted
            for row_text in rows:
                values = _parse_row(row_text)
                record = dict(zip(columns, values))
                username = record.get("username")
                password_hash = record.get("password_hash")
                if not username or not password_hash:
                    skipped += 1
                    continue

                role_value = (record.get("role") or "user").lower()
                role_obj = role_map.get(record.get("role_id")) or role_slug_map.get(role_value)
                defaults = {
                    "password": password_hash,
                    "role": role_value,
                    "role_obj": role_obj,
                    "created_at": _parse_datetime(record.get("created_at")) or timezone.now(),
                    "invite_token": record.get("invite_token"),
                    "invite_expires_at": _parse_datetime(record.get("invite_expires_at")),
                    "reset_token": record.get("reset_token"),
                    "reset_expires_at": _parse_datetime(record.get("reset_expires_at")),
                    "is_active": True,
                }

                user_id = record.get("id")
                if preserve_ids and user_id is not None:
                    existing = User.objects.filter(id=user_id).first()
                    if existing and existing.username != username:
                        raise CommandError(
                            f"ID collision for legacy id={user_id}: "
                            f"existing username={existing.username}, incoming={username}"
                        )
                    obj, was_created = User.objects.update_or_create(
                        id=user_id,
                        defaults={**defaults, "username": username},
                    )
                else:
                    obj, was_created = User.objects.update_or_create(
                        username=username,
                        defaults=defaults,
                    )

                if was_created:
                    created += 1
                else:
                    updated += 1

                if not dry_run:
                    # Ensure Django auth flags track role for admin/director users.
                    should_staff = obj.role in {"admin", "director"}
                    patch = {}
                    if obj.is_staff != should_staff:
                        patch["is_staff"] = should_staff
                    if obj.is_superuser != (obj.role == "admin"):
                        patch["is_superuser"] = obj.role == "admin"
                    if patch:
                        User.objects.filter(id=obj.id).update(**patch)

            if not skip_permissions and permissions_rows:
                for row_text in permissions_rows:
                    values = _parse_row(row_text)
                    record = dict(zip(permissions_columns, values))
                    key = record.get("key")
                    name = record.get("name")
                    category = record.get("category")
                    if not key or not name or not category:
                        continue
                    permission_id = record.get("id")
                    if preserve_ids and permission_id is not None:
                        Permission.objects.update_or_create(
                            id=permission_id,
                            defaults={"key": key, "name": name, "category": category},
                        )
                    else:
                        Permission.objects.update_or_create(
                            key=key,
                            defaults={"name": name, "category": category},
                        )
                    permissions_upserted += 1

            if not skip_permissions and user_permissions_rows:
                user_exists = set(User.objects.values_list("id", flat=True))
                permission_exists = set(Permission.objects.values_list("id", flat=True))
                for row_text in user_permissions_rows:
                    values = _parse_row(row_text)
                    record = dict(zip(user_permissions_columns, values))
                    user_id = record.get("user_id")
                    permission_id = record.get("permission_id")
                    if user_id not in user_exists or permission_id not in permission_exists:
                        continue
                    UserPermission.objects.update_or_create(user_id=user_id, permission_id=permission_id)
                    user_permissions_upserted += 1

            if dry_run:
                raise DryRunRollback()

        try:
            do_import()
        except DryRunRollback:
            # Expected rollback signal for dry-run.
            pass

        self.stdout.write(
            self.style.SUCCESS(
                "Legacy import complete: "
                f"users(created={created}, updated={updated}, skipped={skipped}) "
                f"permissions(upserted={permissions_upserted}) "
                f"user_permissions(upserted={user_permissions_upserted}) "
                f"dry_run={dry_run}"
            )
        )
