import requests
import json
import os
import time
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from dotenv import load_dotenv
from token_manager import refresh_access_token, update_env_var

load_dotenv()

TENANT_ID = os.getenv("TENANT_ID")


def get_tenant_id(access_token):
    """
    Get the tenant ID from Xero's connections endpoint.
    If multiple tenants exist, returns the first one.
    """
    url = "https://api.xero.com/connections"
    
    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/json"
    }
    
    response = requests.get(url, headers=headers, timeout=30)
    
    if response.status_code != 200:
        print(f"Error fetching connections: {response.status_code}")
        print(f"Response: {response.text}")
        response.raise_for_status()
    
    connections = response.json()
    
    if not connections:
        raise ValueError("No Xero connections found. Please authorize your app first.")
    
    # Get the first (or primary) tenant
    tenant_id = connections[0].get("tenantId")
    tenant_name = connections[0].get("tenantName", "Unknown")
    
    if not tenant_id:
        raise ValueError("No tenant ID found in connections response")
    
    # Save to .env if not already set
    if not TENANT_ID and tenant_id:
        update_env_var("TENANT_ID", tenant_id)
        print(f"Saved tenant ID to .env: {tenant_name} ({tenant_id})")
    
    return tenant_id


def verify_tenant_access(access_token, tenant_id):
    """
    Verify that we have access to the tenant by making a simple API call.
    """
    url = "https://api.xero.com/api.xro/2.0/Organisation"
    
    headers = {
        "Authorization": f"Bearer {access_token}",
        "xero-tenant-id": tenant_id,
        "Accept": "application/json"
    }
    
    response = requests.get(url, headers=headers, timeout=30)
    
    if response.status_code == 200:
        org_data = response.json()
        orgs = org_data.get("Organisations", [])
        if orgs:
            org_name = orgs[0].get("Name", "Unknown")
            print(f"Verified access to organization: {org_name}")
            return True
    
    return False


def get_profit_and_loss(access_token, tenant_id=None, from_date=None, to_date=None, periods=None, timeframe=None,
                        tracking_category_id=None, tracking_option_id=None):
    """
    Fetch Profit & Loss report from Xero API.
    
    Args:
        access_token: OAuth access token
        tenant_id: Xero tenant ID (will be fetched if not provided)
        from_date: Start date in YYYY-MM-DD format (optional)
        to_date: End date in YYYY-MM-DD format (optional)
        periods: Number of periods to include (optional)
        timeframe: MONTH, QUARTER, YEAR (optional)
        tracking_category_id: Filter by tracking category (department reports)
        tracking_option_id: Filter to specific tracking option (single department)
    
    Returns:
        JSON response from Xero API
    """
    # Get tenant ID if not provided
    if not tenant_id:
        tenant_id = TENANT_ID or get_tenant_id(access_token)
    
    if not tenant_id:
        raise ValueError("TENANT_ID is required. Set it in .env or it will be fetched automatically.")
    
    url = "https://api.xero.com/api.xro/2.0/Reports/ProfitAndLoss"
    
    # Build query parameters
    params = {}
    if from_date:
        params["fromDate"] = from_date
    if to_date:
        params["toDate"] = to_date
    if periods:
        params["periods"] = periods
    if timeframe:
        params["timeframe"] = timeframe
    if tracking_category_id:
        params["trackingCategoryID"] = tracking_category_id
    if tracking_option_id:
        params["trackingOptionID"] = tracking_option_id
    
    headers = {
        "Authorization": f"Bearer {access_token}",
        "xero-tenant-id": tenant_id,
        "Accept": "application/json"
    }
    
    # Retry logic for 429 rate limiting
    max_retries = 3
    max_wait_time = 60  # Cap wait time at 60 seconds
    last_retry_after = None  # Store the last Retry-After value
    
    for attempt in range(max_retries):
        response = requests.get(url, headers=headers, params=params if params else None, timeout=30)
        
        if response.status_code == 429:
            # Parse Retry-After - Xero may send seconds (int) or HTTP-date
            retry_after_raw = response.headers.get('Retry-After') or ''
            rate_limit_problem = response.headers.get('X-Rate-Limit-Problem', 'unknown')
            
            try:
                if retry_after_raw.isdigit():
                    retry_after_int = int(retry_after_raw)
                elif retry_after_raw:
                    # HTTP-date format: "Wed, 21 Oct 2015 07:28:00 GMT"
                    retry_until = parsedate_to_datetime(retry_after_raw)
                    now = datetime.now(timezone.utc)
                    retry_until_utc = retry_until if retry_until.tzinfo else retry_until.replace(tzinfo=timezone.utc)
                    retry_after_int = max(0, int((retry_until_utc - now).total_seconds()))
                else:
                    retry_after_int = 3600  # Fallback: 1 hour when Xero doesn't specify
                    print(f"   (Xero did not send Retry-After header - assuming 1 hour. Daily limit resets at midnight UTC)")
            except (ValueError, TypeError):
                retry_after_int = 3600  # Fallback: 1 hour
            
            last_retry_after = retry_after_int
            wait_time = min(retry_after_int, max_wait_time)
            capped = retry_after_int > max_wait_time
            
            # Always show user when they can try again
            def fmt_time(sec):
                if sec >= 3600:
                    return f"{sec}s ({sec/3600:.1f} hours)"
                elif sec >= 60:
                    return f"{sec}s ({sec/60:.0f} min)"
                return f"{sec}s"
            
            if attempt < max_retries - 1:
                if capped:
                    print(f"⏱️  Rate limited (429). Try again in {fmt_time(last_retry_after)}. Waiting {wait_time}s before retry {attempt + 1}/{max_retries}...")
                else:
                    print(f"⏱️  Rate limited (429). Try again in {fmt_time(last_retry_after)}. Waiting before retry {attempt + 1}/{max_retries}...")
                time.sleep(wait_time)
                continue
            # Last attempt - clear message for user
            if last_retry_after >= 3600:
                print(f"❌ Rate limited (429) after {max_retries} attempts. Try again in {fmt_time(last_retry_after)}. (Limit: {rate_limit_problem})")
            else:
                print(f"❌ Rate limited (429) after {max_retries} attempts. Try again in {fmt_time(last_retry_after)}.")
        
        break  # Success or non-429 error
    
    if response.status_code != 200:
        error_detail = response.text
        
        # Handle 429 specially - attach Retry-After info
        if response.status_code == 429 and last_retry_after:
            error_msg = f"Xero API rate limit exceeded. Retry after {last_retry_after} seconds."
            # Create a custom exception with retry info
            error = requests.exceptions.HTTPError(error_msg, response=response)
            error.retry_after = last_retry_after  # Attach retry_after to the exception
            raise error
        
        # Check for token expiration
        if response.status_code == 401:
            try:
                error_json = response.json()
                if "TokenExpired" in error_detail or "token expired" in error_detail.lower():
                    # This is a token expiration - let the caller handle it
                    pass
            except:
                pass
        
        # Only print detailed error info if not a token expiration (to reduce noise)
        if response.status_code != 401 or "TokenExpired" not in error_detail:
            print(f"Error response status: {response.status_code}")
            print(f"Error response body: {error_detail}")
            print(f"Request URL: {url}")
            print(f"Tenant ID used: {tenant_id}")
            if response.status_code == 403:
                print("\n403 Forbidden usually means:")
                print("   - Tenant ID is incorrect or missing")
                print("   - Access token doesn't have required scopes")
                print("   - App hasn't been authorized for this organization")
        response.raise_for_status()
    
    return response.json()


def get_balance_sheet(access_token, tenant_id=None, date_str=None):
    """
    Fetch Balance Sheet report from Xero API.
    Returns account balances as at the given date (default: today).
    """
    if not tenant_id:
        tenant_id = TENANT_ID or get_tenant_id(access_token)
    if not tenant_id:
        raise ValueError("TENANT_ID is required.")
    if not date_str:
        date_str = datetime.now().strftime("%Y-%m-%d")
    url = "https://api.xero.com/api.xro/2.0/Reports/BalanceSheet"
    params = {"date": date_str}
    headers = {
        "Authorization": f"Bearer {access_token}",
        "xero-tenant-id": tenant_id,
        "Accept": "application/json"
    }
    response = requests.get(url, headers=headers, params=params, timeout=30)
    if response.status_code != 200:
        response.raise_for_status()
    return response.json()


def extract_cash_in_bank_from_balance_sheet(bs_data, account_names):
    """
    Sum statement balances for accounts matching the given names.
    Uses the Balance Sheet report (balance as at date); same three accounts used for all branches.
    Matches exact names (case-insensitive) and common variants (e.g. "Business Trans Account").
    Balance Sheet returns Rows with nested structure similar to P&L.
    """
    total = 0.0
    names_normalised = {n.strip().lower(): n for n in account_names if n}
    if not bs_data or "Reports" not in bs_data or not bs_data["Reports"]:
        return total
    report = bs_data["Reports"][0]
    rows = report.get("Rows", [])

    def _parse_amount(val):
        try:
            return float(str(val).replace(",", "")) if val else 0.0
        except (ValueError, TypeError):
            return 0.0

    def _normalise_label(s):
        """Normalise for comparison: collapse spaces, fix & encoding."""
        if not s:
            return ""
        s = (s or "").strip().lower().replace("\xa0", " ")
        s = s.replace("&amp;", "&").replace(" and ", " & ")
        return " ".join(s.split())

    def _matches_account(label, names_norm):
        """Check if label matches any of the configured account names (exact or phrase match)."""
        label_clean = _normalise_label(label)
        norm_keys = {_normalise_label(k) for k in names_norm}
        if label_clean in norm_keys:
            return True
        # Flexible match for standard 3 accounts (e.g. "Business Trans Account" vs "Business Trans Acct")
        if "business trans" in label_clean and ("acct" in label_clean or "account" in label_clean):
            return True
        if "tax" in label_clean and "reserve" in label_clean:
            return True
        if "operating reserve" in label_clean:
            return True
        return False

    def _process_row(row):
        """Extract amount from a Row or SummaryRow if it matches our account names.
        Uses cells[1] (first data column = current period), NOT cells[-1] which
        is the comparison period and would be $0 for recently created orgs.
        """
        cells = row.get("Cells", [])
        if not cells or len(cells) < 2:
            return 0.0
        label = (cells[0].get("Value") or "").strip().lower()
        if not _matches_account(label, names_normalised):
            return 0.0
        return _parse_amount(cells[1].get("Value", "0"))

    def _search_rows(rows_in):
        nonlocal total
        for row in rows_in or []:
            if row.get("RowType") == "Section":
                section_title = (row.get("Title") or "").strip().lower()
                if section_title and _matches_account(section_title, names_normalised):
                    for sub in row.get("Rows") or []:
                        cells = sub.get("Cells", [])
                        if cells and len(cells) > 1:
                            total += _parse_amount(cells[1].get("Value", "0"))
                else:
                    _search_rows(row.get("Rows", []))
            elif row.get("RowType") in ("Row", "SummaryRow"):
                total += _process_row(row)

    _search_rows(rows)
    return total


def get_accounts_by_code(access_token, tenant_id, account_codes):
    """
    Fetch accounts from Xero and find UUIDs for specific account codes.
    
    Args:
        access_token: OAuth access token
        tenant_id: Xero tenant ID
        account_codes: List of account codes to find (e.g., ["241.1", "241.2", "241.3"])
    
    Returns:
        Dictionary mapping account codes to {uuid, name, type}
    """
    url = "https://api.xero.com/api.xro/2.0/Accounts"
    
    headers = {
        "Authorization": f"Bearer {access_token}",
        "xero-tenant-id": tenant_id,
        "Accept": "application/json"
    }
    
    response = requests.get(url, headers=headers, timeout=30)
    
    if response.status_code != 200:
        print(f"Error fetching accounts: {response.status_code}")
        response.raise_for_status()
    
    accounts_data = response.json()
    accounts = accounts_data.get("Accounts", [])
    
    # Find accounts matching the codes
    found_accounts = {}
    for account in accounts:
        code = account.get("Code")
        if code in account_codes:
            found_accounts[code] = {
                "uuid": account.get("AccountID"),
                "name": account.get("Name"),
                "type": account.get("Type"),
                "code": code
            }
    
    return found_accounts


def extract_accounts_by_name_patterns(pl_data, department_config):
    """
    Extract and combine accounts by department based on name patterns.
    
    Args:
        pl_data: Profit & Loss JSON data
        department_config: Dictionary mapping department names to lists of account name patterns
                          e.g., {
                              "Solar": ["Solar - Install Commercial", "Solar - Install Residential", ...],
                              "Residential": ["ELECTRICS - Electrical", ...],
                              ...
                          }
    
    Returns:
        Dictionary mapping department names to {total, accounts: [{name, amount}]}
    """
    department_totals = {}
    
    if not pl_data or "Reports" not in pl_data:
        return department_totals
    
    reports = pl_data.get("Reports", [])
    if not reports:
        return department_totals
    
    report = reports[0]
    rows = report.get("Rows", [])
    
    # Initialize department totals
    for dept_name in department_config.keys():
        department_totals[dept_name] = {
            "total": 0.0,
            "accounts": []
        }
    
    # Extract accounts from P&L data - only from Income section to avoid including costs/expenses
    for row in rows:
        if row.get("RowType") != "Section":
            continue
        section_title = (row.get("Title") or "").strip().lower()
        if section_title != "income":
            continue
        rows_in_section = row.get("Rows", [])
        for sub_row in rows_in_section:
            if sub_row.get("RowType") != "Row":
                continue
            cells = sub_row.get("Cells", [])
            if not cells or len(cells) < 2:
                continue
            account_name = cells[0].get("Value", "")
            amount = cells[-1].get("Value", "0")
            attributes = cells[0].get("Attributes", [])
            account_uuid = None
            for attr in attributes:
                if attr.get("Id") == "account":
                    account_uuid = attr.get("Value")
                    break
            try:
                amount_float = float(str(amount).replace(",", "")) if amount else 0
            except (ValueError, TypeError):
                continue
            account_normalized = " ".join(account_name.lower().split())
            matched = False
            for dept_name, patterns in department_config.items():
                if matched:
                    break
                for pattern in patterns:
                    pattern_normalized = " ".join(pattern.lower().split())
                    # Exact match only - prevents over-counting (e.g. "Air Con - Builders Residential")
                    if account_normalized == pattern_normalized:
                        department_totals[dept_name]["total"] += amount_float
                        department_totals[dept_name]["accounts"].append({
                            "name": account_name,
                            "amount": amount_float,
                            "uuid": account_uuid
                        })
                        matched = True
                        break
    return department_totals


def extract_gross_profit_metrics(pl_data):
    """
    Extract Total Cost of Sales and Gross Profit from Xero P&L report.
    
    Args:
        pl_data: JSON response from Xero Profit & Loss API
        
    Returns:
        Dictionary with:
        - total_cost_of_sales: float
        - gross_profit: float
        - gross_profit_margin: float (percentage, 0-100)
        - total_revenue: float (calculated as Cost of Sales + Gross Profit)
    """
    result = {
        "total_cost_of_sales": 0.0,
        "gross_profit": 0.0,
        "gross_profit_margin": 0.0,
        "total_revenue": 0.0
    }
    
    if not pl_data or "Reports" not in pl_data:
        return result
    
    reports = pl_data.get("Reports", [])
    if not reports:
        return result
    
    report = reports[0]
    rows = report.get("Rows", [])
    
    # Look for "Total Cost of Sales" and "Gross Profit" in the report
    for row in rows:
        if row.get("RowType") == "Section":
            rows_in_section = row.get("Rows", [])
            for sub_row in rows_in_section:
                row_type = sub_row.get("RowType", "")
                cells = sub_row.get("Cells", [])
                
                if cells and len(cells) > 0:
                    label = cells[0].get("Value", "").strip()
                    amount_str = cells[-1].get("Value", "0")
                    
                    try:
                        amount = float(amount_str) if amount_str else 0.0
                        
                        # Check for "Total Cost of Sales" (can be SummaryRow or Row)
                        if "total cost of sales" in label.lower():
                            result["total_cost_of_sales"] = amount
                        
                        # Check for "Gross Profit" (usually a Row after Cost of Sales section)
                        if "gross profit" in label.lower() and "margin" not in label.lower():
                            result["gross_profit"] = amount
                    except (ValueError, TypeError):
                        pass
    
    # Calculate total revenue and GP margin
    result["total_revenue"] = result["total_cost_of_sales"] + result["gross_profit"]
    
    if result["total_revenue"] > 0:
        result["gross_profit_margin"] = (result["gross_profit"] / result["total_revenue"]) * 100
    
    return result


def _find_total_trading_income_in_rows(rows):
    """Recursively search rows for Total Trading Income (handles nested sections)."""
    for row in (rows or []):
        cells = row.get("Cells", [])
        if cells:
            label = cells[0].get("Value", "").strip()
            amount_str = cells[-1].get("Value", "0")
            try:
                amount = float(str(amount_str).replace(",", "")) if amount_str else 0.0
                label_lower = label.lower()
                if "total trading income" in label_lower or (
                    "total income" in label_lower and "other" not in label_lower
                ):
                    return amount
            except (ValueError, TypeError):
                pass
        # Recurse into nested Rows (Section, SummaryRow, etc.)
        nested = row.get("Rows", [])
        if nested:
            found = _find_total_trading_income_in_rows(nested)
            if found is not None:
                return found
    return None


def extract_total_trading_income(pl_data):
    """
    Extract "Total Trading Income" (or "Total Income") from a department P&L report.
    Matches custom reports like "Profit and Loss - Air" (Total Trading Income = $96,288.71).
    
    Args:
        pl_data: JSON response from Xero Profit & Loss API (department-filtered)
        
    Returns:
        float: Total Trading Income amount, or 0.0 if not found
    """
    if not pl_data or "Reports" not in pl_data:
        return 0.0
    reports = pl_data.get("Reports", [])
    if not reports:
        return 0.0
    rows = reports[0].get("Rows", [])
    result = _find_total_trading_income_in_rows(rows)
    return result if result is not None else 0.0


def extract_business_groups_revenue(pl_data, department_to_option):
    """
    Extract Total Trading Income for multiple Business Groups from a single P&L call.
    When P&L is called with trackingCategoryID only (no trackingOptionID), Xero returns
    all tracking options as separate columns in one response.
    
    Args:
        pl_data: P&L response with trackingCategoryID (multi-column)
        department_to_option: Dict mapping dept name -> tracking option name (e.g. {"Air": "Air Conditioning"})
    
    Returns:
        Dict mapping dept name -> revenue (e.g. {"Air": 96288.71})
    """
    if not pl_data or "Reports" not in pl_data:
        return {}
    reports = pl_data.get("Reports", [])
    if not reports:
        return {}
    rows = reports[0].get("Rows", [])
    
    # Find Header row to map column index to tracking option name
    column_map = {}  # {option_name_lower: column_index}
    for row in rows:
        if row.get("RowType") == "Header":
            cells = row.get("Cells", [])
            for idx, cell in enumerate(cells):
                if idx == 0:
                    continue  # Skip label column
                option_name = (cell.get("Value") or "").strip()
                if option_name:
                    column_map[option_name.lower()] = idx
            break
    
    # Find Total Trading Income row and extract values for each department
    def _parse_amount(val):
        try:
            return float(str(val).replace(",", "")) if val else 0.0
        except (ValueError, TypeError):
            return 0.0

    def find_total_in_rows(rows):
        for row in (rows or []):
            cells = row.get("Cells", [])
            if cells:
                label = cells[0].get("Value", "").strip().lower()
                if "total trading income" in label or ("total income" in label and "other" not in label):
                    result = {}
                    for dept, opt_spec in department_to_option.items():
                        opts = [opt_spec] if isinstance(opt_spec, str) else opt_spec
                        total = 0.0
                        for opt_name in opts:
                            opt_lower = opt_name.strip().lower()
                            col_idx = column_map.get(opt_lower)
                            if col_idx is not None and col_idx < len(cells):
                                total += _parse_amount(cells[col_idx].get("Value", "0"))
                        result[dept] = total
                    return result
            nested = row.get("Rows", [])
            if nested:
                found = find_total_in_rows(nested)
                if found:
                    return found
        return None

    result = find_total_in_rows(rows)
    return result if result else {dept: 0.0 for dept in department_to_option.keys()}


def extract_business_groups_gross_profit(pl_data, department_to_option):
    """
    Extract Gross Profit for multiple Business Groups from a single P&L call.
    When P&L is called with trackingCategoryID only (no trackingOptionID), Xero returns
    all tracking options as separate columns in one response.
    
    Args:
        pl_data: P&L response with trackingCategoryID (multi-column)
        department_to_option: Dict mapping dept name -> tracking option name (e.g. {"Air": "Air Conditioning"})
    
    Returns:
        Dict mapping dept name -> gross profit amount
    """
    if not pl_data or "Reports" not in pl_data:
        return {}
    reports = pl_data.get("Reports", [])
    if not reports:
        return {}
    rows = reports[0].get("Rows", [])
    
    # Find Header row to map column index to tracking option name
    column_map = {}  # {option_name_lower: column_index}
    for row in rows:
        if row.get("RowType") == "Header":
            cells = row.get("Cells", [])
            for idx, cell in enumerate(cells):
                if idx == 0:
                    continue  # Skip label column
                option_name = (cell.get("Value") or "").strip()
                if option_name:
                    column_map[option_name.lower()] = idx
            break
    
    def _parse_amount(val):
        try:
            return float(str(val).replace(",", "")) if val else 0.0
        except (ValueError, TypeError):
            return 0.0

    def find_gross_profit_in_rows(rows):
        for row in (rows or []):
            cells = row.get("Cells", [])
            if cells:
                label = cells[0].get("Value", "").strip().lower()
                if "gross profit" in label and "margin" not in label:
                    result = {}
                    for dept, opt_spec in department_to_option.items():
                        opts = [opt_spec] if isinstance(opt_spec, str) else opt_spec
                        total = 0.0
                        for opt_name in opts:
                            opt_lower = opt_name.strip().lower()
                            col_idx = column_map.get(opt_lower)
                            if col_idx is not None and col_idx < len(cells):
                                total += _parse_amount(cells[col_idx].get("Value", "0"))
                        result[dept] = total
                    return result
            nested = row.get("Rows", [])
            if nested:
                found = find_gross_profit_in_rows(nested)
                if found:
                    return found
        return None

    result = find_gross_profit_in_rows(rows)
    return result if result else {dept: 0.0 for dept in department_to_option.keys()}


def extract_business_groups_net_profit(pl_data, department_to_option):
    """
    Extract Net Profit for multiple Business Groups from a single P&L call.
    Uses the "Net Profit" line item which appears at the bottom of the P&L.
    """
    if not pl_data or "Reports" not in pl_data:
        return {}
    reports = pl_data.get("Reports", [])
    if not reports:
        return {}
    rows = reports[0].get("Rows", [])

    column_map = {}
    for row in rows:
        if row.get("RowType") == "Header":
            cells = row.get("Cells", [])
            for idx, cell in enumerate(cells):
                if idx == 0:
                    continue
                option_name = (cell.get("Value") or "").strip()
                if option_name:
                    column_map[option_name.lower()] = idx
            break

    def _parse_amount(val):
        try:
            return float(str(val).replace(",", "")) if val else 0.0
        except (ValueError, TypeError):
            return 0.0

    def find_net_profit_in_rows(rows):
        for row in (rows or []):
            cells = row.get("Cells", [])
            if cells:
                label = cells[0].get("Value", "").strip().lower()
                if label == "net profit" or (
                    "net profit" in label and "margin" not in label and "loss" not in label
                ):
                    result = {}
                    for dept, opt_spec in department_to_option.items():
                        opts = [opt_spec] if isinstance(opt_spec, str) else opt_spec
                        total = 0.0
                        for opt_name in opts:
                            opt_lower = opt_name.strip().lower()
                            col_idx = column_map.get(opt_lower)
                            if col_idx is not None and col_idx < len(cells):
                                total += _parse_amount(cells[col_idx].get("Value", "0"))
                        result[dept] = total
                    return result
            nested = row.get("Rows", [])
            if nested:
                found = find_net_profit_in_rows(nested)
                if found:
                    return found
        return None

    result = find_net_profit_in_rows(rows)
    return result if result else {dept: 0.0 for dept in department_to_option.keys()}


def extract_consolidated_net_profit(pl_data):
    """
    Extract the consolidated "Net Profit" from a standard P&L response (no tracking breakdown).
    This gives the true net profit including all operating expenses, not just those
    allocated to specific Business Groups.
    """
    if not pl_data or "Reports" not in pl_data:
        return 0.0
    reports = pl_data.get("Reports", [])
    if not reports:
        return 0.0
    rows = reports[0].get("Rows", [])

    def _parse_amount(val):
        try:
            return float(str(val).replace(",", "")) if val else 0.0
        except (ValueError, TypeError):
            return 0.0

    def _find_net_profit(rows_in):
        for row in (rows_in or []):
            cells = row.get("Cells", [])
            if cells:
                label = cells[0].get("Value", "").strip().lower()
                # Match "Net Profit", "Net Profit/(Loss)" (Xero bottom line), but not "Net Profit Margin"
                if label == "net profit" or (
                    "net profit" in label and "margin" not in label
                ):
                    return _parse_amount(cells[-1].get("Value", "0"))
            nested = row.get("Rows", [])
            if nested:
                found = _find_net_profit(nested)
                if found is not None:
                    return found
        return None

    result = _find_net_profit(rows)
    return result if result is not None else 0.0


def get_tracking_categories(access_token, tenant_id):
    """
    Fetch tracking categories and their options from Xero.
    Use this to discover TrackingCategoryID and TrackingOptionID for department reports.
    
    Returns:
        List of {TrackingCategoryID, Name, Options: [{TrackingOptionID, Name}]}
    """
    url = "https://api.xero.com/api.xro/2.0/TrackingCategories"
    headers = {
        "Authorization": f"Bearer {access_token}",
        "xero-tenant-id": tenant_id,
        "Accept": "application/json"
    }
    
    # Retry logic for 429 rate limiting
    max_retries = 3
    max_wait_time = 60  # Cap wait time at 60 seconds
    last_retry_after = None
    
    for attempt in range(max_retries):
        response = requests.get(url, headers=headers, timeout=30)
        
        if response.status_code == 429:
            retry_after_raw = response.headers.get('Retry-After') or ''
            rate_limit_problem = response.headers.get('X-Rate-Limit-Problem', 'unknown')
            try:
                if retry_after_raw.isdigit():
                    retry_after_int = int(retry_after_raw)
                elif retry_after_raw:
                    retry_until = parsedate_to_datetime(retry_after_raw)
                    now = datetime.now(timezone.utc)
                    retry_until_utc = retry_until if retry_until.tzinfo else retry_until.replace(tzinfo=timezone.utc)
                    retry_after_int = max(0, int((retry_until_utc - now).total_seconds()))
                else:
                    retry_after_int = 3600
            except (ValueError, TypeError):
                retry_after_int = 3600
            last_retry_after = retry_after_int
            wait_time = min(retry_after_int, max_wait_time)
            capped = retry_after_int > max_wait_time
            
            def fmt_time(sec):
                if sec >= 3600:
                    return f"{sec}s ({sec/3600:.1f} hours)"
                elif sec >= 60:
                    return f"{sec}s ({sec/60:.0f} min)"
                return f"{sec}s"
            
            if attempt < max_retries - 1:
                if capped:
                    print(f"⏱️  Rate limited (429) on TrackingCategories. Try again in {fmt_time(last_retry_after)}. Waiting {wait_time}s...")
                else:
                    print(f"⏱️  Rate limited (429) on TrackingCategories. Try again in {fmt_time(last_retry_after)}. Waiting...")
                time.sleep(wait_time)
                continue
            if last_retry_after >= 3600:
                print(f"❌ Rate limited (429) on TrackingCategories. Try again in {fmt_time(last_retry_after)}. (Limit: {rate_limit_problem})")
            else:
                print(f"❌ Rate limited (429) on TrackingCategories. Try again in {fmt_time(last_retry_after)}.")
        
        break
    
    if response.status_code != 200:
        # Handle 429 specially - attach Retry-After info
        if response.status_code == 429 and last_retry_after:
            error_msg = f"Xero API rate limit exceeded (TrackingCategories). Retry after {last_retry_after} seconds."
            error = requests.exceptions.HTTPError(error_msg, response=response)
            error.retry_after = last_retry_after
            raise error
        response.raise_for_status()
    data = response.json()
    return data.get("TrackingCategories", [])


