import re
import numpy as np
import pandas as pd
from datetime import datetime, timezone, timedelta
import streamlit as st
from utils.polymarket_fetch import parse_bucket_range
from config import SESSION



# ====================== CONSTANTS & HELPERS ======================BEGINS_HERE the logic inside this block MUST BE KEPT AS THEY ARE!!
## for the retrieval and processing of polymarket data this is mandatory and not to replace
STANDARD_BINS = [f"{i*20}-{i*20+19}" for i in range(30)]

def parse_lower(label: str) -> int:
    s = str(label).lower().strip()
    if any(k in s for k in ['+', 'more', 'over', 'greater']):
        m = re.search(r'(\d+)', s)
        return int(m.group(1)) if m else 99999
    m = re.search(r'(\d+)[^\d]+(\d+)', s)
    if m:
        return int(m.group(1))
    m = re.search(r'(\d+)', s)
    return int(m.group(1)) if m else 0

def get_event_title(meta, filename):
    """100% bulletproof version — never crashes again"""

    # Case 1: meta is None
    if meta is None:
        return str(filename)[:100]

    # Case 2: meta is already a string (this was causing your error)
    if isinstance(meta, str):
        return meta[:120]

    # Case 3: meta is a dict (normal case)
    if isinstance(meta, dict):
        title = meta.get("event_title") or meta.get("title") or meta.get("name") or ""
        if isinstance(title, np.ndarray):
            title = str(title.item() if title.size == 1 else title)
        return str(title)

    # Case 4: anything else (fallback)
    return str(filename)[:100]

def get_counting_period_from_slug(slug):
    try:
        fi=np.load(f"data/elon_surface_{slug}_FULL.npz")
        endd=datetime.fromisoformat(str(fi.get('end_date')))
        d2=timedelta(days=2)
        d7=timedelta(days=7)
        st7=endd-d7
        ststr7=st7.strftime('%B-').lower()+str(int(st7.strftime('%d').lower()))
        st2=endd-d2
        ststr2=st2.strftime('%B-').lower()+str(int(st2.strftime('%d').lower()))

        if ststr2 in str(fi.get('event_slug')).lower():
            return (endd-d2, endd)
        elif ststr7 in str(fi.get('event_slug')).lower():
            return (endd-d7, endd)
    except Exception as e:
        print(e)
        return False

def discover_active_slugs():
    slugs = []
    try:
        r = SESSION.get("https://polymarket.com/predictions/elon-tweets", timeout=25)
        found = re.findall(r'elon-musk-of-tweets-[^"\']+', r.text)
        for s in found:
            if len(s) < 80 and not re.search(r'\.(jpg|png|gif)', s):
                slugs.append(s)
    except:
        pass
    return list(dict.fromkeys(slugs))

def fetch_event(slug):
    for ep in ["events", "markets"]:
        for closed in [True, False, None]:
            params = {"slug": slug}
            if closed is not None:
                params["closed"] = str(closed).lower()
            try:
                r = SESSION.get(f"https://gamma-api.polymarket.com/{ep}", params=params, timeout=25)
                if r.status_code == 200:
                    raw = r.json()
                    if isinstance(raw, list) and raw:
                        return raw[0]
                    for key in ["data", "events", "markets", "items", "event"]:
                        if isinstance(raw, dict) and raw.get(key):
                            v = raw[key]
                            if isinstance(v, list) and v: return v[0]
                            if isinstance(v, dict): return v
            except:
                continue
    return None

def extract_yes_buckets(event):
    outcomes = event.get("outcomes")
    if isinstance(outcomes, str):
        try: outcomes = json.loads(outcomes)
        except: outcomes = []
    clobs = event.get("clobTokenIds")
    if isinstance(clobs, str):
        try: clobs = json.loads(clobs)
        except: clobs = []
    if outcomes and clobs and len(outcomes) == len(clobs) and len(outcomes) >= 8:
        return [{"label": str(label).strip(), "token": clobs[i]} for i, label in enumerate(outcomes)]

    markets = event.get("markets", []) or [event]
    buckets = []
    for m in markets:
        q = m.get("question", "")
        if "elon" not in q.lower() or "tweet" not in q.lower():
            continue
        clobs = m.get("clobTokenIds")
        if isinstance(clobs, str):
            try: clobs = json.loads(clobs)
            except: continue
        if clobs:
            buckets.append({"label": q.strip(), "token": clobs[0]})
    seen = {b["label"]: b for b in buckets}
    return sorted(seen.values(), key=lambda x: x["label"])

