
import streamlit as st
import pandas as pd
import time
from typing import List, Dict, Any
from config import SESSION, FIDELITY, PRIVATE_KEY, WALLET_ADDRESS

def render_portfolio_tab():
    """Renders the exact tree requested:
    - One top-level expander per eventSlug (sorted by endDate)
    - Parent header shows eventSlug + sum(size*avgPrice) + sum(size*curPrice) + diff + percent
    - Children: labeled by slug without eventSlug prefix, with all requested columns
    Uses nested expanders for true tree feel + clean DataFrame inside each group.
    """
    if not WALLET_ADDRESS:
        return {}
    u = f"https://data-api.polymarket.com/positions?sizeThreshold=1&limit=100&sortBy=TOKENS&sortDirection=DESC&user={WALLET_ADDRESS}"
    def loadpos(url):
        try:
            r = SESSION.get(url, timeout=15)
            r.raise_for_status()
            return r.json()
        except Exception as e:
            print(e)
            time.sleep(2)
            r = SESSION.get(url, timeout=15)
            r.raise_for_status()
            return r.json()

    positions=loadpos(u)
    if not positions:
        st.warning("No position data provided.")
        return

    df = pd.DataFrame(positions)

    # Ensure numeric safety
    for col in ["size", "avgPrice", "curPrice"]:
        df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0)

    # === FILTER REDEEMABLE POSITIONS (default: removed) ===
    original_len = len(df)
    df = df[df["redeemable"] == False].reset_index(drop=True)
    filtered_count = original_len - len(df)
    if filtered_count > 0:
        st.info(f"✅ Removed {filtered_count} redeemable positions (0 value / resolved).")
    # ======================================================

    if len(df) == 0:
        st.warning("No active positions remaining after filtering.")
        return

    # Portfolio-level totals (after filtering)
    portfolio_initial = (df["size"] * df["avgPrice"]).sum()
    portfolio_current = (df["size"] * df["curPrice"]).sum()
    portfolio_diff = portfolio_current - portfolio_initial
    portfolio_pct = (portfolio_diff / portfolio_initial * 100) if portfolio_initial != 0 else 0.0

    st.header("🧬 Polymarket Positions Tree")
    st.caption("Grouped by eventSlug • Only non-redeemable positions • size × price calculations")

    col1, col2, col3, col4 = st.columns(4)
    with col1:
        st.metric("Portfolio Initial", f"${portfolio_initial:,.2f}")
    with col2:
        st.metric("Portfolio Current", f"${portfolio_current:,.2f}")
    with col3:
        st.metric("Portfolio P&L", f"${portfolio_diff:,.2f}", f"{portfolio_pct:+.2f}%")
    with col4:
        st.metric("Active Positions", len(df))

    st.divider()

    # Group by eventSlug
    grouped = df.groupby("eventSlug")

    # Sort by endDate
    sorted_groups = sorted(
        grouped,
        key=lambda x: pd.to_datetime(x[1]["endDate"].iloc[0])
    )

    for event_slug, group_df in sorted_groups:
        # Parent aggregates
        group_initial = (group_df["size"] * group_df["avgPrice"]).sum()
        group_current = (group_df["size"] * group_df["curPrice"]).sum()
        group_diff = group_current - group_initial
        group_pct = (group_diff / group_initial * 100) if group_initial != 0 else 0.0

        with st.expander(
            f"🌳 {event_slug} | "
            f"Initial: ${group_initial:,.2f} | "
            f"Current: ${group_current:,.2f} | "
            f"P&L: ${group_diff:,.2f} ({group_pct:+.2f}%)",
            expanded=False
        ):
            gcol1, gcol2, gcol3 = st.columns(3)
            with gcol1:
                st.metric("Group Initial", f"${group_initial:,.2f}")
            with gcol2:
                st.metric("Group Current", f"${group_current:,.2f}")
            with gcol3:
                st.metric("Group P&L", f"${group_diff:,.2f}", f"{group_pct:+.2f}%")

            st.markdown("**Child Positions** (slug without eventSlug prefix)")

            child_df = group_df.copy()
            child_df["short_slug"] = child_df["slug"].str.replace(f"^{event_slug}-", "", regex=True)
            child_df["initial"] = child_df["size"] * child_df["avgPrice"]
            child_df["current"] = child_df["size"] * child_df["curPrice"]
            child_df["diff"] = child_df["current"] - child_df["initial"]
            child_df["diff_pct"] = (child_df["diff"] / child_df["initial"] * 100).where(
                child_df["initial"] != 0, 0.0
            )

            display_cols = [
                "short_slug",
                "size",
                "avgPrice",
                "curPrice",
                "initial",
                "current",
                "diff",
                "diff_pct",
                "outcome"
            ]
            display_df = child_df[display_cols].round(6)

            st.dataframe(
                display_df.style.format({
                    "size": "{:,.4f}",
                    "avgPrice": "{:.5f}",
                    "curPrice": "{:.5f}",
                    "initial": "${:,.2f}",
                    "current": "${:,.2f}",
                    "diff": "${:,.2f}",
                    "diff_pct": "{:+.2f}%"
                }),
                use_container_width=True,
                hide_index=True
            )

            st.caption(f"{len(group_df)} positions • End date: {group_df['endDate'].iloc[0]}")
