Demonstrating Hadamard langevin

[1]:
import numpy as np
from context import samplers as samplers
import matplotlib.pyplot as plt

Create matrix \(A\in \mathbb{R}^{m\times p}\) and measurements \(y\in\mathbb{R}^n\). We will sample from

\[\pi\propto \exp\left(-\beta\left(\lambda \|x\|_1 + \frac12\|Ax-y\|^2\right)\right)\]

where \(\lambda>0\) and \(\beta>0\).

[2]:
p = 20#signal dimension
m = 40 #number of measurements

#make x0 sparse
x0 = np.sign(np.random.randn(p,))
x0[np.random.permutation(p)[:9*p//10] ]=np.zeros(9*p//10,)
plt.stem(x0)

A = np.random.randn(m,p)/np.sqrt(m)/4
# A = np.array([[1,2]]).T
# print(A)
y = A@x0

lam = np.max(np.abs(A.T@y))*.1
# lam = .5
beta = 5.

print('lambda ', lam, ', beta ', beta)

#smooth part of log density and its gradient
fval = lambda x: np.linalg.norm(A@x - y)**2 * 0.5
grad = lambda x: A.T@(A@x - y)


lambda  0.005524786083441602 , beta  5.0
../_images/vignettes_3_synthetic_3_1.png

Proximal Langevin

[3]:
import time


#Lipschitz constant of smooth part
Lf = np.linalg.norm(A.T@A, 2)

#parameter for moreau envelope regularization
gamma= 1/Lf/5
# gamma= 1/Lf/100

#stepsize
tau = gamma/5/(Lf*gamma+1)
print(tau)

# We will record the results in the following lists
method_list = []
sample_list = []
time_list = []


#number of samples and burn in
n_exp = 5
m_exp = 4
n = 10**n_exp #number of samples  to generate
burn_in = 10**m_exp #number of burn in samples
0.1928833042471887
[4]:
#defined the gradient of  Moreau-envelope of the density
def soft(x,tau):
    return np.sign(x)*np.maximum(np.abs(x)-tau,0)
grad_F = lambda x:  grad(x) + (x - soft(x,lam*gamma))/gamma

#Run proximal langevin
print('Running prox-l1 sampler ...\n')
name = 'Prox-l1'
Iterate = lambda x: samplers.one_step_langevin(x,p, grad_F, tau,beta)
xinit = np.random.randn(p,)

t1 = time.time()
samples_proxl1 = samplers.generate_samples_x(Iterate, xinit, n, burn_in)
time_list.append(time.time()-t1)

method_list.append(name)
sample_list.append(samples_proxl1)
Running prox-l1 sampler ...

[79]:
#Run proximal langevin with metropolis hastings correction
print('Running Proxl1 MALA sampler ...\n')
name = 'proxl1-MALA'

Iterate = lambda x: samplers.one_step_MALA(x, p, lambda x: fval(x) + lam*np.sum(np.abs(x)), grad_F, tau,beta)

t1 = time.time()
samples_px_mala = samplers.generate_samples_x(Iterate, xinit, n, burn_in)
time_list.append(time.time()-t1)

method_list.append(name)
sample_list.append(samples_px_mala)
Running Proxl1 MALA sampler ...

Hadamard Langevin

[5]:
n = 10**n_exp #number of samples  to generate
burn_in = 10**m_exp #number of burn in samples

#we will use the same stepsize as for proximal langevin
print('Running Hadamard sampler ... \n')
name = 'Hadamard'
Iterate = lambda x: samplers.one_step_hadamard(x, p,grad, tau, lam,beta)

#initialization
uinit = np.random.rand(p,)
vinit = np.random.randn(p,)
uvinit = np.concatenate((uinit,vinit))

t1 = time.time()
samples_uv = samplers.generate_samples_x(Iterate, uvinit, n, burn_in)

#take the Hadamard product for samples to the required density
samples_x_uv = samples_uv[:,:p]*samples_uv[:,p:]

time_list.append(time.time()-t1)


method_list.append(name)

sample_list.append(samples_x_uv)

Running Hadamard sampler ...

[25]:
#Overparameterized langevin with metropolis hastings
print('Running Hadamard MALA sampler ...\n')
name = 'Hadamard-MALA'
Iterate = lambda x: samplers.one_step_MALA_hadamard(x,p, fval, grad, tau, lam,beta)
uvinit = np.concatenate((uinit,vinit))

t1 = time.time()
samples_uv_mala = samplers.generate_samples_x(Iterate, uvinit, n, burn_in)
samples_x_mala = samples_uv_mala[:,:p]*samples_uv_mala[:,p:]
time_list.append(time.time()-t1)

method_list.append(name)
sample_list.append(samples_x_mala)


Running Hadamard MALA sampler ...

Running the Gibbs sampler

[6]:
print('Running Gibbs sampler ... \n')
name = 'Gibbs'
n = 10**(n_exp-1) #number of samples  to generate
burn_in = 10**(m_exp-3) #number of burn in samples

init = np.ones(p,)
t1 = time.time()
gibbs_x_samples = samplers.gibbs_sampler(A,y, lam,init, n, burn_in=burn_in, beta=beta)
# gibbs_x_samples = gibbs_samples[:,:p]
time_list.append(time.time()-t1)


method_list.append(name)
sample_list.append(gibbs_x_samples)


Running Gibbs sampler ...

Visualize the results

[15]:
# Plot the running means
M = len(method_list)

from itertools import cycle

colours = cycle(['red', 'blue', 'black', 'orange','yellow','orange'])

# mean_true = integrate.quad(lambda x: x**2*f(x), -np.inf, np.inf)[0]/Total_mass

for i in range(M):
    samples = sample_list[i]
    name = method_list[i]
    running_mean = np.cumsum(samples, axis=0) / (np.arange(len(samples))[:, None] + 1)
    c = next(colours)
    print(name)
    if i<M-1:
        plt.plot(running_mean,color = c,label=name,alpha=1)
    else:
        plt.plot(running_mean,color = c,label=name,alpha=.9)
    plt.ylim([-3,3])
    plt.savefig('results/'+name+'means.pdf', bbox_inches='tight')
    plt.show()

# plt.legend()
# plt.xlim([0,4000])
# plt.show()
# plt.savefig('results/means.pdf', bbox_inches='tight')
Prox-l1
../_images/vignettes_3_synthetic_14_1.png
Hadamard
../_images/vignettes_3_synthetic_14_3.png
Gibbs
../_images/vignettes_3_synthetic_14_5.png

Check the effective sample size

[28]:
#effective sample size
import numpy as np
import arviz as az
import matplotlib.pyplot as plt


# Compute Effective Sample Size (ESS)

idata_list = []
ess_list= []
for samples in sample_list:

    idata = az.convert_to_inference_data(np.expand_dims(samples, 0))
    idata_list.append(idata)
    ess_list.append(az.ess(idata))




# Visualization
labels = [str(i) for i in range(p)]

fig, ax = plt.subplots()

M = len(method_list)
width = 0.2  # the width of the bars

x = np.arange(p)-width
for i in range(M):
    ess_values = np.log(ess_list[i]['x'].values)
    rects1 = ax.bar(x +i*width, ess_values, width, label=method_list[i])


# # Add some text for labels, title, and custom x-axis tick labels, etc.
ax.set_ylabel('ESS',fontsize=16)
# ax.set_title('Effective Sample Size by Algorithm and dimension')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend(fontsize=16)
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=16)

fig.tight_layout()
plt.savefig('results/ESS.pdf', bbox_inches='tight')
# plt.xlabel('Dimension')
plt.show()

#display minimum ess
# Visualization
if p>1:
    labels = [str(i) for i in range(p)]

    fig, ax = plt.subplots()

    M = len(method_list)
    width = 0.2  # the width of the bars

    x = np.arange(1)-width
    xticks=[]
    for i in range(M):
        print(method_list[i])
        ess_values = np.log((ess_list[i]['x'].values).min())
        print(ess_values)
        rects1 = ax.bar(x +i*width, ess_values, width, label=method_list[i])
        xticks.append(x[0] +i*width)

    ax.set_ylabel('ESS',fontsize=16)
    # ax.set_title('Minimum Effective Sample Size by Algorithm')
    ax.set_xticks(xticks)
    ax.set_xticklabels(method_list,fontsize=16)
    ax.tick_params(axis='x', labelsize=16)
    ax.tick_params(axis='y', labelsize=16)

    fig.tight_layout()
    plt.savefig('results/ESS_min.pdf', bbox_inches='tight')

    # plt.xlabel('Dimension')
    plt.show()
../_images/vignettes_3_synthetic_16_0.png
Prox-l1
4.005851575122584
Hadamard
6.36752302816277
Gibbs
9.155003688288497
../_images/vignettes_3_synthetic_16_2.png

Plot the densities if in dimension 1 or 2

[198]:
from scipy import integrate
#For dimension p=1 and p=2, get true density via numerical integration
if p==1:
    def f(x):
        return np.exp(-(fval(np.array([x]))+lam*np.abs(x))*beta)
    Total_mass = integrate.quad(f, -np.inf, np.inf)[0]
    # print(Total_mass)
    density = lambda x: f(x)/Total_mass

    mean_true = integrate.quad(lambda x: x*f(x), -np.inf, np.inf)[0]/Total_mass
    def rho(xgrid):
        return np.array([density(t) for t in xgrid])

if p==2:
    def f(x1,x2):
        return np.exp(-beta*( fval(np.array([x1,x2]))+lam*np.abs(x1)-lam*np.abs(x2)) )
    Total_mass = integrate.dblquad(f, -np.inf, np.inf, -np.inf, np.inf)[0]
    density = lambda x1,x2: f(x1,x2)/Total_mass

    mean_true = integrate.dblquad(lambda x1,x2: x1**2*f(x1,x2), -np.inf, np.inf, -np.inf, np.inf)[0]/Total_mass

    def rho(xgrid,ygrid):
        return np.array([density(t,s) for t in xgrid for s in ygrid])


[ ]:
import matplotlib.pyplot as plt
from scipy.special import kl_div as kl
from matplotlib import cm
#functions to evaluate the KL error in diemnsion p=1 and p=2

def KL_error(samples, rho, plot=True):
    counts, bins = np.histogram(samples,100,density=True)
    xgrid = (bins[1:] + bins[:-1])/2
    if plot:
        plt.plot(xgrid, counts)
        xgrid_fine = np.linspace(min(bins), max(bins),1000)
        plt.plot(xgrid_fine,rho(xgrid_fine), 'r--')
    return np.sum(kl(counts,rho(xgrid)))

def KL_error_2d(samples, rho, plot=True):
    Z, xbins,ybins = np.histogram2d(samples[:,0], samples[:,1],bins=100,density=True)
    xgrid = (xbins[1:] + xbins[:-1])/2
    ygrid = (ybins[1:] + ybins[:-1])/2

    Z2 = rho(xgrid,ygrid)
    if plot:
        ax = plt.figure().add_subplot(projection='3d')

        X,Y = np.meshgrid(xgrid,ygrid)
        surf = ax.plot_surface(X, Y, Z, alpha=0.8, cmap = cm.coolwarm,
                           linewidth=0, antialiased=False)


        surf2 = ax.plot_surface(X, Y, Z2.reshape(len(xgrid),len(ygrid)), alpha=0.1, cmap = cm.jet,
                       linewidth=0, antialiased=False)
    # ax.set(xlim=[-20, 20], ylim=[-5, 15])
    return np.sum(kl(Z.reshape(-1),Z2))/len(Z2)

[ ]:
print(time_list)

colors = ['k--','k', 'g--', 'g', 'm']
print('Compute errors and plot (p<=2 only) ...\n')
if p==1:

    def get_running_error(samples, rho):
        error_list = []
        skip = len(samples)//50
        for i in range(100,len(samples),skip):
            error = KL_error(samples[:i],rho,plot=False)
            error_list.append(error)
        return error_list

    def get_running_mean_error(samples, rho):
        error_list = []
        skip = len(samples)//50
        for i in range(100,len(samples),skip):
            error = np.abs(np.mean(samples[:i])-mean_true)
            error_list.append(error)
        return error_list

    error_list = []
    for samples,name,t in zip(sample_list,method_list,time_list):
        error_list.append(get_running_error(samples, rho))

    for errors, name,t, c in zip(error_list,method_list,time_list,colors):
        plt.semilogy(np.linspace(0,t,len(errors)), errors,c, label=name)

    plt.legend()
    plt.savefig('results/Error_Time.pdf', bbox_inches='tight')
    plt.title('KL error against time')
    plt.show()

    for i in range(len(sample_list)):
        S = len(sample_list[i])
        errors= error_list[i]
        name = method_list[i]
        plt.semilogy(np.linspace(0,S,len(errors)), errors,colors[i], label=name)

    plt.legend()
    plt.savefig('results/Error.pdf', bbox_inches='tight')
    plt.title('KL error against iteration')
    plt.show()

    for samples,name in zip(sample_list,method_list):
        print('Error ',name, ' :', KL_error(samples,rho,plot=True))
        # plt.xlim([-5,5])
        plt.savefig('results/density-'+name+'.pdf', bbox_inches='tight')
        plt.title(name)
        plt.show()

elif p==2:
    def get_running_error(samples, rho):
        error_list = []
        skip = len(samples)//20
        for i in range(100,len(samples),skip):
            error = KL_error_2d(samples[:i,:],rho,plot=False)
            error_list.append(error)
        return error_list


    error_list = []
    for samples,name,t in zip(sample_list,method_list,time_list):
        error_list.append(get_running_error(samples, rho))

    for errors, name,t in zip(error_list,method_list,time_list):
        plt.semilogy(np.linspace(0,t,len(errors)), errors, label=name)

    plt.legend()
    plt.savefig('results/Error_Time.pdf', bbox_inches='tight')
    plt.show()

    for i in range(len(sample_list)):
        S = len(sample_list[i])
        errors= error_list[i]
        name = method_list[i]
        plt.semilogy(np.linspace(0,S,len(errors)), errors, label=name)

    plt.legend()
    plt.savefig('results/Error.pdf', bbox_inches='tight')
    plt.show()
[ ]:

[ ]:

[ ]:

[ ]:

[ ]:

[ ]: