import streamlit as st
from pathlib import Path
import time
import re
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo

from core.data_loader import load_surfaces
from utils.xtracker_fetch import fetch_xtracker_cached, get_pace
from utils.polymarket_fetch import fetch_event, extract_yes_buckets
from data_layer.smart_money_analyzer import analyze_smart_money_flow
from utils.helpers import normalize_bucket
from core.portfolio import get_user_positions_by_token
class DataHub:

    _instance = None
    @staticmethod
    def get_instance():
        if DataHub._instance is None:
            DataHub._instance = DataHub()
        return DataHub._instance

    def __init__(self):
        self.global_data = {}
        self.event_cache = {}
        self.market_cache = {}
        self.last_refresh = 0
        self.refresh()

    def refresh(self):
        now = time.time()
        if now - self.last_refresh < 30:
            return
        self.last_refresh = now

        # FORCE cache clear so we always see the latest .npz
        load_surfaces.clear()

        # === GLOBAL LAYER ===
        self.global_data["tweets"] = fetch_xtracker_cached()
        self.global_data["pace_24h"] = get_pace(1.0)
        self.global_data["pace_7d"] = get_pace(7.0)
        self.global_data["pace_30d"] = get_pace(30.0)
        self.global_data["avg_pace"] = (self.global_data["pace_24h"] +
                                      self.global_data["pace_7d"] +
                                      self.global_data["pace_30d"]) / 3
        self.global_data["portfolio"] = get_user_positions_by_token()

        # === PER-EVENT LAYER ===
        folder = str(Path.cwd()) + "/data"
        surfaces, _ = load_surfaces(folder)
        for slug, data in surfaces.items():
            matrix = data["matrix"]
            meta = data.get("meta", {})
            if not isinstance(meta, dict):
                meta = {"event_title": str(meta)}
            self.event_cache[slug] = {
                "matrix": matrix,
                "meta": meta,
                "current_prices": matrix[-1] if len(matrix) > 0 else None,
                "timestamps": meta.get("timestamps", [])
            }

        # === PER-MARKET LAYER ===
        for slug in list(self.event_cache.keys()):
            event = fetch_event(slug)
            if not event:
                continue
            buckets = extract_yes_buckets(event)
            for b in buckets:
                token = b.get("token")
                if not token:
                    continue
                label = b.get("label", "")
                try:
                    bin_labels = [normalize_bucket(q)[0] for q in self.event_cache[slug]["meta"].get("bins", [])]
                    idx = bin_labels.index(normalize_bucket(label)[0])
                    price = self.event_cache[slug]["current_prices"][idx] if self.event_cache[slug]["current_prices"] is not None else 0.5
                except:
                    price = 0.5
                sm = analyze_smart_money_flow(event.get("conditionId") or "", slug)
                self.market_cache[(slug, token)]  ={
                    "token_id": token,
                    "current_price": float(price),
                    "label": label,
                    "smfi": sm.get("smfi",0)
                }
                if token in self.global_data["portfolio"].keys():
                    avg_price=self.global_data["portfolio"].get(token).get("avgPrice")
                    self.market_cache[(slug,token)]["PriceAVG"]=avg_price

    def get_global(self, key):
        self.refresh()
        return self.global_data.get(key)

    def get_event(self, slug, key):
        self.refresh()
        return self.event_cache.get(slug, {}).get(key)

    def get_market(self, slug, token, key):
        self.refresh()
        return self.market_cache.get((slug, token), {}).get(key)

    def get_all_slugs(self):
        self.refresh()
        return list(self.event_cache.keys())

    def is_event_finished(self, slug: str) -> bool:
        """Heuristic: finished events have only 1 bucket left (the final result)."""
        if slug not in self.event_cache:
            return True  # treat unknown as finished
        meta = self.event_cache[slug].get("meta", {})
        bins = meta.get("bins", [])
        # finished events show only the winning/final bucket
        return len(bins) <= 1

    def get_active_slugs(self) -> list:
        """Return only non-finished events that still have multiple buckets."""
        return [
            slug for slug in self.event_cache.keys()
            if not self.is_event_finished(slug)
        ]
    def get_rule_context(self, slug=None, token=None, side=None):
        """FINAL — Top/Bottom/Center are now parsed with rock-solid token→label lookup"""
        self.refresh()
        ctx = {}

        # ====================== GLOBAL LAYER ======================
        ctx["Heat4h"]  = self._calculate_heat(4*60)
        ctx["Heat10m"] = self._calculate_heat(10)
        ctx["Pace24h"] = float(self.global_data.get("pace_24h", 0.0))
        ctx["Pace7d"]  = float(self.global_data.get("pace_7d", 0.0))
        ctx["Pace30d"] = float(self.global_data.get("pace_30d", 0.0))
        ctx["PaceAVG"] = float(self.global_data.get("avg_pace", 0.0))

        # ====================== PER-SLUG (EVENT) LAYER ======================
        if slug and slug in self.event_cache:
            event = self.event_cache[slug]
            meta = event.get("meta", {})
            bins = meta.get("bins", [])

            tweets = self.global_data.get("tweets", [])
            start_berlin, end_berlin = self._get_event_period(slug)
            start_berlin = datetime.fromisoformat(str(start_berlin).replace("Z","+00:00"))
            end_berlin = datetime.fromisoformat(str(end_berlin).replace("Z","+00:00"))
            ctx["Count"] = len([t for t in tweets if start_berlin <= datetime.fromisoformat(str(t.get("createdAt","")).replace("Z","+00:00")) <= end_berlin])

            now = datetime.now(ZoneInfo("Europe/Berlin"))
            total_hours = (end_berlin - start_berlin).total_seconds() / 3600
            passed_hours = (now - start_berlin).total_seconds() / 3600
            ctx["Progress"] = round(max(0, min(1, passed_hours / total_hours)), 3) if total_hours > 0 else 0.0
            ctx["RHours"] = max(0.1, (end_berlin - now).total_seconds() / 3600)

            current_prices = event.get("current_prices")
            if current_prices is not None and len(bins) > 0:
                max_idx = int(np.argmax(current_prices))
                peak_label = str(bins[max_idx])
                from utils.helpers import normalize_bucket
                _, low, high = normalize_bucket(peak_label)
                ctx["Peak"] = int(high)
            else:
                ctx["Peak"] = 1200

            ctx["Top"] = ctx["Bottom"] = ctx["Center"] = 0

        # ====================== PER-BUCKET LAYER — BULLETPROOF LABEL PARSING ======================
        if token is not None:
            from utils.polymarket_fetch import get_orderbook_data
            price, volume, vola = get_orderbook_data(token, side)
            ctx["Price"]  = round(price, 4)
            ctx["Volume"] = int(volume)
            ctx["Vola"]   = round(vola, 4)
            if "PriceAVG" in self.market_cache[(slug,token)].keys():
                ctx["PriceAVG"]= float(self.market_cache[(slug,token)].get("PriceAVG"))

            # === ROCK-SOLID TOKEN → LABEL LOOKUP ===
            label = ""
            token_str = str(token)

            # 1. Fastest: market_cache (already populated in refresh)
            if (slug, token) in self.market_cache:
                label = self.market_cache[(slug, token)].get("label", "")
                #print(f"[LABEL_LOOKUP] market_cache hit → {label}")

            # 2. Live fetch (most accurate)
            if not label:
                event = fetch_event(slug) if slug else None
                if event:
                    for b in extract_yes_buckets(event):
                        if str(b.get("token")) == token_str:
                            label = b.get("label", "")
                            #print(f"[LABEL_LOOKUP] live fetch hit → {label}")
                            break

            # 3. Final fallback: cached meta (never fails)
            if not label and slug in self.event_cache:
                meta = self.event_cache[slug].get("meta", {})
                saved_bins = meta.get("bins", [])
                saved_tokens = meta.get("tokens", [])
                for i, t in enumerate(saved_tokens):
                    if str(t) == token_str and i < len(saved_bins):
                        label = str(saved_bins[i])
                        #print(f"[LABEL_LOOKUP] meta fallback hit → {label}")
                        break

            # === PARSE Bottom / Top / Center ===
            def _parse_bucket_range(text: str):
                text = str(text).lower()
                m = re.search(r'(\d+)\s*(?:\+|or more|greater)', text)
                if m:
                    low = int(m.group(1))
                    return low, 999
                m = re.search(r'(\d+)[^\d]+(\d+)', text)
                if m:
                    return int(m.group(1)), int(m.group(2))
                m = re.search(r'(?:less than|under|<)\s*(\d+)', text)
                if m:
                    high = int(m.group(1)) - 1
                    return 0, max(0, high)
                m = re.search(r'(\d+)', text)
                if m:
                    n = int(m.group(1))
                    return n, n + 19
                return 0, 19

            bottom, top = _parse_bucket_range(label)
            ctx["Bottom"] = bottom
            ctx["Top"]    = top
            ctx["Center"] = round((bottom + top) / 2, 1)
            ctx["Bucket"] = label

            #print(f"[ORDERBOOK+LABEL] Token {token_str[:12]}... → Label='{label}' Top={top} Bottom={bottom}")

        return ctx

    def _calculate_heat(self, minutes: int):
        """Real HeatX from Signals-tab logic"""
        tweets = self.global_data.get("tweets", [])
        if not tweets:
            return 0.0
        cutoff = datetime.now(ZoneInfo("Europe/Berlin")) - timedelta(minutes=minutes)
        count = len([t for t in tweets if datetime.fromisoformat(str(t.get("createdAt","")).replace("Z","+00:00")) >= cutoff])
        return round(count / (minutes / 60), 1)   # tweets per hour

    def _get_event_period(self, slug):
        """Exact same period as Signals tab — ALWAYS returns datetime objects"""
        from utils.helpers import get_counting_period_from_slug
        period = get_counting_period_from_slug(slug)
        if period and len(period) == 2:
            start, end = period
            # Force datetime right here
            if isinstance(start, str):
                start = datetime.fromisoformat(str(start).replace("Z","+00:00"))
            if isinstance(end, str):
                end = datetime.fromisoformat(str(end).replace("Z","+00:00"))
            return start, end
        # Safe fallback
        now = datetime.now(ZoneInfo("Europe/Berlin"))
        return now - timedelta(days=7), now


    def get_bins_for_event(self, slug: str):
        """STRICTLY TOKEN-DRIVEN — always uses live event as source of truth.
        No more index-based pairing from .npz (buckets can resolve and change order)."""
        self.refresh()
        #print(f"[GET_BINS] {datetime.now().strftime('%H:%M:%S')} — REQUEST for slug '{slug}' (token-based)")

        # Primary: live fetch — guaranteed correct token ↔ label pairing
        event = fetch_event(slug)
        if event:
            buckets = extract_yes_buckets(event)
            bins = []
            for b in buckets:
                token = b.get("token")
                label = b.get("label", "")
                if token:
                    bins.append({"token": token, "label": label})
                    #print(f"[GET_BINS]     → Added {label} (token={str(token)[:12]}...)")
            if bins:
                #print(f"[GET_BINS]   ✅ RETURNING {len(bins)} bins (live token-based)")
                return bins

        #print(f"[GET_BINS]   ⚠️ Live fetch failed, falling back to .npz (last resort)")
        # Absolute last-resort fallback (still token-safe)
        if slug in self.event_cache:
            meta = self.event_cache[slug].get("meta", {})
            saved_bins = meta.get("bins", [])
            saved_tokens = meta.get("tokens", [])
            bins = []
            for i, label in enumerate(saved_bins):
                token = saved_tokens[i] if i < len(saved_tokens) else None
                if token:
                    bins.append({"token": token, "label": str(label)})
            if bins:
                return bins

        return []