def _match_option(option_by_name, opt_name_clean):
    """Match option by exact name, or by case-insensitive contains (e.g. 'Electrical' in 'Electrical Dept')."""
    key = (opt_name_clean or "").strip().lower()
    if not key:
        return None
    # Exact match first
    for name, opt_id in option_by_name.items():
        if name == key:
            return opt_id
    # Partial: config name contained in option name or vice versa
    for name, opt_id in option_by_name.items():
        if key in name or name in key:
            return opt_id
    return None


def resolve_tracking_ids_by_name(access_token, tenant_id, tracking_category_names, department_to_option_name):
    """
    Resolve tracking category/option names to IDs by fetching from Xero API.
    
    Args:
        access_token: OAuth access token
        tenant_id: Xero tenant ID
        tracking_category_names: Name(s) of tracking category to try (e.g. ["Business Groups", "Department"])
        department_to_option_name: Dict mapping our dept name -> Xero option name (e.g. {"Air": "Air"})
    
    Returns:
        Tuple of (tracking_category_id, {dept: tracking_option_id}) or (None, {}) if not found
    """
    import logging
    log = logging.getLogger(__name__)
    try:
        cats = get_tracking_categories(access_token, tenant_id)
    except Exception as e:
        log.warning("Could not fetch TrackingCategories (scope/401?): %s", e)
        return (None, {})
    if not cats:
        log.debug("No tracking categories returned from Xero")
        return (None, {})

    names_to_try = (
        tracking_category_names
        if isinstance(tracking_category_names, (list, tuple))
        else [tracking_category_names] if tracking_category_names else []
    )
    names_lower = [n.strip().lower() for n in names_to_try if n]

    def try_category(cat):
        cat_id = cat.get("TrackingCategoryID")
        options = cat.get("Options", [])
        option_by_name = {}
        for opt in options:
            if opt.get("Status") == "ACTIVE":
                opt_name = (opt.get("Name") or "").strip()
                option_by_name[opt_name.lower()] = opt.get("TrackingOptionID")
        result = {}
        for dept, opt_name in department_to_option_name.items():
            opt_id = _match_option(option_by_name, opt_name)
            if opt_id:
                result[dept] = opt_id
        return (cat_id, result)

    best = (None, {})
    for cat in cats:
        cname = (cat.get("Name") or "").strip().lower()
        cat_id, result = try_category(cat)
        if len(result) > len(best[1]):
            best = (cat_id, result)
        if cname in names_lower and result:
            return (cat_id, result)
    if best[1]:
        return best
    # Fallback: if we found "Business Groups" category but no option matches, return cat_id anyway.
    # P&L multi-column uses column headers (option names) - we don't need option IDs.
    for cat in cats:
        cname = (cat.get("Name") or "").strip().lower()
        if cname in names_lower:
            cat_id = cat.get("TrackingCategoryID")
            if cat_id and department_to_option_name:
                return (cat_id, {d: "ok" for d in department_to_option_name})
    for cat in cats:
        opts = [o.get("Name") for o in cat.get("Options", []) if o.get("Status") == "ACTIVE"]
        log.info("Xero tracking: category '%s' has options: %s", cat.get("Name"), opts)
    return (None, {})


def display_department_totals(department_totals, report_period="Current Period"):
    """
    Display department totals in a formatted way.
    
    Args:
        department_totals: Dictionary from extract_accounts_by_name_patterns
        report_period: String describing the period (e.g., "Current Month", "Current Day")
    """
    print("\n" + "="*60)
    print(f"DEPARTMENT TOTALS - {report_period.upper()}")
    print("="*60)
    
    grand_total = 0.0
    
    for dept_name, dept_data in sorted(department_totals.items()):
        total = dept_data["total"]
        accounts = dept_data["accounts"]
        grand_total += total
        
        print(f"\n{dept_name.upper()}:")
        print("-" * 60)
        
        if accounts:
            for acc in accounts:
                print(f"  {acc['name']:.<45} ${acc['amount']:>10,.2f}")
        else:
            print("  (No accounts found)")
        
        print(f"\n  {'Total ' + dept_name + ' Revenue':.<45} ${total:>10,.2f}")
    
    print("\n" + "="*60)
    print(f"{'GRAND TOTAL (All Departments)':.<45} ${grand_total:>10,.2f}")
    print("="*60)


