"""
Backtester: Estrategia Punto de Gravedad (Mean Reversion)
==========================================================
Lógica:
  - Trigger: |precio - PG| >= movimiento_esperado
  - CCS (bajista): precio >> PG  → short_call = precio+10, long_call = precio+15
  - PCS (alcista): precio << PG  → short_put  = precio-10, long_put  = precio-15
  - TP: +$100 | SL: -$100 | Cierre al final del día si no toca ninguno
  - Multiplicador SPX: x100
  - Solo 1 operación por día
"""

import os
import sys
import glob
import pandas as pd
import numpy as np

# ── Rutas ──────────────────────────────────────────────────────────────────────
PRED_DIR = "/var/www/html/backtestingmarket/predictor_data/data/SPX"
CHAIN_DIR = "/var/www/html/flask_project/chains"
OUT_PATH = "/var/www/html/backtestingmarket/backtest_results_pg.csv"
MULTIPLIER = 100
TP_USD = 250
SL_USD = -250

# ── Filtro de fechas (modificar aquí o pasar por CLI) ─────────────────────────
FECHA_DESDE = "2025-08-01"   # inclusive  (None = sin límite)
FECHA_HASTA = None           # inclusive  (None = sin límite)

# ── Filtros de calidad ────────────────────────────────────────────────────────
HORA_MAX_ENTRADA = "14:00"   # no entrar después de esta hora (HH:MM)
CREDITO_MIN_USD = 150         # crédito mínimo en USD para abrir la operación
TRIGGER_RATIO_MIN = 0.5      # mínimo ratio distancia/ME para activar trigger


# ── Confirmación de salida (anti-spike) ───────────────────────────────────────
SL_CONFIRMACION = 3   # ticks consecutivos en zona SL para confirmar salida
TP_CONFIRMACION = 3   # ticks consecutivos en zona TP para confirmar salida

# ── Helpers ────────────────────────────────────────────────────────────────────


def round_to_strike(price, base=5):
    return round(round(price / base) * base, 1)


def get_mid(chain_ts, strike, option_type):
    row = chain_ts[chain_ts['strike'] == strike]
    if row.empty:
        return None
    if option_type == 'call':
        bid, ask = row['bid_call'].iloc[0], row['ask_call'].iloc[0]
    else:
        bid, ask = row['bid_put'].iloc[0], row['ask_put'].iloc[0]
    if bid <= 0 or ask <= 0:
        return None
    return (bid + ask) / 2


def spread_value(chain_ts, short_strike, long_strike, option_type):
    mid_short = get_mid(chain_ts, short_strike, option_type)
    mid_long = get_mid(chain_ts, long_strike,  option_type)
    if mid_short is None or mid_long is None:
        return None
    return mid_short - mid_long


# ── Backtester de un día ───────────────────────────────────────────────────────

def backtest_day(date_str):
    pred_path = os.path.join(PRED_DIR,  f"prediction_$SPX_{date_str}.csv")
    chain_path = os.path.join(CHAIN_DIR, f"optionChain_$SPX_{date_str}.parquet")

    if not os.path.exists(pred_path) or not os.path.exists(chain_path):
        return None   # sin archivos → skipped

    pred = pd.read_csv(pred_path, parse_dates=['timestamp'])
    chain = pd.read_parquet(chain_path)
    chain['timestamp'] = pd.to_datetime(chain['timestamp'])

    chain_timestamps = sorted(chain['timestamp'].unique())

    def nearest_chain_ts(ts):
        idx = np.searchsorted(chain_timestamps, ts)
        if idx >= len(chain_timestamps):
            idx = len(chain_timestamps) - 1
        return chain_timestamps[idx]

    trade_open = False
    traded_today = False
    trade = None

    for _, row in pred.iterrows():
        ts = row['timestamp']
        precio = row['precio_actual']
        pg = row['punto_de_gravedad']
        me = row['movimiento_esperado']
        distancia = precio - pg

        # ── Detectar trigger ──────────────────────────────────────────────────
        if not trade_open and not traded_today:
            if ts.strftime('%H:%M') > HORA_MAX_ENTRADA:
                continue

            if distancia >= me * TRIGGER_RATIO_MIN:
                direction = 'CCS'
            elif distancia <= -me * TRIGGER_RATIO_MIN:
                direction = 'PCS'
            else:
                continue

            if direction == 'CCS':
                short_strike = round_to_strike(precio + 10)
                long_strike = short_strike + 5
                opt_type = 'call'
            else:
                short_strike = round_to_strike(precio - 10)
                long_strike = short_strike - 5
                opt_type = 'put'

            entry_ts = nearest_chain_ts(ts)
            chain_entry = chain[chain['timestamp'] == entry_ts]

            entry_credit = spread_value(chain_entry, short_strike, long_strike, opt_type)
            if entry_credit is None or entry_credit <= 0:
                continue

            if entry_credit * MULTIPLIER < CREDITO_MIN_USD:
                continue

            trade_open = True
            sl_ticks = 0
            tp_ticks = 0
            trade = {
                'date':             date_str,
                'direction':        direction,
                'trigger_time':     ts.strftime('%H:%M:%S'),
                'precio_entry':     round(precio, 2),
                'pg':               round(pg, 2),
                'me':               round(me, 2),
                'distancia':        round(distancia, 2),
                'ratio_vs_me':      round(distancia / me, 3),
                'short_strike':     short_strike,
                'long_strike':      long_strike,
                'opt_type':         opt_type,
                'entry_credit':     round(entry_credit, 4),
                'entry_credit_usd': round(entry_credit * MULTIPLIER, 2),
                'exit_time':        None,
                'exit_spread_usd':  None,
                'pnl_usd':          None,
                'exit_reason':      None,
            }

        # ── Monitorear posición ───────────────────────────────────────────────
        if trade_open:
            monitor_ts = nearest_chain_ts(ts)
            chain_monitor = chain[chain['timestamp'] == monitor_ts]

            current_val = spread_value(chain_monitor, trade['short_strike'],
                                       trade['long_strike'], trade['opt_type'])
            if current_val is None:
                continue

            pnl_usd = (trade['entry_credit'] - current_val) * MULTIPLIER

            if pnl_usd >= TP_USD:
                tp_ticks += 1
                sl_ticks = 0
            elif pnl_usd <= SL_USD:
                sl_ticks += 1
                tp_ticks = 0
            else:
                tp_ticks = 0
                sl_ticks = 0

            if tp_ticks >= TP_CONFIRMACION:
                trade['exit_time'] = ts.strftime('%H:%M:%S')
                trade['exit_spread_usd'] = round(current_val * MULTIPLIER, 2)
                trade['pnl_usd'] = round(pnl_usd, 2)
                trade['exit_reason'] = 'TP'
                trade_open = False
                traded_today = True

            elif sl_ticks >= SL_CONFIRMACION:
                trade['exit_time'] = ts.strftime('%H:%M:%S')
                trade['exit_spread_usd'] = round(current_val * MULTIPLIER, 2)
                trade['pnl_usd'] = round(pnl_usd, 2)
                trade['exit_reason'] = 'SL'
                trade_open = False
                traded_today = True

    # ── Cierre al final del día si quedó abierto ──────────────────────────────
    if trade_open and trade is not None:
        last_ts = chain_timestamps[-1]
        last_pred = pred.iloc[-1]
        precio_final = last_pred['precio_actual']
        exit_time = pd.Timestamp(last_ts).strftime('%H:%M:%S')

        short_s = trade['short_strike']
        long_s = trade['long_strike']

        if trade['opt_type'] == 'call':
            if precio_final <= short_s:
                spread_expiry = 0.0
            elif precio_final >= long_s:
                spread_expiry = (long_s - short_s) * MULTIPLIER
            else:
                spread_expiry = (precio_final - short_s) * MULTIPLIER
        else:
            if precio_final >= short_s:
                spread_expiry = 0.0
            elif precio_final <= long_s:
                spread_expiry = (short_s - long_s) * MULTIPLIER
            else:
                spread_expiry = (short_s - precio_final) * MULTIPLIER

        entry_usd = trade['entry_credit'] * MULTIPLIER
        pnl_usd = entry_usd - spread_expiry

        trade['exit_time'] = exit_time
        trade['exit_spread_usd'] = round(spread_expiry, 2)
        trade['pnl_usd'] = round(pnl_usd, 2)
        trade['exit_reason'] = 'EXPIRY'
        trade['precio_expiry'] = round(precio_final, 2)

    # None = sin archivos | {} = sin trigger | dict con datos = trade ejecutado
    return trade if trade is not None else {}


# ── Main ───────────────────────────────────────────────────────────────────────

