"""

SlimeTree-RLM: Two-Regime Comparison

Mild (near-uniform) vs Concentrated: cancellation holds in both.

Addresses: "Is SC just keeping π uniform?"

"""

 

import numpy as np

import matplotlib.pyplot as plt

import matplotlib

matplotlib.rcParams.update({

    'font.size': 11, 'font.family': 'serif',

    'axes.grid': True, 'grid.alpha': 0.3, 'lines.linewidth': 1.5,

})

 

K = 3; T = 600; N_runs = 50; BURN = 60; W = 15

 

def softmax(theta):

    t = theta - theta.max(); e = np.exp(t); return e / e.sum()

 

def run(mode, eta_max, g_fail, fail_rate, base_loss, seed=42):

    rng = np.random.RandomState(seed)

    theta = np.zeros(K)

    log = dict(pi=[], pi_min=[], GS=[], eta=[], delta_norm=[],

               F_dag_norm=[], C_t=[], product=[], switches=[], is_failure=[])

    dom_hist = []

    for t in range(T):

        pi = softmax(theta); pi = np.clip(pi, 1e-12, None); pi /= pi.sum()

        loss = base_loss + 0.1 * rng.randn(K)

        is_f = rng.rand() < fail_rate

        fail_sig = np.array([g_fail, -g_fail*0.4, -g_fail*0.6]) if is_f else np.zeros(K)

        grad = pi * (loss - pi @ loss) + fail_sig

        grad_T = grad - np.mean(grad)

        ng = grad_T / pi

        dom = np.argmax(pi); dom_hist.append(dom)

        w = min(40, len(dom_hist))

        GS = min(sum(1 for d in dom_hist[-w:] if d == dom) / w, 0.995)

        eta = eta_max if mode == 'fixed' else eta_max * (1.0 - GS)

        delta = eta * ng

        pi_min = pi.min()

        dom_before = np.argmax(pi)

        log['pi'].append(pi.copy()); log['pi_min'].append(pi_min)

        log['GS'].append(GS); log['eta'].append(eta)

        log['delta_norm'].append(np.linalg.norm(delta))

        log['F_dag_norm'].append(1.0/pi_min)

        log['C_t'].append(pi_min / max(1-GS, 1e-10))

        log['product'].append(eta / pi_min)

        log['is_failure'].append(is_f)

        theta = theta - np.clip(delta, -5, 5)

        log['switches'].append(1 if np.argmax(softmax(theta)) != dom_before else 0)

    for k in log: log[k] = np.array(log[k])

    return log

 

# ─── Two regimes ───

regimes = {

    'mild': dict(eta_max=0.10, g_fail=0.5, fail_rate=0.08,

                 base_loss=np.array([0.20, 0.55, 0.75])),

    'conc': dict(eta_max=0.20, g_fail=0.7, fail_rate=0.10,

                 base_loss=np.array([0.10, 0.70, 1.00])),

}

 

results = {}

for rname, p in regimes.items():

    print(f"\n--- {rname} ---")

    V = run('fixed', **p, seed=42)

    A = run('adaptive', **p, seed=42)

    sw_V_all, sw_A_all, C_all = [], [], []

    for s in range(N_runs):

        rv = run('fixed', **p, seed=3000+s)

        ra = run('adaptive', **p, seed=3000+s)

        sw_V_all.append(int(rv['switches'].sum()))

        sw_A_all.append(int(ra['switches'].sum()))

        C_all.append(ra['C_t'])

    C_all = np.array(C_all)

    results[rname] = dict(V=V, A=A, sw_V=sw_V_all, sw_A=sw_A_all, C=C_all)

    

    piA = np.array(A['pi'].tolist())

    pm = piA[BURN:].min(axis=1).mean()

    px = piA[BURN:].max(axis=1).mean()

    pV = V['product'][BURN:].mean()

    pA = A['product'][BURN:].mean()

    

    print(f"  π_max={px:.3f}, π_min={pm:.4f}, ratio={px/max(pm,1e-10):.1f}×")

    print(f"  Vanilla: {np.mean(sw_V_all):.1f}±{np.std(sw_V_all):.1f} sw, prod={pV:.2f}")

    print(f"  Adaptive: {np.mean(sw_A_all):.1f}±{np.std(sw_A_all):.1f} sw, prod={pA:.4f}")

    print(f"  Reduction: {pV/max(pA,1e-10):.0f}×")

 

# ═══════════════════════════════════

# FIGURE 5: Two-regime comparison (2×3 grid)

# ═══════════════════════════════════

fig, axes = plt.subplots(3, 2, figsize=(12, 10))

names = ['Recall', 'Explore', 'Consolidate']

cols_k = ['#1f77b4', '#ff7f0e', '#2ca02c']

 

regime_info = {

    'mild': ('Mild regime\n(near-uniform, π_max/π_min ≈ 1.2)', '#2ca02c'),

    'conc': ('Concentrated regime\n(π_max ≈ 0.54, π_min ≈ 0.17)', '#ff7f0e'),

}

 

