#!/usr/bin/env python3
"""
Apply handbook-style grading to «Q1» in kpis.xlsx:

  • Scores 3–5 per KPI per month (from PDF targets + assumed «exceeds» bands).
  • Weight % per KPI per ruleset (Electrics / Air / Commercial — from PDF).
  • Payout % per month: 4 or 5 → 100% (5 capped like 4 for $), 3 → 0% (PDF).
  • Department summary rows: weighted index = SUMPRODUCT(weight%, payout%) / 100.

Weights for KPIs not on the sheet (e.g. VMVs, due dates) are omitted and remaining
weights are renormalized to sum to 100% within that department block.

Run: python3 scripts/apply_kpi_grades_kpis_xlsx.py
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

try:
    import openpyxl
    from openpyxl.styles import Alignment, Font, PatternFill
    from openpyxl.utils import get_column_letter
except ImportError:
    print("pip install openpyxl", file=sys.stderr)
    sys.exit(1)

# --- Grading (same as before) ---


def _r_high(v: float | None, meet: float, exceed: float) -> int | None:
    if v is None:
        return None
    if v >= exceed:
        return 5
    if v >= meet:
        return 4
    return 3


def _r_low(v: float | None, meet: float, exceed: float) -> int | None:
    if v is None:
        return None
    if v <= exceed:
        return 5
    if v <= meet:
        return 4
    return 3


def _r_workload(v: float | None, _meet: float, _exceed: float) -> int | None:
    if v is None:
        return None
    if v >= 1.0:
        return 4
    return 3


def rules_elec() -> dict[str, tuple]:
    return {
        "quote conversion": ("high", 60.0, 70.0),
        "client reviews": ("high", 4.0, 4.5),
        "positive client reviews": ("high", 10.0, 13.0),
        "avg. time to first attendance": ("low", 6.5, 5.0),
        "quote creation to acceptance": ("low", 6.0, 4.5),
    }


def rules_air() -> dict[str, tuple]:
    return {
        "quote conversion": ("high", 50.0, 60.0),
        "client reviews": ("high", 4.0, 4.5),
        "positive client reviews": ("high", 10.0, 13.0),
        "avg. time to first attendance": ("low", 10.0, 7.0),
        "quote creation to acceptance": ("low", 5.0, 3.5),
    }


def rules_comm() -> dict[str, tuple]:
    return {
        "jobs completed": ("high", 230.0, 245.0),
        "avg. time to first attendance": ("low", 4.0, 3.0),
        "total pending jobs": ("low", 75.0, 55.0),
        "weekly workload form completion": ("workload", 1.0, 1.0),
    }


DEPT_RULESET: dict[str, str] = {
    "Bunbury Residental": "ELEC",
    "Bunbury Solar": "ELEC",
    "Busselton Electrical": "ELEC",
    "Mandurah Electrical": "ELEC",
    "Bunbury Air": "AIR",
    "Busselton Air": "AIR",
    "Mandurah Air": "AIR",
    "Bunbury Commercial": "COMM",
}

RULESETS = {
    "ELEC": rules_elec(),
    "AIR": rules_air(),
    "COMM": rules_comm(),
}

# PDF weightings (page 3). Keys = fragments matching KPI column (lowercase).
# Electrics / Air: Quote 30, VMVs 10, Reviews 10, TTFA 25, Q2A 25
# Commercial: Jobs 35, TTFA 10, Pending 30, Workload 5, Due dates 15, VMVs 5

HANDBOOK_WEIGHTS: dict[str, dict[str, float]] = {
    "ELEC": {
        "quote conversion": 30.0,
        "alignment": 10.0,  # VMVs — not on Q1; dropped in renormalize
        "client reviews": 10.0,
        "positive client reviews": 10.0,
        "avg. time to first attendance": 25.0,
        "quote creation to acceptance": 25.0,
    },
    "AIR": {
        "quote conversion": 30.0,
        "alignment": 10.0,
        "client reviews": 10.0,
        "positive client reviews": 10.0,
        "avg. time to first attendance": 25.0,
        "quote creation to acceptance": 25.0,
    },
    "COMM": {
        "jobs completed": 35.0,
        "avg. time to first attendance": 10.0,
        "total pending jobs": 30.0,
        "weekly workload form completion": 5.0,
        "due date": 15.0,  # not on Q1
        "alignment": 5.0,
    },
}


def _normalize_kpi(s: str) -> str:
    return s.strip().lower()


def _find_rule(ruleset: str, kpi_cell: str) -> tuple[str, float, float] | None:
    k = _normalize_kpi(kpi_cell)
    table = RULESETS.get(ruleset)
    if not table:
        return None
    for fragment, spec in table.items():
        if fragment in k:
            return spec
    return None


def handbook_weight_for_kpi(ruleset: str, kpi_cell: str) -> float | None:
    """Return PDF weight % for this KPI row, or None if unmapped."""
    k = _normalize_kpi(kpi_cell)
    wmap = HANDBOOK_WEIGHTS.get(ruleset)
    if not wmap:
        return None
    for fragment, w in wmap.items():
        if fragment in k:
            if fragment == "alignment":
                return None
            return w
    return None


def renormalize_weights(
    ruleset: str, kpi_names_in_block: list[str]
) -> dict[str, float]:
    """Sum weights for KPI rows present on sheet; renormalize to 100."""
    raw: dict[str, float] = {}
    for name in kpi_names_in_block:
        w = handbook_weight_for_kpi(ruleset, name)
        if w is not None:
            raw[name] = w
    s = sum(raw.values())
    if s <= 0:
        return {k: 0.0 for k in raw}
    return {k: raw[k] * 100.0 / s for k in raw}


def grade_one(
    value: float | None, ruleset: str, kpi_cell: str
) -> int | str | None:
    if value is None:
        return None
    nk = _normalize_kpi(kpi_cell)
    if ruleset == "COMM" and "jobs completed" in nk:
        if float(value) < 200:
            return "—"
    spec = _find_rule(ruleset, kpi_cell)
    if spec is None:
        return "—"
    mode, meet, exceed = spec
    if mode == "high":
        return _r_high(float(value), meet, exceed)
    if mode == "low":
        return _r_low(float(value), meet, exceed)
    if mode == "workload":
        return _r_workload(float(value), meet, exceed)
    return "—"


def _forward_fill_dept(ws, start_row: int, max_row: int) -> dict[int, str]:
    dept_at_row: dict[int, str] = {}
    current = ""
    for r in range(start_row, max_row + 1):
        v = ws.cell(r, 1).value
        if v is not None and str(v).strip():
            current = str(v).strip()
        dept_at_row[r] = current
    return dept_at_row


def _dept_blocks(
    ws, dept_at_row: dict[int, str], data_start: int, max_row: int
) -> list[tuple[str, int, int]]:
    """[(dept_name, first_row, last_row), ...] for rows that have a KPI in column B."""
    rows = [
        r
        for r in range(data_start, max_row + 1)
        if ws.cell(r, 2).value and str(ws.cell(r, 2).value).strip()
    ]
    if not rows:
        return []
    blocks: list[tuple[str, int, int]] = []
    start = rows[0]
    cur_dept = dept_at_row.get(start, "")
    prev_r = start
    for r in rows[1:]:
        d = dept_at_row.get(r, "")
        if d != cur_dept:
            blocks.append((cur_dept, start, prev_r))
            start = r
            cur_dept = d
        prev_r = r
    blocks.append((cur_dept, start, prev_r))
    return [(a, b, c) for a, b, c in blocks if a]


def payout_formula(row: int, grade_col: int, pay_col: int) -> str:
    """Excel: payout % = 100 if grade >= 4, 0 if grade 3, blank if no grade."""
    g = get_column_letter(grade_col)
    p = get_column_letter(pay_col)
    # IF(OR(ISBLANK(F3),F3="—"),"",IF(F3>=4,100,0))
    return (
        f'=IF(OR(ISBLANK({g}{row}),{g}{row}="—"),"",IF({g}{row}>=4,100,0))'
    )


def ensure_rulebook_sheet(wb) -> None:
    name = "Rulebook (KPI)"
    if name in wb.sheetnames:
        del wb[name]
    ws = wb.create_sheet(name)
    ws["A1"] = "Handbook parity: weights (PDF p.3), payout 4/5=100% 3=0%, weighted SUMPRODUCT/100"
    ws["A2"] = (
        "VMVs / due dates not on Q1 — those weights renormalized onto KPIs present. "
        "Exceeds bands: see apply_kpi_grades_kpis_xlsx.py"
    )
    rows = [
        ("Ruleset", "KPI", "PDF weight %", "Notes"),
        ("ELEC/AIR", "Quote Conversion", "30", ""),
        ("ELEC/AIR", "VMVs", "10", "Not on Q1 — renorm"),
        ("ELEC/AIR", "Reviews", "10", ""),
        ("ELEC/AIR", "TTFA", "25", ""),
        ("ELEC/AIR", "Quote→Accept", "25", ""),
        ("COMM", "Jobs", "35", "Branch targets may differ from 230–250"),
        ("COMM", "TTFA", "10", ""),
        ("COMM", "Pending", "30", ""),
        ("COMM", "Workload form", "5", ""),
        ("COMM", "Due dates", "15", "Not on Q1 — renorm"),
        ("COMM", "VMVs", "5", "Not on Q1 — renorm"),
    ]
    for i, row in enumerate(rows, start=4):
        for j, val in enumerate(row, start=1):
            ws.cell(i, j, val)
    ws.column_dimensions["A"].width = 14
    ws.column_dimensions["D"].width = 36


def apply_grades(path: Path, *, dry_run: bool = False) -> None:
    wb = openpyxl.load_workbook(path)
    if "Q1" not in wb.sheetnames:
        print("No Q1 sheet", file=sys.stderr)
        sys.exit(1)
    ws = wb["Q1"]
    header_row = 2
    data_start = 3

    # Drop previous auto summary (re-runnable script)
    for r in range(1, min(ws.max_row + 1, 300)):
        v = ws.cell(r, 1).value
        if v and "Department summary (handbook-style" in str(v):
            ws.delete_rows(r, ws.max_row - r + 1)
            break

    data_rows = [
        r
        for r in range(data_start, min(ws.max_row + 1, 200))
        if ws.cell(r, 2).value and str(ws.cell(r, 2).value).strip()
    ]
    data_max_row = max(data_rows) if data_rows else data_start

    max_row = data_max_row

    # Ensure grade + weight + payout columns: C–E values, F–H grades, I weight, J–L pay%
    col_jan, col_feb, col_mar = 3, 4, 5
    col_gj, col_gf, col_gm = 6, 7, 8
    col_w = 9
    col_pj, col_pf, col_pm = 10, 11, 12

    headers = [ws.cell(header_row, c).value for c in range(1, 20)]
    if "Grade Jan" not in headers:
        ws.insert_cols(6, 3)
        for c, h in enumerate(["Grade Jan", "Grade Feb", "Grade Mar"], start=6):
            ws.cell(header_row, c, h)
            ws.cell(header_row, c).font = Font(bold=True)
            ws.cell(header_row, c).alignment = Alignment(horizontal="center")
        for mrange in list(ws.merged_cells.ranges):
            if str(mrange).startswith("A1:") and mrange.min_row == 1:
                ws.unmerge_cells(str(mrange))
                break
        ws.merge_cells(start_row=1, start_column=1, end_row=1, end_column=8)

    headers = [ws.cell(header_row, c).value for c in range(1, 20)]
    if "Weight %" not in headers:
        ws.insert_cols(9, 4)
        ws.cell(header_row, 9, "Weight %")
        ws.cell(header_row, 10, "Pay% Jan")
        ws.cell(header_row, 11, "Pay% Feb")
        ws.cell(header_row, 12, "Pay% Mar")
        for c in range(9, 13):
            ws.cell(header_row, c).font = Font(bold=True)
            ws.cell(header_row, c).alignment = Alignment(horizontal="center")
        for mrange in list(ws.merged_cells.ranges):
            if str(mrange).startswith("A1:") and mrange.min_row == 1:
                ws.unmerge_cells(str(mrange))
                break
        ws.merge_cells(start_row=1, start_column=1, end_row=1, end_column=12)

    col_jan, col_feb, col_mar = 3, 4, 5
    col_gj, col_gf, col_gm = 6, 7, 8
    col_w = 9
    col_pj, col_pf, col_pm = 10, 11, 12

    dept_at = _forward_fill_dept(ws, data_start, max(data_max_row, 55))
    fill_4 = PatternFill("solid", fgColor="C6EFCE")
    fill_5 = PatternFill("solid", fgColor="BDD7EE")
    fill_3 = PatternFill("solid", fgColor="FFC7CE")

    # --- Pass 1: collect KPI names per department for renormalization ---
    kpi_by_dept: dict[str, list[str]] = {}
    for r in range(data_start, max_row + 1):
        kpi = ws.cell(r, 2).value
        if kpi is None or str(kpi).strip() == "":
            continue
        d = dept_at.get(r, "")
        if not d or d not in DEPT_RULESET:
            continue
        kpi_by_dept.setdefault(d, []).append(str(kpi).strip())

    weight_lookup: dict[tuple[str, str], float] = {}
    for dept, names in kpi_by_dept.items():
        rs = DEPT_RULESET[dept]
        renorm = renormalize_weights(rs, names)
        for n in names:
            weight_lookup[(dept, n)] = renorm.get(n, 0.0)

    # --- Pass 2: grades + weights + payout formulas ---
    for r in range(data_start, max_row + 1):
        kpi = ws.cell(r, 2).value
        if kpi is None or str(kpi).strip() == "":
            continue
        dept = dept_at.get(r, "")
        ruleset = DEPT_RULESET.get(dept)

        w_pct = weight_lookup.get((dept, str(kpi).strip()))
        if w_pct is not None and w_pct > 0:
            ws.cell(r, col_w, round(w_pct, 2))
            ws.cell(r, col_w).alignment = Alignment(horizontal="center")
        elif ruleset:
            ws.cell(r, col_w, 0)
            ws.cell(r, col_w).alignment = Alignment(horizontal="center")
        else:
            ws.cell(r, col_w, "")

        if not ruleset:
            for c in (col_gj, col_gf, col_gm):
                ws.cell(r, c, "—")
            continue

        for col_month, col_g in (
            (col_jan, col_gj),
            (col_feb, col_gf),
            (col_mar, col_gm),
        ):
            raw = ws.cell(r, col_month).value
            try:
                val = float(raw) if raw is not None and str(raw).strip() != "" else None
            except (TypeError, ValueError):
                val = None
            g = grade_one(val, ruleset, str(kpi))
            if isinstance(g, int):
                c = ws.cell(r, col_g, g)
                c.fill = {5: fill_5, 4: fill_4, 3: fill_3}.get(g)
                c.alignment = Alignment(horizontal="center")
            elif g == "—":
                ws.cell(r, col_g, "—").alignment = Alignment(horizontal="center")
            else:
                ws.cell(r, col_g, "")

        ws.cell(r, col_pj, payout_formula(r, col_gj, col_pj))
        ws.cell(r, col_pf, payout_formula(r, col_gf, col_pf))
        ws.cell(r, col_pm, payout_formula(r, col_gm, col_pm))

    # --- Department summary formulas (below data) ---
    blocks = _dept_blocks(ws, dept_at, data_start, max_row)
    summary_start = max_row + 3
    ws.cell(summary_start - 1, 1, "Department summary (handbook-style weighted payout index 0–100)")
    ws.cell(summary_start - 1, 1).font = Font(bold=True, italic=True)
    ws.merge_cells(
        start_row=summary_start - 1,
        start_column=1,
        end_row=summary_start - 1,
        end_column=12,
    )

    hdr = summary_start
    ws.cell(hdr, 1, "Department")
    ws.cell(hdr, 2, "Jan index")
    ws.cell(hdr, 3, "Feb index")
    ws.cell(hdr, 4, "Mar index")
    ws.cell(hdr, 5, "Q1 avg")
    for c in range(1, 6):
        ws.cell(hdr, c).font = Font(bold=True)

    row_out = hdr + 1
    for dept, start_r, end_r in blocks:
        if dept not in DEPT_RULESET:
            continue
        ws.cell(row_out, 1, dept)
        # =SUMPRODUCT(I3:I6,J3:J6)/100  (weights × pay% / 100)
        for ci, (c_lo, c_hi) in enumerate(
            [(col_pj, col_pj), (col_pf, col_pf), (col_pm, col_pm)]
        ):
            col_letter_i = get_column_letter(col_w)
            col_letter_p = get_column_letter(c_lo)
            f = f"=IFERROR(SUMPRODUCT({col_letter_i}{start_r}:{col_letter_i}{end_r},{col_letter_p}{start_r}:{col_letter_p}{end_r})/100,\"\")"
            ws.cell(row_out, 2 + ci, f)
        # Q1 avg of three months
        ws.cell(
            row_out,
            5,
            f"=IF(COUNT({get_column_letter(2)}{row_out}:{get_column_letter(4)}{row_out})<3,\"\",AVERAGE({get_column_letter(2)}{row_out}:{get_column_letter(4)}{row_out}))",
        )
        row_out += 1

    ws.cell(row_out + 1, 1, "Index = Σ(weight% × pay%) / 100. Pay% = 100 if grade ≥4, else 0. Missing KPIs (—) excluded from grade; weights renorm.")

    if not dry_run:
        ensure_rulebook_sheet(wb)
        wb.save(path)
    else:
        wb.close()


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--dry-run", action="store_true")
    args = ap.parse_args()
    root = Path(__file__).resolve().parents[1]
    path = root / "kpis.xlsx"
    if not path.exists():
        print(f"Missing {path}", file=sys.stderr)
        sys.exit(1)
    apply_grades(path, dry_run=args.dry_run)
    if not args.dry_run:
        print(f"Updated {path}")


if __name__ == "__main__":
    main()