def fetch_new_prices(token, last_ts):
    all_pts = []
    current = last_ts + 1
    now = int(time.time()) + 86400
    chunks = 0
    while current < now and chunks < 25:  # HARD SAFETY LIMIT
        end = min(current + 86400 * 2, now)
        try:
            r = SESSION.get("https://clob.polymarket.com/prices-history", params={
                "market": token,
                "startTs": current,
                "endTs": end,
                "fidelity": FIDELITY
            }, timeout=25)
            if r.status_code == 200:
                pts = r.json().get("history", [])
                all_pts.extend(pts)
                if pts:
                    current = pts[-1]["t"] + 1
                else:
                    current = end
            else:
                current = end
        except:
            current = end
        chunks += 1
        time.sleep(0.05)
    unique = {p["t"]: p for p in all_pts}
    return sorted(unique.values(), key=lambda x: x["t"])

def parse_lower(label: str) -> int:
    s = str(label).lower().strip()
    if any(k in s for k in ['+', 'more', 'over', 'greater']):
        m = re.search(r'(\d+)', s)
        return int(m.group(1)) if m else 99999
    m = re.search(r'(\d+)[^\d]+(\d+)', s)
    if m:
        return int(m.group(1))
    m = re.search(r'(\d+)', s)
    return int(m.group(1)) if m else 0

def get_category_from_event_title(event_title):
    event_title = event_title.lower()
    if re.search(r'(march|april|may|june|july|august|september|october|november|december|january|february)-\d{4}', event_title):
        return "month"
    year_match = re.search(r'20(\d{2})', event_title)
    year = int("20" + year_match.group(1)) if year_match else 2026
    match = re.search(r'([a-z]+)\s*(\d+)\s*-\s*([a-z]+)\s*(\d+)', event_title)
    if match:
        s_month, s_day, e_month, e_day = match.groups()
        month_map = {"january":1,"february":2,"march":3,"april":4,"may":5,"june":6,"july":7,"august":8,"september":9,"october":10,"november":11,"december":12}
        start = datetime(year, month_map.get(s_month, 1), int(s_day))
        end = datetime(year, month_map.get(e_month, 1), int(e_day))
        if end < start:
            end = end.replace(year=end.year + 1)
        duration = (end - start).days + 1
        weekday = start.weekday()
        if duration <= 3:
            return "2day"
        if weekday == 1:
            return "7day_tue"
        if weekday == 4:
            return "7day_fri"
        return "other"
    return "other"

def normalize_bucket(q: str) -> tuple[str, float, float]:
    """Force lowest/highest buckets into the same 20-tweet bin format as the others"""
    q_lower = q.lower().strip()

    # Extract the key number
    m = re.search(r'(\d+)', q_lower)
    if not m:
        return q[:90], 0.0, 1000.0
    num = float(m.group(1))

    # === LOWEST BUCKET (less than / < / fewer than) ===
    if any(word in q_lower for word in ["less than", "fewer than", "<", "under", "below"]):
        high = max(0, num - 1)
        label = f"0 - {int(high)}"
        return label, 0.0, float(high)

    # === HIGHEST BUCKET (or more / + / greater) ===
    if any(word in q_lower for word in ["or more", "or higher", "+", "greater", "over", "above"]):
        low = num
        high = low + 19                     # exactly matches the 20-wide bins used everywhere else
        label = f"{int(low)} - {int(high)}"
        return label, float(low), float(high)

    # === NORMAL RANGE BUCKETS (800-819 etc.) ===
    rng = parse_bucket_range(q)
    if rng:
        low, high = rng
        label = f"{int(low)} - {int(high)}"
        return label, low, high

    # Fallback
    return q[:90], num - 10, num + 10

# ====================== CONSTANTS & HELPERS ======================ENDS_HERE the logic inside this block MUST BE KEPT AS THEY ARE!!

def midpoint(low: float, high: float) -> float:
    """Clean midpoint — works perfectly with the new normalized buckets"""
    return (low + high) / 2

def gaussian_pdf(x, loc=0.0, scale=1.0):
    """Pure NumPy replacement for scipy.stats.norm.pdf — identical math"""
    x = np.asarray(x)
    return np.exp(-0.5 * ((x - loc) / scale) ** 2) / (scale * np.sqrt(2 * np.pi))

def calculate_center(mids, probs):
    return float(np.average(mids, weights=probs))

def calculate_skewness(mids, probs):
    mu = np.average(mids, weights=probs)
    var = np.average((mids - mu)**2, weights=probs)
    return float(np.average((mids - mu)**3, weights=probs) / (var ** 1.5)) if var > 0 else 0.0
