# frontend/tabs/heatmap.py
import streamlit as st
import plotly.graph_objects as go
import pandas as pd
import requests
import numpy as np
import re

def render_heatmap_tab(selected_slug: str):
    resp = requests.get(f"http://172.17.0.2:3001/api/surface/{selected_slug}", timeout=25)
    if resp.status_code != 200:
        st.error("Backend error")
        return

    surface = resp.json()
    matrix = np.array(surface.get("matrix", []))
    timestamps = surface.get("timestamps", [])
    bins = surface.get("bins", [])

    st.success(f"Loaded {len(timestamps)} timestamps × {len(bins)} bins")

    if len(timestamps) < 10 or len(matrix) == 0:
        st.warning("Not enough data")
        return

    ts_berlin = pd.to_datetime(timestamps, utc=True).tz_convert("Europe/Berlin")

    # === PROPER NUMERICAL SORTING OF BINS ===
    def extract_bin_start(label: str) -> int:
        match = re.search(r'(\d+)', label)
        return int(match.group(1)) if match else 999999

    # Sort bins by starting number
    sorted_indices = sorted(range(len(bins)), key=lambda i: extract_bin_start(bins[i]))
    
    sorted_bin_labels = [bins[i] for i in sorted_indices]
    sorted_matrix = matrix[:, sorted_indices]   # reorder columns

    # Clean labels for display
    clean_labels = []
    for b in sorted_bin_labels:
        clean = re.sub(r'Will Elon Musk post ', '', b)
        clean = re.sub(r' tweets from .*2026\??', '', clean)
        clean_labels.append(clean.strip())

    # Plot
    fig = go.Figure(data=go.Heatmap(
        z=sorted_matrix.T,           # Transpose: rows = bins, columns = time
        x=ts_berlin,
        y=clean_labels,
        colorscale="Viridis",
        colorbar=dict(title="Yes Price"),
    ))

    fig.update_layout(
        height=820,
        title="Full History Heatmap — Correctly Sorted Bins",
        xaxis_title="Time (Berlin)",
        yaxis_title="Tweet Count Range",
        margin=dict(l=100, r=50, t=100, b=140),
    )

    fig.update_xaxes(range=[ts_berlin[0], ts_berlin[-1]], autorange=False, tickangle=-45)

    st.plotly_chart(fig, use_container_width=True, key=f"heat_{selected_slug}_{len(timestamps)}")

    # Info
    col1, col2, col3 = st.columns(3)
    with col1: st.metric("Timestamps", len(timestamps))
    with col2: st.metric("Bins", len(clean_labels))
    with col3: st.metric("Span", f"{(ts_berlin[-1] - ts_berlin[0]).days} days")