def extract_accounts_by_uuid(pl_data, account_uuids):
    """
    Extract specific accounts from P&L data by their UUIDs.
    Returns a list of {name, amount, uuid} dictionaries.
    
    Args:
        pl_data: Profit & Loss JSON data
        account_uuids: List of account UUIDs to extract
    
    Returns:
        List of dictionaries with account details
    """
    accounts_found = []
    
    if not pl_data or "Reports" not in pl_data:
        return accounts_found
    
    reports = pl_data.get("Reports", [])
    if not reports:
        return accounts_found
    
    report = reports[0]
    rows = report.get("Rows", [])
    
    for row in rows:
        if row.get("RowType") == "Section":
            rows_in_section = row.get("Rows", [])
            for sub_row in rows_in_section:
                if sub_row.get("RowType") == "Row":
                    cells = sub_row.get("Cells", [])
                    if cells and len(cells) > 0:
                        # Get account UUID from attributes
                        attributes = cells[0].get("Attributes", [])
                        account_uuid = None
                        for attr in attributes:
                            if attr.get("Id") == "account":
                                account_uuid = attr.get("Value")
                                break
                        
                        # If this account matches one we're looking for
                        if account_uuid in account_uuids:
                            account_name = cells[0].get("Value", "")
                            amount = cells[-1].get("Value", "0") if len(cells) > 0 else "0"
                            try:
                                amount_float = float(amount) if amount else 0
                                accounts_found.append({
                                    "name": account_name,
                                    "amount": amount_float,
                                    "uuid": account_uuid
                                })
                            except (ValueError, TypeError):
                                pass
    
    return accounts_found


def find_account_uuids_by_name_pattern(pl_data, name_patterns):
    """
    Search P&L data for accounts matching name patterns and return their UUIDs.
    Useful for finding account UUIDs when you know part of the account name.
    
    Args:
        pl_data: Profit & Loss JSON data
        name_patterns: List of strings to search for in account names (case-insensitive)
    
    Returns:
        Dictionary mapping patterns to lists of {name, uuid, amount}
    """
    found = {}
    
    if not pl_data or "Reports" not in pl_data:
        return found
    
    reports = pl_data.get("Reports", [])
    if not reports:
        return found
    
    report = reports[0]
    rows = report.get("Rows", [])
    
    for row in rows:
        if row.get("RowType") == "Section":
            rows_in_section = row.get("Rows", [])
            for sub_row in rows_in_section:
                if sub_row.get("RowType") == "Row":
                    cells = sub_row.get("Cells", [])
                    if cells and len(cells) > 0:
                        account_name = cells[0].get("Value", "")
                        amount = cells[-1].get("Value", "0")
                        
                        # Get account UUID
                        attributes = cells[0].get("Attributes", [])
                        account_uuid = None
                        for attr in attributes:
                            if attr.get("Id") == "account":
                                account_uuid = attr.get("Value")
                                break
                        
                        # Check if account name matches any pattern
                        for pattern in name_patterns:
                            if pattern.lower() in account_name.lower():
                                if pattern not in found:
                                    found[pattern] = []
                                try:
                                    amount_float = float(amount) if amount else 0
                                    found[pattern].append({
                                        "name": account_name,
                                        "uuid": account_uuid,
                                        "amount": amount_float
                                    })
                                except (ValueError, TypeError):
                                    pass
    
    return found


def test_account_extraction(pl_data, test_uuids):
    """
    Test function to extract and display account values by UUID.
    
    Args:
        pl_data: Profit & Loss JSON data
        test_uuids: Dictionary mapping test names to UUIDs
                   e.g., {"Solar Account 1": "uuid1", "Solar Account 2": "uuid2"}
    """
    print("\n" + "="*60)
    print("TEST: EXTRACTING ACCOUNTS BY UUID")
    print("="*60)
    
    all_uuids = list(test_uuids.values())
    accounts_found = extract_accounts_by_uuid(pl_data, all_uuids)
    
    if not accounts_found:
        print("No accounts found with the provided UUIDs")
        return
    
    print(f"\nFound {len(accounts_found)} account(s):\n")
    
    total = 0
    for test_name, uuid in test_uuids.items():
        # Find matching account
        matching = [acc for acc in accounts_found if acc["uuid"] == uuid]
        if matching:
            acc = matching[0]
            print(f"{test_name}:")
            print(f"  Name: {acc['name']}")
            print(f"  UUID: {acc['uuid']}")
            print(f"  Amount: ${acc['amount']:>10,.2f}")
            total += acc['amount']
            print()
        else:
            print(f"{test_name}: NOT FOUND (UUID: {uuid})")
            print()
    
    print("-" * 60)
    print(f"TOTAL (Net Value): ${total:>10,.2f}")
    print("="*60)