for col_idx, rname in enumerate(['mild', 'conc']):

    R = results[rname]

    A, V = R['A'], R['V']

    title, accent = regime_info[rname]

    

    # Row 0: π trajectories (adaptive)

    ax = axes[0, col_idx]

    pi_arr = np.array(A['pi'].tolist())

    for k in range(K):

        ax.plot(pi_arr[:, k], color=cols_k[k], alpha=0.8, linewidth=1.2,

                label=f'π({names[k]})' if col_idx == 0 else None)

    for s in np.where(A['switches'] == 1)[0]:

        ax.axvline(s, color='red', alpha=0.3, linewidth=1)

    nsw = int(A['switches'].sum())

    ax.set_title(f'{title}\nAdaptive: {nsw} switches', fontsize=10, fontweight='bold')

    ax.set_ylim(-0.02, 1.02)

    if col_idx == 0:

        ax.set_ylabel('π_w(k)')

        ax.legend(fontsize=8, loc='right')

    

    # Row 1: Product comparison (vanilla vs adaptive)

    ax = axes[1, col_idx]

    pv = np.convolve(V['product'], np.ones(W)/W, mode='valid')

    pa = np.convolve(A['product'], np.ones(W)/W, mode='valid')

    tt = np.arange(W-1, T)

    ax.plot(tt, pv, color='#d62728', linewidth=1.5, label='Vanilla η·‖F†‖', alpha=0.8)

    ax.plot(tt, pa, color='#2ca02c', linewidth=1.5, label='Adaptive η_eff·‖F†‖', alpha=0.8)

    ax.set_yscale('log')

    if col_idx == 0: ax.set_ylabel('η · ‖F†‖\n(log scale)')

    ax.legend(fontsize=8)

    

    # Row 2: SC invariant C(t)

    ax = axes[2, col_idx]

    C = R['C']

    for i in range(N_runs):

        ax.plot(C[i], color=accent, alpha=0.06, linewidth=0.5)

    ax.plot(np.mean(C, axis=0), color='black', linewidth=1.5, label='Mean')

    c_est = max(np.percentile(C[:, BURN:].flatten(), 5), 1e-8)

    ax.axhline(c_est, color='red', linestyle='--', linewidth=1,

               label=f'c = {c_est:.2f}')

    ax.set_xlabel('Step t')

    if col_idx == 0: ax.set_ylabel('C(t) = π_min/(1−GS)')

    ymax = np.percentile(C[:, BURN:].flatten(), 99)

    ax.set_ylim(bottom=0, top=min(ymax*1.3, max(ymax*1.3, 1)))

    ax.legend(fontsize=8)

 

plt.suptitle('η_eff cancellation holds in both uniform and concentrated regimes',

             fontsize=12, fontweight='bold', y=1.01)

plt.tight_layout()

plt.savefig('/home/claude/fig5_two_regimes.png', dpi=200, bbox_inches='tight')

plt.close()

print("\nFig 5 ✓")

 

# ═══════════════════════════════════

# Summary numbers for paper text

# ═══════════════════════════════════

print("\n" + "="*70)

print("NUMBERS FOR §6 TEXT")

print("="*70)

for rname in ['mild', 'conc']:

    R = results[rname]

    A, V = R['A'], R['V']

    piA = np.array(A['pi'].tolist())

    pm = piA[BURN:].min(axis=1).mean()

    px = piA[BURN:].max(axis=1).mean()

    pV = V['product'][BURN:].mean()

    pA = A['product'][BURN:].mean()

    dV = V['delta_norm'].max()

    dA = A['delta_norm'].max()

    c_est = max(np.percentile(R['C'][:, BURN:].flatten(), 5), 1e-8)

    C_mins = np.min(R['C'][:, BURN:], axis=1)

    margins = np.max(piA[BURN:], axis=1) - np.sort(piA[BURN:], axis=1)[:, -2]

    

    print(f"\n{rname.upper()}:")

    print(f"  π_max = {px:.3f}, π_min = {pm:.4f}")

    print(f"  Concentration: π_max/π_min = {px/max(pm,1e-10):.1f}×")

    print(f"  Vanilla: {np.mean(R['sw_V']):.1f}±{np.std(R['sw_V']):.1f} switches")

    print(f"  Adaptive: {np.mean(R['sw_A']):.1f}±{np.std(R['sw_A']):.1f} switches")

    print(f"  Switch reduction: {np.mean(R['sw_V'])/max(np.mean(R['sw_A']),0.1):.1f}×")

    print(f"  max ‖Δθ‖: vanilla={dV:.3f}, adaptive={dA:.4f} ({dV/max(dA,1e-10):.0f}×)")

    print(f"  mean η·‖F†‖: vanilla={pV:.2f}, adaptive={pA:.4f} ({pV/max(pA,1e-10):.0f}×)")

    print(f"  c = {c_est:.2f}")

    print(f"  C(t)>0 all runs: {np.sum(C_mins > 0)}/{N_runs}")

    print(f"  Δ_min = {margins.min():.4f}")