from __future__ import annotations

from datetime import datetime, timedelta, date
from pathlib import Path
from typing import List, Dict, Optional

import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
import os

# ======================================================
# CONFIG
# ======================================================
BASE_DATA_DIR = Path("/var/www/html/flask_project/chains")
FILE_PATTERN = "optionChain_{symbol}_{yyyy}-{mm}-{dd}.parquet"
INDEX_SYMBOLS = {"SPX", "RUT", "XSP"}



def _run_single_day(
    d: date,
    chain_symbol: str,
    entry_time: str,
    strategy: str,
    target_delta_abs: float,
    width: float,
    contracts: int,
    slippage: float,
    TP: float,
    SL: float,
):
    p = find_parquet_for_day(chain_symbol, d)
    if not p:
        return {"_missing": d.isoformat()}

    df = pd.read_parquet(p)

    df["timestamp"] = pd.to_datetime(df["timestamp"], errors="coerce")
    df["strike"] = df["strike"].astype(float)

    snap_entry = find_snapshot_by_time(df, entry_time)

    sells = find_sell_strikes_by_delta(snap_entry, target_delta_abs)
    call_sell, put_sell = sells["call_sell"], sells["put_sell"]

    call_buy = pick_wing_strike(snap_entry, call_sell["strike"], width, "call")
    put_buy = pick_wing_strike(snap_entry, put_sell["strike"], width, "put")

    legs = {
        "call_sell": call_sell,
        "call_buy": call_buy,
        "put_sell": put_sell,
        "put_buy": put_buy,
    }

    entry_gross = calc_ic_credit_gross(call_sell, put_sell, call_buy, put_buy, contracts)
    entry_net = calc_entry_credit_net(entry_gross["credit_gross"], slippage)

    entry_ts = pd.to_datetime(snap_entry["timestamp"].iloc[0])
    underlying_entry = float(snap_entry["underlying_price"].iloc[0]) if "underlying_price" in snap_entry.columns else None

    df_idx = df.set_index(["timestamp", "strike"]).sort_index()
    ts_all = df_idx.index.get_level_values(0).unique()
    ts_all = ts_all[ts_all > entry_ts]

    pnl_path = []
    exit_reason = "EOD"
    exit_ts = entry_ts
    final_credit_now = entry_gross["credit_gross"]

    for ts in ts_all:
        credit_now = credit_for_snapshot_fast(df_idx, ts, legs, contracts)
        pnl = entry_net["credit_net"] - credit_now

        pnl_path.append({
            "ts": str(ts),
            "credit_now_gross": round(float(credit_now), 2),
            "pnl": round(float(pnl), 2),
        })

        final_credit_now = credit_now
        exit_ts = ts

        if TP > 0 and pnl >= TP:
            exit_reason = "TP"
            break
        if SL > 0 and pnl <= -SL:
            exit_reason = "SL"
            break

    final_pnl = pnl_path[-1]["pnl"] if pnl_path else 0.0

    return {
        "date": d.isoformat(),
        "file": str(p),
        "entry_time": entry_time,
        "underlying_entry": underlying_entry,
        "legs": legs,
        "entry": {
            "timestamp": str(entry_ts),
            "credit_points": entry_gross["credit_points"],
            "credit_gross": entry_gross["credit_gross"],
            "slippage": entry_net["slippage"],
            "credit_net": entry_net["credit_net"],
        },
        "exit": {
            "reason": exit_reason,
            "timestamp": str(exit_ts),
            "time_hhmm": exit_ts.strftime("%H:%M"),
            "final_credit_now_gross": round(float(final_credit_now), 2),
            "final_pnl": round(float(final_pnl), 2),
            "tp": TP,
            "sl": SL,
        },
        "result": {
            "label": "WIN" if final_pnl > 0 else "LOSS" if final_pnl < 0 else "FLAT",
            "is_positive": final_pnl > 0,
            "pnl": round(float(final_pnl), 2),
        },
    }



# ======================================================
# HELPERS
# ======================================================
def normalize_symbol_for_chain(symbol: str) -> str:
    return f"${symbol}" if symbol in INDEX_SYMBOLS else symbol


def _parse_date(s: str) -> date:
    return datetime.strptime(str(s), "%Y-%m-%d").date()


def list_weekdays(d0: date, d1: date) -> List[date]:
    cur = d0
    out = []
    while cur <= d1:
        if cur.weekday() < 5:
            out.append(cur)
        cur += timedelta(days=1)
    return out


def find_parquet_for_day(symbol: str, day: date) -> Optional[Path]:
    fname = FILE_PATTERN.format(
        symbol=symbol,
        yyyy=f"{day.year:04d}",
        mm=f"{day.month:02d}",
        dd=f"{day.day:02d}",
    )
    p = (BASE_DATA_DIR / fname).resolve()
    return p if p.exists() else None


def _safe_mid(bid, ask) -> float:
    if pd.isna(bid) or pd.isna(ask):
        return 0.0
    bid = float(bid)
    ask = float(ask)
    if bid <= 0 or ask <= 0:
        return 0.0
    return (bid + ask) / 2.0


# ======================================================
# SNAPSHOT POR HORA (OPTIMIZADO, MISMO RESULTADO)
# ======================================================
def find_snapshot_by_time(df: pd.DataFrame, hhmm: str) -> pd.DataFrame:
    ts = pd.to_datetime(df["timestamp"], errors="coerce")
    good = ts.notna()
    if not good.any():
        raise ValueError("All timestamps invalid")

    df = df.loc[good]
    ts = ts.loc[good]

    target_minutes = int(hhmm[:2]) * 60 + int(hhmm[3:])
    minutes = ts.dt.hour * 60 + ts.dt.minute

    idx = (minutes - target_minutes).abs().idxmin()
    snap_ts = df.loc[idx, "timestamp"]

    return df[df["timestamp"] == snap_ts]


# ======================================================
# STRIKES POR DELTA (MISMO)
# ======================================================
def find_sell_strikes_by_delta(snap: pd.DataFrame, target_delta_abs: float) -> Dict:
    df = snap

    df["call_diff"] = (df["delta_call"] - target_delta_abs).abs()
    df["put_diff"] = (df["delta_put"].abs() - target_delta_abs).abs()

    call_row = df.loc[df["call_diff"].idxmin()]
    put_row = df.loc[df["put_diff"].idxmin()]

    return {
        "call_sell": {
            "strike": float(call_row["strike"]),
            "delta": float(call_row["delta_call"]),
            "mid": _safe_mid(call_row["bid_call"], call_row["ask_call"]),
            "bid": float(call_row["bid_call"]),
            "ask": float(call_row["ask_call"]),
        },
        "put_sell": {
            "strike": float(put_row["strike"]),
            "delta": float(put_row["delta_put"]),
            "mid": _safe_mid(put_row["bid_put"], put_row["ask_put"]),
            "bid": float(put_row["bid_put"]),
            "ask": float(put_row["ask_put"]),
        },
    }


def pick_wing_strike(snap: pd.DataFrame, sell_strike: float, width: float, side: str) -> Dict:
    s = snap
    s["strike"] = s["strike"].astype(float)

    target = sell_strike + width if side == "call" else sell_strike - width

    if side == "call":
        df = s[s["strike"] >= target]
        row = df.loc[(df["strike"] - target).abs().idxmin()] if not df.empty else s.iloc[-1]
        return {
            "strike": float(row["strike"]),
            "delta": float(row["delta_call"]),
            "mid": _safe_mid(row["bid_call"], row["ask_call"]),
            "bid": float(row["bid_call"]),
            "ask": float(row["ask_call"]),
        }

    df = s[s["strike"] <= target]
    row = df.loc[(df["strike"] - target).abs().idxmin()] if not df.empty else s.iloc[0]
    return {
        "strike": float(row["strike"]),
        "delta": float(row["delta_put"]),
        "mid": _safe_mid(row["bid_put"], row["ask_put"]),
        "bid": float(row["bid_put"]),
        "ask": float(row["ask_put"]),
    }