def parse_profit_and_loss(pl_data):
    """
    Parse and display key metrics from the Profit & Loss report.
    """
    if not pl_data or "Reports" not in pl_data:
        print("⚠️  Unexpected data format")
        return
    
    reports = pl_data.get("Reports", [])
    if not reports:
        print("⚠️  No reports found in response")
        return
    
    report = reports[0]
    report_type = report.get("ReportType", "Unknown")
    report_name = report.get("ReportName", "Unknown")
    report_date = report.get("ReportDate", "Unknown")
    
    print(f"\n{'='*60}")
    print(f"Report: {report_name}")
    print(f"Type: {report_type}")
    print(f"Date: {report_date}")
    print(f"{'='*60}\n")
    
    # Extract rows (revenue, expenses, etc.)
    rows = report.get("Rows", [])
    
    for row in rows:
        row_type = row.get("RowType", "")
        if row_type == "Header":
            # Print section headers
            cells = row.get("Cells", [])
            if cells:
                print(f"\n{row.get('Title', '')}")
                print("-" * 60)
        elif row_type == "Section":
            # Print section titles (Revenue, Expenses, etc.)
            title = row.get("Title", "")
            if title:
                print(f"\n{title}")
                print("-" * 60)
            
            # Print rows within this section
            rows_in_section = row.get("Rows", [])
            for sub_row in rows_in_section:
                if sub_row.get("RowType") == "Row":
                    cells = sub_row.get("Cells", [])
                    if cells:
                        account_name = cells[0].get("Value", "") if len(cells) > 0 else ""
                        amount = cells[-1].get("Value", "0") if len(cells) > 0 else "0"
                        try:
                            amount_float = float(amount) if amount else 0
                            if amount_float != 0:
                                print(f"  {account_name:.<50} ${amount_float:>10,.2f}")
                        except (ValueError, TypeError):
                            pass
        elif row_type == "SummaryRow":
            # Print summary rows (Total Revenue, Net Profit, etc.)
            cells = row.get("Cells", [])
            if cells:
                title = cells[0].get("Value", "") if len(cells) > 0 else ""
                amount = cells[-1].get("Value", "0") if len(cells) > 0 else "0"
                try:
                    amount_float = float(amount) if amount else 0
                    print(f"\n{title:.<50} ${amount_float:>10,.2f}")
                except (ValueError, TypeError):
                    pass


def main():
    """
    Main function to fetch and display Profit & Loss data from Xero.
    """
    print("Refreshing access token...")
    try:
        access_token = refresh_access_token()
        print("Access token obtained\n")
    except Exception as e:
        print(f"Failed to get access token: {e}")
        return
    
    # Get tenant ID (fetch if not in .env or if it's a placeholder)
    print("Getting tenant ID...")
    try:
        tenant_id = TENANT_ID
        # Check if tenant_id is missing or is a placeholder
        if not tenant_id or tenant_id.lower() in ["your_tenant_id", "placeholder", "xxx", ""]:
            print("   Tenant ID not found or is a placeholder, fetching from Xero...")
            tenant_id = get_tenant_id(access_token)
        else:
            print(f"Using tenant ID from .env: {tenant_id[:8]}...")
        
        # Verify we have access to this tenant
        print("Verifying tenant access...")
        if not verify_tenant_access(access_token, tenant_id):
            print("Could not verify tenant access. Trying to fetch fresh tenant ID...")
            tenant_id = get_tenant_id(access_token)
            if verify_tenant_access(access_token, tenant_id):
                print("Successfully verified access with fresh tenant ID")
        print()
    except Exception as e:
        print(f"Failed to get tenant ID: {e}")
        import traceback
        traceback.print_exc()
        return
    
    print("Fetching Profit & Loss data from Xero...")
    try:
        # Get current date for today's report
        today = datetime.now()
        today_str = today.strftime("%Y-%m-%d")
        
        # Report 1: Current month (default - no date params)
        print("\n" + "="*60)
        print("REPORT 1: CURRENT MONTH")
        print("="*60)
        pl_data_month = get_profit_and_loss(access_token, tenant_id=tenant_id)
        parse_profit_and_loss(pl_data_month)
        
        # Save current month report
        output_file_month = "profit_and_loss_month.json"
        with open(output_file_month, "w") as f:
            json.dump(pl_data_month, f, indent=2)
        print(f"\nCurrent month report saved to {output_file_month}")
        
        # Define department configurations
        department_config = {
            "Solar": [
                "Solar - Install Commercial",
                "Solar - Install Residential",
                "Solar - Service & Maintenance"
            ],
            "Residential": [
                "ELECTRICS - Electrical Fault Finding",
                "ELECTRICS - Electrical Installation",
                "ELECTRICS - Electrical Maintenance",
                "ELECTRICS - Electrical Reports & Certificates",
                "ELECTRICS - Property Management"
            ],
            "Air": [
                "Air Con - Install (Ducted)",
                "Air Con - Install (Splits)",
                "Air Con - Service & Maintenance",
                "Air Con - Referral to Electrical",
                "Subscription - Annual AC Service"
            ],
            "Commercial": [
                "COMMERCIAL DEPT - Installation",
                "COMMERCIAL DEPT - Maintenance"
            ]
        }
        
        # Extract and display department totals for current month
        department_totals_month = extract_accounts_by_name_patterns(pl_data_month, department_config)
        display_department_totals(department_totals_month, "Current Month")
        
        # Report 2: Current day (today only)
        print("\n\n" + "="*60)
        print(f"REPORT 2: CURRENT DAY ({today_str})")
        print("="*60)
        pl_data_today = get_profit_and_loss(access_token, tenant_id=tenant_id, 
                                           from_date=today_str, to_date=today_str)
        parse_profit_and_loss(pl_data_today)
        
        # Save current day report
        output_file_today = "profit_and_loss_today.json"
        with open(output_file_today, "w") as f:
            json.dump(pl_data_today, f, indent=2)
        print(f"\nCurrent day report saved to {output_file_today}")
        
        # Extract and display department totals for current day
        department_totals_today = extract_accounts_by_name_patterns(pl_data_today, department_config)
        display_department_totals(department_totals_today, "Current Day")
        
    except Exception as e:
        print(f"Failed to fetch P&L data: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()
