from __future__ import annotations

from datetime import datetime, timedelta, date
from pathlib import Path
from typing import List, Dict, Optional

import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
import os
from flask import request, jsonify


# ======================================================
# CONFIG
# ======================================================
BASE_DATA_DIR = Path("/var/www/html/flask_project/chains")
FILE_PATTERN = "optionChain_{symbol}_{yyyy}-{mm}-{dd}.parquet"
INDEX_SYMBOLS = {"SPX", "RUT", "XSP"}


def _run_single_day(
    d: date,
    chain_symbol: str,
    entry_time: str,
    strategy: str,
    target_delta_abs: float,
    width: float,
    contracts: int,
    slippage: float,
    TP: float,
    SL: float,
    return_intraday: bool = False,
    full_day_path: bool = False,
    confirm_n: int = 3,
):
    """
    Ejecuta un solo día para Iron Condor / Iron Butterfly.

    Reglas:
    - confirm_n = cantidad de snapshots CONSECUTIVOS sobrepasando TP o SL para confirmar salida.
      * TP se confirma si pnl >= TP por confirm_n snapshots seguidos
      * SL se confirma si pnl <= -SL por confirm_n snapshots seguidos
      * si pnl vuelve a estar dentro del rango (no cumple), se resetea hits.

    - full_day_path = si True, siempre recorre hasta fin del día (EOD) y devuelve path completo,
      aunque el trade haya "salido" antes (trade_timestamp).

    - return_intraday = si True, incluye pnl_path (y underlying si existe).
    """

    # -----------------------
    # Load day parquet
    # -----------------------
    p = find_parquet_for_day(chain_symbol, d)
    if not p:
        return {"_missing": d.isoformat()}

    df = pd.read_parquet(p)



    # -------------------------------------------------
    # SAFE MID PRICES (equivalente vectorizado de _safe_mid)
    # -------------------------------------------------
    for c in ["bid_call", "ask_call", "bid_put", "ask_put"]:
        df[c] = pd.to_numeric(df[c], errors="coerce")

    call_ok = (df["bid_call"] > 0) & (df["ask_call"] > 0)
    put_ok  = (df["bid_put"]  > 0) & (df["ask_put"]  > 0)

    df["call_mid"] = ((df["bid_call"] + df["ask_call"]) * 0.5).where(call_ok, 0.0).fillna(0.0)
    df["put_mid"]  = ((df["bid_put"]  + df["ask_put"])  * 0.5).where(put_ok,  0.0).fillna(0.0)








    # normalizaciones seguras
    df["timestamp"] = pd.to_datetime(df["timestamp"], errors="coerce")
    df = df.dropna(subset=["timestamp"])
    df["strike"] = df["strike"].astype(float)

    # -----------------------
    # ENTRY SNAPSHOT
    # -----------------------
    snap_entry = find_snapshot_by_time(df, entry_time)
    if snap_entry is None or len(snap_entry) == 0:
        return {"_missing": d.isoformat(), "_reason": "no_entry_snapshot"}

    sells = find_sell_strikes_by_delta(snap_entry, target_delta_abs)
    call_sell, put_sell = sells["call_sell"], sells["put_sell"]

    call_buy = pick_wing_strike(snap_entry, call_sell["strike"], width, "call")
    put_buy = pick_wing_strike(snap_entry, put_sell["strike"], width, "put")

    legs = {
        "call_sell": call_sell,
        "call_buy": call_buy,
        "put_sell": put_sell,
        "put_buy": put_buy,
    }

    entry_gross = calc_ic_credit_gross(call_sell, put_sell, call_buy, put_buy, contracts)
    entry_net = calc_entry_credit_net(entry_gross["credit_gross"], slippage)

    entry_ts = pd.to_datetime(snap_entry["timestamp"].iloc[0])

    underlying_entry = None
    if "underlying_price" in snap_entry.columns:
        try:
            underlying_entry = float(snap_entry["underlying_price"].iloc[0])
        except Exception:
            underlying_entry = None

    # -----------------------
    # INTRADAY ITERATION SETUP
    # -----------------------
    df_idx = df.set_index(["timestamp", "strike"]).sort_index()

    ts_all = df_idx.index.get_level_values(0).unique()
    ts_all = ts_all[ts_all > entry_ts]

    # path (solo si lo pedimos)
    pnl_path = []

    # trade state
    exit_reason = "EOD"

    # trade exit (cuando se confirma TP/SL)
    exit_ts_trade = None
    pnl_trade_exit = None

    # path exit (último timestamp recorrido; si full_day_path=True normalmente EOD)
    exit_ts_path = entry_ts

    # último credit calculado
    final_credit_now = float(entry_gross["credit_gross"])

    # hits consecutivos
    tp_hits = 0
    sl_hits = 0

    # valores robustos
    TP = float(TP or 0.0)
    SL = float(SL or 0.0)
    confirm_n = int(confirm_n or 1)
    if confirm_n < 1:
        confirm_n = 1

    # mapa rápido del underlying (opcional)
    u_map = None
    if "underlying_price" in df.columns:
        try:
            u_map = df.drop_duplicates("timestamp").set_index("timestamp")["underlying_price"]
        except Exception:
            u_map = None

    # último pnl visto (por si no guardamos path)
    last_pnl = 0.0

    # -----------------------
    # CACHE STRIKES (una sola vez)
    # -----------------------
    cs_rows = df_idx.xs(legs["call_sell"]["strike"], level="strike")
    ps_rows = df_idx.xs(legs["put_sell"]["strike"], level="strike")
    cb_rows = df_idx.xs(legs["call_buy"]["strike"], level="strike")
    pb_rows = df_idx.xs(legs["put_buy"]["strike"], level="strike")

    cs_mid = cs_rows["call_mid"]
    ps_mid = ps_rows["put_mid"]
    cb_mid = cb_rows["call_mid"]
    pb_mid = pb_rows["put_mid"]

    # -----------------------
    # LOOP
    # -----------------------
    for ts in ts_all:
        # credit_now = float(credit_for_snapshot_fast(df_idx, ts, legs, contracts))
        try:
            credit_now = (
                cs_mid.get(ts, 0.0) +
                ps_mid.get(ts, 0.0) -
                cb_mid.get(ts, 0.0) -
                pb_mid.get(ts, 0.0)
            ) * 100 * contracts
        except KeyError:
            credit_now = 0.0

        pnl = float(entry_net["credit_net"] - credit_now)
        last_pnl = pnl

        # underlying
        underlying_now = None
        if u_map is not None:
            v = u_map.get(ts)
            underlying_now = float(v) if v is not None else None

        # PATH (solo si pedimos intraday o full_day_path)
        if return_intraday or full_day_path:
            pnl_path.append({
                "ts": str(ts),
                "credit_now_gross": round(credit_now, 2),
                "pnl": round(pnl, 2),
                "underlying": underlying_now,
            })

        final_credit_now = credit_now
        exit_ts_path = ts

        # si ya cerró trade y queremos full path → no evaluamos TP/SL
        if full_day_path and exit_reason != "EOD":
            continue

        # -----------------------
        # TP hits consecutivos
        # -----------------------
        if TP > 0:
            if pnl >= TP:
                tp_hits += 1
            else:
                tp_hits = 0

        # -----------------------
        # SL hits consecutivos
        # -----------------------
        if SL > 0:
            if pnl <= -SL:
                sl_hits += 1
            else:
                sl_hits = 0

        # -----------------------
        # Confirmación (solo una vez)
        # -----------------------
        if exit_reason == "EOD":
            if TP > 0 and tp_hits >= confirm_n:
                exit_reason = "TP"
                exit_ts_trade = ts
                pnl_trade_exit = pnl
                if not full_day_path:
                    break

            elif SL > 0 and sl_hits >= confirm_n:
                exit_reason = "SL"
                exit_ts_trade = ts
                pnl_trade_exit = pnl
                if not full_day_path:
                    break

    # -----------------------
    # FINALES (sin None / sin NaN)
    # -----------------------
    # pnl EOD: si hay path, el último; si no, last_pnl
    pnl_eod = float(pnl_path[-1]["pnl"]) if pnl_path else float(last_pnl)

    # pnl del trade: si hubo TP/SL confirmado -> ese pnl, si no -> pnl_eod
    final_pnl_trade = float(pnl_trade_exit) if pnl_trade_exit is not None else float(pnl_eod)

    # timestamp legacy que usa tu frontend
    exit_ts_legacy = exit_ts_trade if exit_ts_trade is not None else exit_ts_path

    # ✅ final_pnl para frontend: SIEMPRE número y coherente con reason
    # (esto evita NaN y rompe menos tu renderRunTable)
    if exit_reason == "TP":
        final_pnl_legacy = float(TP)
    elif exit_reason == "SL":
        final_pnl_legacy = -float(SL)
    else:
        final_pnl_legacy = float(final_pnl_trade)

    out = {
        "date": d.isoformat(),
        "file": str(p),
        "entry_time": entry_time,
        "underlying_entry": underlying_entry,
        "legs": legs,
        "entry": {
            "timestamp": str(entry_ts),
            "credit_points": float(entry_gross.get("credit_points", 0.0)),
            "credit_gross": float(entry_gross.get("credit_gross", 0.0)),
            "slippage": float(entry_net.get("slippage", 0.0)),
            "credit_net": float(entry_net.get("credit_net", 0.0)),
        },
        "exit": {
            "reason": exit_reason,

            # ✅ LEGACY (lo que espera backtesting.html)
            "timestamp": str(exit_ts_legacy),
            "time_hhmm": exit_ts_legacy.strftime("%H:%M"),
            "final_pnl": round(float(final_pnl_legacy), 2),  # ✅ SIEMPRE

            # ✅ Nuevo (por si quieres usar luego)
            "trade_timestamp": str(exit_ts_trade) if exit_ts_trade else None,
            "trade_time_hhmm": exit_ts_trade.strftime("%H:%M") if exit_ts_trade else None,
            "path_timestamp": str(exit_ts_path),
            "path_time_hhmm": exit_ts_path.strftime("%H:%M"),

            "final_credit_now_gross": round(float(final_credit_now), 2),
            "final_pnl_trade": round(float(final_pnl_trade), 2),
            "final_pnl_eod": round(float(pnl_eod), 2),

            "tp": float(TP),
            "sl": float(SL),
            "confirm_n": int(confirm_n),
        },
        "result": {
            "label": "WIN" if final_pnl_trade > 0 else "LOSS" if final_pnl_trade < 0 else "FLAT",
            "is_positive": final_pnl_trade > 0,
            "pnl": round(float(final_pnl_trade), 2),
        },
    }

    if return_intraday:
        out["pnl_path"] = pnl_path

    return out


