# frontend/tabs/surface_3d.py
import streamlit as st
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import requests
import re

def render_surface_3d_tab(selected_slug: str):
    if not selected_slug:
        return

    resp = requests.get(f"http://172.17.0.2:3001/api/surface/{selected_slug}", timeout=25)
    if resp.status_code != 200:
        st.error("Failed to load surface")
        return

    surface = resp.json()
    matrix = np.array(surface.get("matrix", []))
    timestamps = surface.get("timestamps", [])
    bins = surface.get("bins", [])

    st.caption(f"3D Surface: {len(timestamps)} timestamps × {len(bins)} bins")

    if len(timestamps) < 10 or matrix.shape[0] == 0:
        st.warning("Not enough data for 3D view")
        return

    ts_berlin = pd.to_datetime(timestamps, utc=True).tz_convert("Europe/Berlin")

    # === ROBUST NUMERICAL SORTING ===
    def extract_bin_number(label: str) -> int:
        match = re.search(r'(\d+)', label)
        return int(match.group(1)) if match else 999999

    sorted_indices = sorted(range(len(bins)), key=lambda i: extract_bin_number(bins[i]))

    sorted_bin_labels = [bins[i] for i in sorted_indices]
    sorted_matrix = matrix[:, sorted_indices]   # reorder columns correctly

    # Clean labels
    clean_labels = [re.sub(r'Will Elon Musk post | tweets from .*2026\??', '', b).strip() 
                    for b in sorted_bin_labels]

    # === DOWNSAMPLING FOR 3D (very important) ===
    max_time_points = 400   # 3D surfaces get very slow/heavy above this
    if len(ts_berlin) > max_time_points:
        step = len(ts_berlin) // max_time_points
        plot_matrix = sorted_matrix[::step]
        plot_ts = ts_berlin[::step]
        st.info(f"Downsampled time from {len(ts_berlin)} → {len(plot_ts)} points for 3D performance")
    else:
        plot_matrix = sorted_matrix
        plot_ts = ts_berlin

    # Meshgrid
    X, Y = np.meshgrid(range(len(clean_labels)), range(len(plot_ts)))
    Z = plot_matrix

    fig = go.Figure(data=[go.Surface(
        x=X,
        y=Y,
        z=Z,
        colorscale='Viridis',
        colorbar=dict(title="Yes Price"),
        hovertemplate="Bucket: %{x}<br>Time: %{y}<br>Price: %{z:.4f}<extra></extra>"
    )])

    fig.update_layout(
        title="3D Price Surface — Sorted + Downsampled",
        scene=dict(
            xaxis=dict(title="Tweet Count Range", tickvals=list(range(len(clean_labels))), ticktext=clean_labels),
            yaxis=dict(title="Time", tickvals=list(range(len(plot_ts))), 
                      ticktext=plot_ts.strftime("%d %b %H:%M")),
            zaxis=dict(title="Yes Price", range=[0, 1]),
            camera=dict(eye=dict(x=1.9, y=1.9, z=1.4)),
        ),
        height=820,
        margin=dict(l=0, r=0, t=60, b=0),
    )

    st.plotly_chart(fig, use_container_width=True)

    # Info
    col1, col2, col3 = st.columns(3)
    with col1: st.metric("Original Timestamps", len(timestamps))
    with col2: st.metric("Displayed Points", len(plot_ts))
    with col3: st.metric("Bins", len(clean_labels))