def run_backtest(fecha_desde=None, fecha_hasta=None):
    desde = fecha_desde or FECHA_DESDE
    hasta = fecha_hasta or FECHA_HASTA

    pred_files = glob.glob(os.path.join(PRED_DIR, "prediction_$SPX_*.csv"))
    all_dates = sorted([
        os.path.basename(f).replace("prediction_$SPX_", "").replace(".csv", "")
        for f in pred_files
    ])

    dates = [d for d in all_dates
             if (desde is None or d >= desde)
             and (hasta is None or d <= hasta)]

    print(f"{'='*90}")
    print(f"  BACKTEST — PG Mean Reversion  |  TP: ${TP_USD}  SL: ${SL_USD}")
    print(f"  Período: {desde} → {hasta}  |  Días disponibles: {len(dates)}")
    print(f"{'='*90}")
    print(f"{'Fecha':<12} {'Dir':<4} {'Trigger':>8} {'Precio':>8} {'PG':>8} "
          f"{'Dist':>7} {'Ratio':>6} {'Strikes':<13} {'Crédito$':>9} "
          f"{'Salida':>8} {'Razón':<7} {'P&L':>8}")
    print(f"{'-'*90}")

    all_trades = []
    no_trade = 0
    skipped = 0

    for date in dates:
        trade = backtest_day(date)

        if trade is None:
            # Archivos no encontrados
            skipped += 1
            continue

        if not trade:
            # Archivos OK pero sin trigger en todo el día
            no_trade += 1
            print(f"{date:<12} {'—':<4} {'—':>8} {'—':>8} {'—':>8} "
                  f"{'—':>7} {'—':>6} {'—':<13} {'—':>9} "
                  f"{'—':>8} {'—':<7} {'sin trigger':>8}")
            continue

        all_trades.append(trade)

        strikes_str = f"{trade['short_strike']:.0f}/{trade['long_strike']:.0f}"
        pnl_str = f"${trade['pnl_usd']:+.0f}" if trade['pnl_usd'] is not None else "N/A"
        ratio_str = f"{trade['ratio_vs_me']:+.2f}x"
        expiry_extra = f"  precio_cierre={trade['precio_expiry']}" if trade.get('precio_expiry') else ""

        print(f"{date:<12} {trade['direction']:<4} {trade['trigger_time']:>8} "
              f"{trade['precio_entry']:>8.1f} {trade['pg']:>8.1f} "
              f"{trade['distancia']:>7.1f} {ratio_str:>6} {strikes_str:<13} "
              f"${trade['entry_credit_usd']:>8.2f} {str(trade['exit_time']):>8} "
              f"{trade['exit_reason']:<7} {pnl_str:>8}{expiry_extra}")

    print(f"{'='*90}")
    print(f"  Días sin chain: {skipped}  |  Días sin trigger: {no_trade}  |  Trades ejecutados: {len(all_trades)}")

    if not all_trades:
        print("  Sin trades en el período.")
        return

    df = pd.DataFrame(all_trades)
    valid = df[df['pnl_usd'].notna()]

    total_pnl = valid['pnl_usd'].sum()
    win_rate = (valid['pnl_usd'] > 0).mean() * 100
    avg_win = valid[valid['pnl_usd'] > 0]['pnl_usd'].mean()
    avg_loss = valid[valid['pnl_usd'] < 0]['pnl_usd'].mean()
    n_tp = (valid['exit_reason'] == 'TP').sum()
    n_sl = (valid['exit_reason'] == 'SL').sum()
    n_exp = (valid['exit_reason'] == 'EXPIRY').sum()
    n_ccs = (valid['direction'] == 'CCS').sum()
    n_pcs = (valid['direction'] == 'PCS').sum()

    print(f"\n{'='*90}")
    print(f"  RESUMEN GLOBAL")
    print(f"{'='*90}")
    print(f"  Total trades      : {len(valid)}  (CCS: {n_ccs} | PCS: {n_pcs})")
    print(f"  P&L Total         : ${total_pnl:+,.2f}")
    print(f"  Win Rate          : {win_rate:.1f}%")
    print(f"  Avg Win           : ${avg_win:+.2f}" if not np.isnan(avg_win) else "  Avg Win  : N/A")
    print(f"  Avg Loss          : ${avg_loss:+.2f}" if not np.isnan(avg_loss) else "  Avg Loss : N/A")
    print(f"  Salidas TP/SL/EXP : {n_tp} / {n_sl} / {n_exp}")
    print(f"{'='*90}")

    df.to_csv(OUT_PATH, index=False)
    print(f"\n  CSV guardado en: {OUT_PATH}")

    return df


if __name__ == "__main__":
    if len(sys.argv) == 3:
        run_backtest(fecha_desde=sys.argv[1], fecha_hasta=sys.argv[2])
    elif len(sys.argv) == 2:
        run_backtest(fecha_desde=sys.argv[1])
    else:
        run_backtest()