# ======================================================
# HELPERS
# ======================================================
def normalize_symbol_for_chain(symbol: str) -> str:
    return f"${symbol}" if symbol in INDEX_SYMBOLS else symbol


def _parse_date(s: str) -> date:
    return datetime.strptime(str(s), "%Y-%m-%d").date()


def list_weekdays(d0: date, d1: date) -> List[date]:
    cur = d0
    out = []
    while cur <= d1:
        if cur.weekday() < 5:
            out.append(cur)
        cur += timedelta(days=1)
    return out


def find_parquet_for_day(symbol: str, day: date) -> Optional[Path]:
    fname = FILE_PATTERN.format(
        symbol=symbol,
        yyyy=f"{day.year:04d}",
        mm=f"{day.month:02d}",
        dd=f"{day.day:02d}",
    )
    p = (BASE_DATA_DIR / fname).resolve()
    return p if p.exists() else None


def _safe_mid(bid, ask) -> float:
    if pd.isna(bid) or pd.isna(ask):
        return 0.0
    bid = float(bid)
    ask = float(ask)
    if bid <= 0 or ask <= 0:
        return 0.0
    return (bid + ask) / 2.0


# ======================================================
# SNAPSHOT POR HORA (OPTIMIZADO, MISMO RESULTADO)
# ======================================================
def find_snapshot_by_time(df: pd.DataFrame, hhmm: str) -> pd.DataFrame:
    ts = pd.to_datetime(df["timestamp"], errors="coerce")
    good = ts.notna()
    if not good.any():
        raise ValueError("All timestamps invalid")

    df = df.loc[good]
    ts = ts.loc[good]

    target_minutes = int(hhmm[:2]) * 60 + int(hhmm[3:])
    minutes = ts.dt.hour * 60 + ts.dt.minute

    idx = (minutes - target_minutes).abs().idxmin()
    snap_ts = df.loc[idx, "timestamp"]

    return df[df["timestamp"] == snap_ts]


