import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from context import samplers as samplers
np.random.seed(10)
# ── helpers ──────────────────────────────────────────────────────────────────
def soft(x, tau):
return np.sign(x) * np.maximum(np.abs(x) - tau, 0)
def make_prox_kernel(tau, grad, lam, gamma, p, beta=1):
grad_F = lambda x: grad(x) + (x - soft(x, lam * gamma)) / gamma
return lambda z: samplers.one_step_langevin(z, p, grad_F, tau, beta)
def make_hadamard_kernel(tau, grad, lam, p, beta=1):
return lambda z: samplers.one_step_hadamard(z, p, grad, tau, lam, beta)
def project_hadamard(z, p):
return z[:p] * z[p:]
def acf(arr, max_lag=60):
arr = np.asarray(arr, dtype=float)
n = len(arr)
mu = arr.mean()
v = arr - mu
den = (v ** 2).sum()
return np.array([(v[:n-k] * v[k:]).sum() / den for k in range(max_lag + 1)])
def ess(arr, max_lag=None):
arr = np.asarray(arr, dtype=float)
if max_lag is None:
max_lag = min(300, len(arr) // 3)
a = acf(arr, max_lag)
cutoff = np.argmax(a < 0.05)
if cutoff == 0:
cutoff = max_lag
rho_sum = 1 + 2 * a[1:cutoff].sum()
return len(arr) / rho_sum
# ── experiment setup ──────────────────────────────────────────────────────────
# Target: Laplace prior, no likelihood (pure l1)
p = 2
lam = 1.0
tau_h = 0.05 # HLD step size
tau_p = 0.01 # Prox-ULA step size (needs to be smaller)
gamma = 0.5 / lam*2 # Moreau envelope parameter for prox
gamma= .1
tau_p = gamma/5/(gamma+1)
grad_G = lambda x: np.zeros(p) # pure Laplace prior, G = 0
N_samples = 50_000
burn_in = 10_000
# ── run HLD ──────────────────────────────────────────────────────────────────
hld_kernel = make_hadamard_kernel(tau_h, grad_G, lam, p)
init_hld = np.concatenate([np.abs(np.random.randn(p)) + 0.5,
np.random.randn(p)])
# collect full (u,v) chain so we can plot in lifted space
uv_chain = samplers.generate_samples_x(hld_kernel, init_hld, N_samples, burn_in)
x_hld = uv_chain[:, :p] * uv_chain[:, p:] # project: x = u * v
u_chain = uv_chain[:, :p]
v_chain = uv_chain[:, p:]
# ── run Prox-ULA ─────────────────────────────────────────────────────────────
prox_kernel = make_prox_kernel(tau_p, grad_G, lam, gamma, p)
init_prox = np.random.randn(p)
x_prox = samplers.generate_samples_x(prox_kernel, init_prox, N_samples, burn_in)
# ── diagnostics ──────────────────────────────────────────────────────────────
ess_h = ess(x_hld[:, 0])
ess_p = ess(x_prox[:, 0])
sparsity_h = np.mean(np.abs(x_hld[:, 0]) < 0.1) * 100
sparsity_p = np.mean(np.abs(x_prox[:, 0]) < 0.1) * 100
print(f"ESS x1 | HLD: {ess_h:.0f} Prox-ULA: {ess_p:.0f}")
print(f"Sparsity| HLD: {sparsity_h:.1f}% Prox-ULA: {sparsity_p:.1f}%")
acf_h = acf(x_hld[:, 0], 60)
acf_p = acf(x_prox[:, 0], 60)
# ── plotting ─────────────────────────────────────────────────────────────────
PURPLE = "#534AB7"
CORAL = "#D85A30"
TEAL = "#1D9E75"
GRAY = "#888780"
fig = plt.figure(figsize=(14, 9))
gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.42, wspace=0.38)
# ── 1. Marginal histogram (x1) ────────────────────────────────────────────
ax1 = fig.add_subplot(gs[0, :2])
bins = np.linspace(-6, 6, 100)
ax1.hist(x_hld[:, 0], bins=bins, density=True,
color=PURPLE, alpha=0.55, label="HLD")
ax1.hist(x_prox[:, 0], bins=bins, density=True,
color=CORAL, alpha=0.45, label="Prox-ULA")
xs_true = np.linspace(-6, 6, 300)
ax1.plot(xs_true, lam / 2 * np.exp(-lam * np.abs(xs_true)),
color=TEAL, lw=2, label=r"True $\pi(x_1)$")
ax1.set_xlabel(r"$x_1$", fontsize=11)
ax1.set_ylabel("density", fontsize=11)
ax1.set_title(r"Marginal histogram — $x_1$", fontsize=12)
ax1.legend(fontsize=10)
ax1.set_xlim(-6, 6)
# annotation
ax1.text(0.97, 0.92,
f"ESS HLD={ess_h:.0f} Prox={ess_p:.0f}",
transform=ax1.transAxes, ha="right", fontsize=9,
color=GRAY, style="italic")
# ── 2. ACF ───────────────────────────────────────────────────────────────────
ax2 = fig.add_subplot(gs[0, 2])
lags = np.arange(len(acf_h))
ax2.plot(lags, acf_h, color=PURPLE, lw=1.8, label="HLD")
ax2.plot(lags, acf_p, color=CORAL, lw=1.8, label="Prox-ULA")
ax2.axhline(0.05, color=GRAY, lw=0.8, ls="--")
ax2.set_xlabel("lag", fontsize=11)
ax2.set_ylabel("ACF", fontsize=11)
ax2.set_title(r"ACF — $x_1$", fontsize=12)
ax2.set_ylim(-0.15, 1.05)
ax2.legend(fontsize=10)
# ── 3. x = u⊙v scatter ───────────────────────────────────────────────────────
xx, yy = np.meshgrid(np.linspace(-6,6,200), np.linspace(-6,6,200))
Z_xy = np.abs(xx) +np.abs(yy)
ax3 = fig.add_subplot(gs[1, 0])
idx = np.random.choice(N_samples, size=min(2000, N_samples), replace=False)
ax3.contour(xx, yy, ( -Z_xy), levels=10, cmap='gray_r', alpha=0.3)
ax3.scatter(x_hld[idx, 0], x_hld[idx, 1],
s=4, color=PURPLE, alpha=0.35, linewidths=0)
ax3.set_xlabel(r"$x_1$", fontsize=11)
ax3.set_ylabel(r"$x_2$", fontsize=11)
ax3.set_title(r"Samples $x = u \odot v$ (HLD)", fontsize=12)
ax3.axhline(0, color=GRAY, lw=0.5, ls="--")
ax3.axvline(0, color=GRAY, lw=0.5, ls="--")
ax3.set_xlim(-6, 6)
ax3.set_ylim(-6, 6)
# ── 4. Lifted space (u1, v1) coloured by |x1| ────────────────────────────────
uu, vv = np.meshgrid(np.linspace(0.001,4,200), np.linspace(-4,4,200))
X = uu * vv
Z_uv = 1*( 0.5* lam*uu**2 + 0.5* lam*vv**2) - np.log(uu)
ax4 = fig.add_subplot(gs[1, 1])
x1_abs = np.abs(x_hld[idx, 0])
norm = Normalize(vmin=0, vmax=np.percentile(x1_abs, 95))
sc = ax4.scatter(u_chain[idx, 0], v_chain[idx, 0],
c=x1_abs, cmap="plasma", s=5, alpha=0.5,
norm=norm, linewidths=0)
ax4.contour(uu, vv, ( -Z_uv), levels=10, cmap='gray_r', alpha=0.3)
plt.colorbar(sc, ax=ax4, label=r"$|x_1|$", pad=0.02)
ax4.set_xlabel(r"$u_1$ $(u_1 > 0)$", fontsize=11)
ax4.set_ylabel(r"$v_1$", fontsize=11)
ax4.set_title(r"Lifted space $(u_1, v_1)$", fontsize=12)
ax4.axhline(0, color=GRAY, lw=0.5, ls="--")
# ── 5. Sign-quadrant visitation ───────────────────────────────────────────────
ax5 = fig.add_subplot(gs[1, 2])
signs = np.sign(x_hld)
quadrants = {
r"$(+,+)$": np.mean((signs[:, 0] > 0) & (signs[:, 1] > 0)),
r"$(+,-)$": np.mean((signs[:, 0] > 0) & (signs[:, 1] < 0)),
r"$(-,+)$": np.mean((signs[:, 0] < 0) & (signs[:, 1] > 0)),
r"$(-,-)$": np.mean((signs[:, 0] < 0) & (signs[:, 1] < 0)),
}
labels = list(quadrants.keys())
vals = [v * 100 for v in quadrants.values()]
bars = ax5.bar(labels, vals, color=PURPLE, alpha=0.75, width=0.5)
ax5.axhline(25, color=TEAL, lw=1.5, ls="--", label="true (25% each)")
ax5.set_ylabel("% time", fontsize=11)
ax5.set_title("Sign-quadrant visitation (HLD)", fontsize=12)
ax5.set_ylim(0, 40)
ax5.legend(fontsize=10)
for bar, v in zip(bars, vals):
ax5.text(bar.get_x() + bar.get_width() / 2, v + 0.5,
f"{v:.1f}%", ha="center", va="bottom", fontsize=9)
# ── save ──────────────────────────────────────────────────────────────────────
fig.suptitle("Hadamard Langevin Dynamics — sampling diagnostics", fontsize=13, y=1.01)
plt.savefig("hld_diagnostics.pdf", bbox_inches="tight")
plt.savefig("hld_diagnostics.png", dpi=150, bbox_inches="tight")
plt.show()
print("Saved hld_diagnostics.pdf and hld_diagnostics.png")