"""
SlimeTree-RLM: Two-Regime Comparison
Mild (near-uniform) vs Concentrated: cancellation holds in both.
Addresses: "Is SC just keeping π uniform?"
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({
'font.size': 11, 'font.family': 'serif',
'axes.grid': True, 'grid.alpha': 0.3, 'lines.linewidth': 1.5,
})
K = 3; T = 600; N_runs = 50; BURN = 60; W = 15
def softmax(theta):
t = theta - theta.max(); e = np.exp(t); return e / e.sum()
def run(mode, eta_max, g_fail, fail_rate, base_loss, seed=42):
rng = np.random.RandomState(seed)
theta = np.zeros(K)
log = dict(pi=[], pi_min=[], GS=[], eta=[], delta_norm=[],
F_dag_norm=[], C_t=[], product=[], switches=[], is_failure=[])
dom_hist = []
for t in range(T):
pi = softmax(theta); pi = np.clip(pi, 1e-12, None); pi /= pi.sum()
loss = base_loss + 0.1 * rng.randn(K)
is_f = rng.rand() < fail_rate
fail_sig = np.array([g_fail, -g_fail*0.4, -g_fail*0.6]) if is_f else np.zeros(K)
grad = pi * (loss - pi @ loss) + fail_sig
grad_T = grad - np.mean(grad)
ng = grad_T / pi
dom = np.argmax(pi); dom_hist.append(dom)
w = min(40, len(dom_hist))
GS = min(sum(1 for d in dom_hist[-w:] if d == dom) / w, 0.995)
eta = eta_max if mode == 'fixed' else eta_max * (1.0 - GS)
delta = eta * ng
pi_min = pi.min()
dom_before = np.argmax(pi)
log['pi'].append(pi.copy()); log['pi_min'].append(pi_min)
log['GS'].append(GS); log['eta'].append(eta)
log['delta_norm'].append(np.linalg.norm(delta))
log['F_dag_norm'].append(1.0/pi_min)
log['C_t'].append(pi_min / max(1-GS, 1e-10))
log['product'].append(eta / pi_min)
log['is_failure'].append(is_f)
theta = theta - np.clip(delta, -5, 5)
log['switches'].append(1 if np.argmax(softmax(theta)) != dom_before else 0)
for k in log: log[k] = np.array(log[k])
return log
# ─── Two regimes ───
regimes = {
'mild': dict(eta_max=0.10, g_fail=0.5, fail_rate=0.08,
base_loss=np.array([0.20, 0.55, 0.75])),
'conc': dict(eta_max=0.20, g_fail=0.7, fail_rate=0.10,
base_loss=np.array([0.10, 0.70, 1.00])),
}
results = {}
for rname, p in regimes.items():
print(f"\n--- {rname} ---")
V = run('fixed', **p, seed=42)
A = run('adaptive', **p, seed=42)
sw_V_all, sw_A_all, C_all = [], [], []
for s in range(N_runs):
rv = run('fixed', **p, seed=3000+s)
ra = run('adaptive', **p, seed=3000+s)
sw_V_all.append(int(rv['switches'].sum()))
sw_A_all.append(int(ra['switches'].sum()))
C_all.append(ra['C_t'])
C_all = np.array(C_all)
results[rname] = dict(V=V, A=A, sw_V=sw_V_all, sw_A=sw_A_all, C=C_all)
piA = np.array(A['pi'].tolist())
pm = piA[BURN:].min(axis=1).mean()
px = piA[BURN:].max(axis=1).mean()
pV = V['product'][BURN:].mean()
pA = A['product'][BURN:].mean()
print(f" π_max={px:.3f}, π_min={pm:.4f}, ratio={px/max(pm,1e-10):.1f}×")
print(f" Vanilla: {np.mean(sw_V_all):.1f}±{np.std(sw_V_all):.1f} sw, prod={pV:.2f}")
print(f" Adaptive: {np.mean(sw_A_all):.1f}±{np.std(sw_A_all):.1f} sw, prod={pA:.4f}")
print(f" Reduction: {pV/max(pA,1e-10):.0f}×")
# ═══════════════════════════════════
# FIGURE 5: Two-regime comparison (2×3 grid)
# ═══════════════════════════════════
fig, axes = plt.subplots(3, 2, figsize=(12, 10))
names = ['Recall', 'Explore', 'Consolidate']
cols_k = ['#1f77b4', '#ff7f0e', '#2ca02c']
regime_info = {
'mild': ('Mild regime\n(near-uniform, π_max/π_min ≈ 1.2)', '#2ca02c'),
'conc': ('Concentrated regime\n(π_max ≈ 0.54, π_min ≈ 0.17)', '#ff7f0e'),
}
for col_idx, rname in enumerate(['mild', 'conc']):
R = results[rname]
A, V = R['A'], R['V']
title, accent = regime_info[rname]
# Row 0: π trajectories (adaptive)
ax = axes[0, col_idx]
pi_arr = np.array(A['pi'].tolist())
for k in range(K):
ax.plot(pi_arr[:, k], color=cols_k[k], alpha=0.8, linewidth=1.2,
label=f'π({names[k]})' if col_idx == 0 else None)
for s in np.where(A['switches'] == 1)[0]:
ax.axvline(s, color='red', alpha=0.3, linewidth=1)
nsw = int(A['switches'].sum())
ax.set_title(f'{title}\nAdaptive: {nsw} switches', fontsize=10, fontweight='bold')
ax.set_ylim(-0.02, 1.02)
if col_idx == 0:
ax.set_ylabel('π_w(k)')
ax.legend(fontsize=8, loc='right')
# Row 1: Product comparison (vanilla vs adaptive)
ax = axes[1, col_idx]
pv = np.convolve(V['product'], np.ones(W)/W, mode='valid')
pa = np.convolve(A['product'], np.ones(W)/W, mode='valid')
tt = np.arange(W-1, T)
ax.plot(tt, pv, color='#d62728', linewidth=1.5, label='Vanilla η·‖F†‖', alpha=0.8)
ax.plot(tt, pa, color='#2ca02c', linewidth=1.5, label='Adaptive η_eff·‖F†‖', alpha=0.8)
ax.set_yscale('log')
if col_idx == 0: ax.set_ylabel('η · ‖F†‖\n(log scale)')
ax.legend(fontsize=8)
# Row 2: SC invariant C(t)
ax = axes[2, col_idx]
C = R['C']
for i in range(N_runs):
ax.plot(C[i], color=accent, alpha=0.06, linewidth=0.5)
ax.plot(np.mean(C, axis=0), color='black', linewidth=1.5, label='Mean')
c_est = max(np.percentile(C[:, BURN:].flatten(), 5), 1e-8)
ax.axhline(c_est, color='red', linestyle='--', linewidth=1,
label=f'c = {c_est:.2f}')
ax.set_xlabel('Step t')
if col_idx == 0: ax.set_ylabel('C(t) = π_min/(1−GS)')
ymax = np.percentile(C[:, BURN:].flatten(), 99)
ax.set_ylim(bottom=0, top=min(ymax*1.3, max(ymax*1.3, 1)))
ax.legend(fontsize=8)
plt.suptitle('η_eff cancellation holds in both uniform and concentrated regimes',
fontsize=12, fontweight='bold', y=1.01)
plt.tight_layout()
plt.savefig('/home/claude/fig5_two_regimes.png', dpi=200, bbox_inches='tight')
plt.close()
print("\nFig 5 ✓")
# ═══════════════════════════════════
# Summary numbers for paper text
# ═══════════════════════════════════
print("\n" + "="*70)
print("NUMBERS FOR §6 TEXT")
print("="*70)
for rname in ['mild', 'conc']:
R = results[rname]
A, V = R['A'], R['V']
piA = np.array(A['pi'].tolist())
pm = piA[BURN:].min(axis=1).mean()
px = piA[BURN:].max(axis=1).mean()
pV = V['product'][BURN:].mean()
pA = A['product'][BURN:].mean()
dV = V['delta_norm'].max()
dA = A['delta_norm'].max()
c_est = max(np.percentile(R['C'][:, BURN:].flatten(), 5), 1e-8)
C_mins = np.min(R['C'][:, BURN:], axis=1)
margins = np.max(piA[BURN:], axis=1) - np.sort(piA[BURN:], axis=1)[:, -2]
print(f"\n{rname.upper()}:")
print(f" π_max = {px:.3f}, π_min = {pm:.4f}")
print(f" Concentration: π_max/π_min = {px/max(pm,1e-10):.1f}×")
print(f" Vanilla: {np.mean(R['sw_V']):.1f}±{np.std(R['sw_V']):.1f} switches")
print(f" Adaptive: {np.mean(R['sw_A']):.1f}±{np.std(R['sw_A']):.1f} switches")
print(f" Switch reduction: {np.mean(R['sw_V'])/max(np.mean(R['sw_A']),0.1):.1f}×")
print(f" max ‖Δθ‖: vanilla={dV:.3f}, adaptive={dA:.4f} ({dV/max(dA,1e-10):.0f}×)")
print(f" mean η·‖F†‖: vanilla={pV:.2f}, adaptive={pA:.4f} ({pV/max(pA,1e-10):.0f}×)")
print(f" c = {c_est:.2f}")
print(f" C(t)>0 all runs: {np.sum(C_mins > 0)}/{N_runs}")
print(f" Δ_min = {margins.min():.4f}")