# ======================================================
# STRIKES POR DELTA (MISMO)
# ======================================================
def find_sell_strikes_by_delta(snap: pd.DataFrame, target_delta_abs: float) -> Dict:
    df = snap

    df["call_diff"] = (df["delta_call"] - target_delta_abs).abs()
    df["put_diff"] = (df["delta_put"].abs() - target_delta_abs).abs()

    call_row = df.loc[df["call_diff"].idxmin()]
    put_row = df.loc[df["put_diff"].idxmin()]

    return {
        "call_sell": {
            "strike": float(call_row["strike"]),
            "delta": float(call_row["delta_call"]),
            "mid": _safe_mid(call_row["bid_call"], call_row["ask_call"]),
            "bid": float(call_row["bid_call"]),
            "ask": float(call_row["ask_call"]),
        },
        "put_sell": {
            "strike": float(put_row["strike"]),
            "delta": float(put_row["delta_put"]),
            "mid": _safe_mid(put_row["bid_put"], put_row["ask_put"]),
            "bid": float(put_row["bid_put"]),
            "ask": float(put_row["ask_put"]),
        },
    }


def pick_wing_strike(snap: pd.DataFrame, sell_strike: float, width: float, side: str) -> Dict:
    s = snap
    s["strike"] = s["strike"].astype(float)

    target = sell_strike + width if side == "call" else sell_strike - width

    if side == "call":
        df = s[s["strike"] >= target]
        row = df.loc[(df["strike"] - target).abs().idxmin()] if not df.empty else s.iloc[-1]
        return {
            "strike": float(row["strike"]),
            "delta": float(row["delta_call"]),
            "mid": _safe_mid(row["bid_call"], row["ask_call"]),
            "bid": float(row["bid_call"]),
            "ask": float(row["ask_call"]),
        }

    df = s[s["strike"] <= target]
    row = df.loc[(df["strike"] - target).abs().idxmin()] if not df.empty else s.iloc[0]
    return {
        "strike": float(row["strike"]),
        "delta": float(row["delta_put"]),
        "mid": _safe_mid(row["bid_put"], row["ask_put"]),
        "bid": float(row["bid_put"]),
        "ask": float(row["ask_put"]),
    }


