import calendar
import logging
import os
import time
from dataclasses import dataclass
from datetime import date
from typing import Any, Dict, List, Optional, Tuple

import requests

logger = logging.getLogger(__name__)


@dataclass
class _MonthCacheEntry:
    created_at: float
    signature: str
    rows: List[Dict[str, Any]]


class WhatConvertsService:
    """
    Pulls monthly lead data from WhatConverts and aggregates by source/medium.

    Caching strategy:
    - Return cached month aggregation for a short TTL window (no API call).
    - After TTL, do a lightweight "change check" call (1 lead) and compare signature.
    - Only fetch all paginated leads again when signature changed.
    """

    BASE_URL = "https://app.whatconverts.com/api/v1/leads"
    CACHE_TTL_SECONDS = 300
    CHECK_TTL_SECONDS = 120
    MAX_LEADS_PER_PAGE = 2500
    SUPPORTED_BRANCHES = ("Bunbury", "Busselton", "Mandurah")

    def __init__(self) -> None:
        self._branch_credentials: Dict[str, Dict[str, str]] = self._load_branch_credentials()
        self._cache: Dict[str, _MonthCacheEntry] = {}
        self._check_cache: Dict[str, Tuple[float, str]] = {}

    def _load_branch_credentials(self) -> Dict[str, Dict[str, str]]:
        out: Dict[str, Dict[str, str]] = {}

        for branch in self.SUPPORTED_BRANCHES:
            key = branch.upper()
            token = (
                os.getenv(f"WHATCONVERTS_API_TOKEN_{key}")
                or os.getenv(f"WHATCONVERTS_{key}_API_TOKEN")
                or ""
            ).strip()
            secret = (
                os.getenv(f"WHATCONVERTS_API_SECRET_{key}")
                or os.getenv(f"WHATCONVERTS_{key}_API_SECRET")
                or ""
            ).strip()
            profile = (
                os.getenv(f"WHATCONVERTS_PROFILE_ID_{key}")
                or os.getenv(f"WHATCONVERTS_{key}_PROFILE_ID")
                or ""
            ).strip()
            out[branch] = {"token": token, "secret": secret, "profile_id": profile}

        global_token = (os.getenv("WHATCONVERTS_API_TOKEN") or "").strip()
        global_secret = (os.getenv("WHATCONVERTS_API_SECRET") or "").strip()
        global_profile = (os.getenv("WHATCONVERTS_PROFILE_ID") or "").strip()
        if global_token and global_secret:
            for branch in self.SUPPORTED_BRANCHES:
                if not out[branch]["token"] or not out[branch]["secret"]:
                    out[branch]["token"] = out[branch]["token"] or global_token
                    out[branch]["secret"] = out[branch]["secret"] or global_secret
                    out[branch]["profile_id"] = out[branch]["profile_id"] or global_profile
        return out

    def enabled(self, branch_name: str) -> bool:
        creds = self._branch_credentials.get(branch_name, {})
        return bool(creds.get("token") and creds.get("secret"))

    def get_monthly_channel_summary(self, branch_name: str, year: int, month: int) -> List[Dict[str, Any]]:
        if not self.enabled(branch_name):
            return []

        cache_key = f"{branch_name}:{year:04d}-{month:02d}"
        now = time.time()
        cached = self._cache.get(cache_key)
        if cached and now - cached.created_at < self.CACHE_TTL_SECONDS:
            return cached.rows

        signature = self._get_month_signature(branch_name, year, month)
        if not signature:
            return cached.rows if cached else []

        if cached and cached.signature == signature:
            self._cache[cache_key] = _MonthCacheEntry(now, cached.signature, cached.rows)
            return cached.rows

        leads = self._fetch_all_month_leads(branch_name, year, month)
        rows = self._aggregate_rows(leads)
        self._cache[cache_key] = _MonthCacheEntry(now, signature, rows)
        return rows

    def _get_month_signature(self, branch_name: str, year: int, month: int) -> Optional[str]:
        check_key = f"{branch_name}:{year:04d}-{month:02d}"
        now = time.time()
        cached_check = self._check_cache.get(check_key)
        if cached_check and now - cached_check[0] < self.CHECK_TTL_SECONDS:
            return cached_check[1]

        params = self._month_params(branch_name, year, month)
        params.update({"leads_per_page": 1, "page_number": 1, "order": "desc"})
        payload = self._get(branch_name, params)
        if not payload:
            return None

        total_leads = int(payload.get("total_leads") or 0)
        leads = payload.get("leads") or []
        newest = leads[0] if leads else {}
        last_updated = str(newest.get("last_updated") or newest.get("date_created") or "")
        signature = f"{total_leads}:{last_updated}"
        self._check_cache[check_key] = (now, signature)
        return signature

    def _fetch_all_month_leads(self, branch_name: str, year: int, month: int) -> List[Dict[str, Any]]:
        leads: List[Dict[str, Any]] = []
        page = 1

        while True:
            params = self._month_params(branch_name, year, month)
            params.update(
                {
                    "leads_per_page": self.MAX_LEADS_PER_PAGE,
                    "page_number": page,
                    "order": "asc",
                }
            )
            payload = self._get(branch_name, params)
            if not payload:
                break

            page_leads = payload.get("leads") or []
            if not page_leads:
                break

            leads.extend(page_leads)
            total_pages = int(payload.get("total_pages") or 1)
            if page >= total_pages:
                break
            page += 1

        return leads

    def _month_params(self, branch_name: str, year: int, month: int) -> Dict[str, Any]:
        start = date(year, month, 1)
        last_day = calendar.monthrange(year, month)[1]
        end = date(year, month, last_day)
        params: Dict[str, Any] = {
            "start_date": start.isoformat(),
            "end_date": end.isoformat(),
        }
        profile_id = (self._branch_credentials.get(branch_name, {}) or {}).get("profile_id")
        if profile_id:
            params["profile_id"] = profile_id
        return params

    def _get(self, branch_name: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        creds = self._branch_credentials.get(branch_name, {})
        token = creds.get("token", "")
        secret = creds.get("secret", "")
        if not token or not secret:
            return None
        try:
            response = requests.get(
                self.BASE_URL,
                params=params,
                auth=(token, secret),
                timeout=25,
            )
            response.raise_for_status()
            return response.json()
        except Exception as exc:
            logger.warning("WhatConverts request failed for %s: %s", branch_name, exc)
            return None

    @staticmethod
    def _as_float(value: Any) -> float:
        try:
            return float(value or 0)
        except (ValueError, TypeError):
            return 0.0

    def _aggregate_rows(self, leads: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        grouped: Dict[str, Dict[str, Any]] = {}

        for lead in leads:
            if bool(lead.get("spam")) or bool(lead.get("duplicate")):
                continue

            source = (str(lead.get("lead_source") or "").strip() or "(direct)").lower()
            medium = (str(lead.get("lead_medium") or "").strip() or "(none)").lower()
            key = f"{source} / {medium}"

            row = grouped.setdefault(
                key,
                {
                    "channel": key,
                    "total_leads": 0,
                    "quotable_leads": 0,
                    "total_quote_value": 0.0,
                    "total_sales_value": 0.0,
                },
            )

            row["total_leads"] += 1
            if str(lead.get("quotable") or "").strip().lower() == "yes":
                row["quotable_leads"] += 1
            row["total_quote_value"] += self._as_float(lead.get("quote_value"))
            row["total_sales_value"] += self._as_float(lead.get("sales_value"))

        rows = list(grouped.values())
        rows.sort(key=lambda x: x["total_leads"], reverse=True)
        return rows


whatconverts_service = WhatConvertsService()