# ======================================================
# CREDIT
# ======================================================
def calc_ic_credit_gross(call_sell, put_sell, call_buy, put_buy, contracts: int) -> Dict:
    credit_points = (
        call_sell["mid"] + put_sell["mid"]
        - call_buy["mid"] - put_buy["mid"]
    )
    credit_gross = credit_points * 100 * contracts
    return {
        "credit_points": round(float(credit_points), 4),
        "credit_gross": round(float(credit_gross), 2),
    }


def calc_entry_credit_net(credit_gross: float, slippage: float) -> Dict:
    credit_net = credit_gross - slippage
    return {
        "slippage": round(slippage, 2),
        "credit_net": round(credit_net, 2),
    }


# ======================================================
# FAST SNAPSHOT CREDIT (CLAVE DE SPEED)
# ======================================================
def credit_for_snapshot_fast(df_idx, ts, legs, contracts: int) -> float:
    def mid(ts, strike, side):
        try:
            r = df_idx.loc[(ts, strike)]
            if side == "call":
                return _safe_mid(r["bid_call"], r["ask_call"])
            return _safe_mid(r["bid_put"], r["ask_put"])
        except KeyError:
            return 0.0

    cs = mid(ts, legs["call_sell"]["strike"], "call")
    ps = mid(ts, legs["put_sell"]["strike"], "put")
    cb = mid(ts, legs["call_buy"]["strike"], "call")
    pb = mid(ts, legs["put_buy"]["strike"], "put")

    return (cs + ps - cb - pb) * 100 * contracts


# ======================================================
# ENGINE — MISMO OUTPUT, MUCHÍSIMO MÁS RÁPIDO
# ======================================================
def run_backtest_engine(params: dict) -> dict:
    symbol_raw = params.get("symbol")
    entry_time = params.get("entry_time")
    strategy = params.get("strategy", "IC")

    if not symbol_raw or not entry_time:
        return {"ok": False, "error": "missing symbol/entry_time"}

    d0 = _parse_date(params.get("date_from"))
    d1 = _parse_date(params.get("date_to"))

    dn = int(params.get("delta_call", 10))
    width = float(params.get("width_call", 20))
    contracts = int(params.get("contracts", 1))
    slippage = float(params.get("slippage", 0.0))
    TP = float(params.get("tp", 0.0))
    SL = float(params.get("sl", 0.0))

    chain_symbol = normalize_symbol_for_chain(symbol_raw)
    days = list_weekdays(d0, d1)
    target_delta_abs = 0.50 if strategy == "IB" else max(0.01, min(0.50, dn / 100))

    per_day, missing_days, errors = [], [], []

    max_workers = min(6, os.cpu_count() or 2)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(
                _run_single_day,
                d,
                chain_symbol,
                entry_time,
                strategy,
                target_delta_abs,
                width,
                contracts,
                slippage,
                TP,
                SL,
            ): d
            for d in days
        }

        for future in as_completed(futures):
            try:
                res = future.result()
                if "_missing" in res:
                    missing_days.append(res["_missing"])
                else:
                    per_day.append(res)
            except Exception as e:
                errors.append({"error": str(e)})

    return {
        "ok": True,
        "step": "ENTRY_AND_PATH",
        "symbol_input": symbol_raw,
        "symbol_chain": chain_symbol,
        "strategy": strategy,
        "target_delta_abs": target_delta_abs,
        "width": width,
        "days_total": len(days),
        "days_processed": len(per_day),
        "missing_days": missing_days,
        "errors": errors,
        "per_day": sorted(per_day, key=lambda x: x["date"]),
    }