# ======================================================
# CREDIT
# ======================================================
def calc_ic_credit_gross(call_sell, put_sell, call_buy, put_buy, contracts: int) -> Dict:
    credit_points = (
        call_sell["mid"] + put_sell["mid"]
        - call_buy["mid"] - put_buy["mid"]
    )
    credit_gross = credit_points * 100 * contracts
    return {
        "credit_points": round(float(credit_points), 4),
        "credit_gross": round(float(credit_gross), 2),
    }


def calc_entry_credit_net(credit_gross: float, slippage: float) -> Dict:
    credit_net = credit_gross - slippage
    return {
        "slippage": round(slippage, 2),
        "credit_net": round(credit_net, 2),
    }


# ======================================================
# FAST SNAPSHOT CREDIT (CLAVE DE SPEED)
# ======================================================
def credit_for_snapshot_fast(df_idx, ts, legs, contracts: int) -> float:
    def mid(ts, strike, side):
        try:
            r = df_idx.loc[(ts, strike)]
            if side == "call":
                return _safe_mid(r["bid_call"], r["ask_call"])
            return _safe_mid(r["bid_put"], r["ask_put"])
        except KeyError:
            return 0.0

    cs = mid(ts, legs["call_sell"]["strike"], "call")
    ps = mid(ts, legs["put_sell"]["strike"], "put")
    cb = mid(ts, legs["call_buy"]["strike"], "call")
    pb = mid(ts, legs["put_buy"]["strike"], "put")

    return (cs + ps - cb - pb) * 100 * contracts


# ======================================================
# ENGINE — MISMO OUTPUT, MUCHÍSIMO MÁS RÁPIDO
# ======================================================
def run_backtest_engine(params: dict) -> dict:
    symbol_raw = params.get("symbol")
    entry_time = params.get("entry_time")
    strategy = params.get("strategy", "IC")
    confirm_n = 5

    if not symbol_raw or not entry_time:
        return {"ok": False, "error": "missing symbol/entry_time"}

    d0 = _parse_date(params.get("date_from"))
    d1 = _parse_date(params.get("date_to"))

    dn = int(params.get("delta_call", 10))
    width = float(params.get("width_call", 20))
    contracts = int(params.get("contracts", 1))
    slippage = float(params.get("slippage", 0.0))
    TP = float(params.get("tp", 0.0))
    SL = float(params.get("sl", 0.0))

    chain_symbol = normalize_symbol_for_chain(symbol_raw)
    days = list_weekdays(d0, d1)
    target_delta_abs = 0.50 if strategy == "IB" else max(0.01, min(0.50, dn / 100))

    per_day, missing_days, errors = [], [], []

    max_workers = min(6, os.cpu_count() or 2)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(
                _run_single_day,
                d,
                chain_symbol,
                entry_time,
                strategy,
                target_delta_abs,
                width,
                contracts,
                slippage,
                TP,
                SL,
                False,          # return_intraday
                False,          # full_day_path
                confirm_n       # 👈 OBLIGATORIO
            ): d
            for d in days
        }

        for future in as_completed(futures):
            try:
                res = future.result()
                if "_missing" in res:
                    missing_days.append(res["_missing"])
                else:
                    per_day.append(res)
            except Exception as e:
                errors.append({"error": str(e)})

    return {
        "ok": True,
        "step": "ENTRY_AND_PATH",
        "symbol_input": symbol_raw,
        "symbol_chain": chain_symbol,
        "strategy": strategy,
        "target_delta_abs": target_delta_abs,
        "width": width,
        "days_total": len(days),
        "days_processed": len(per_day),
        "missing_days": missing_days,
        "errors": errors,
        "per_day": sorted(per_day, key=lambda x: x["date"]),
    }
