Source code for HadamardLangevin.utils

import pywt
import numpy as np
from numpy.fft import fft, ifft


[docs] def getWaveletTransforms(n,wavelet_type = "db2",level = 5, weight=1): mode = "periodization" coeffs_tpl = pywt.wavedec(data=np.zeros(n), wavelet=wavelet_type, mode=mode, level=level) coeffs_1d, coeff_slices, coeff_shapes = pywt.ravel_coeffs(coeffs_tpl) coeffs_tpl_rec = pywt.unravel_coeffs(coeffs_1d, coeff_slices, coeff_shapes) scaling_vec = np.zeros_like(coeffs_1d) for i,slice in enumerate(coeff_slices): if i==0: scaling_vec[slice] += weight**i else: scaling_vec[slice['d']] += weight**i def py_W(x): alpha = pywt.wavedec(data=x, wavelet=wavelet_type, mode=mode, level=level) alpha, _, _ = pywt.ravel_coeffs(alpha) return alpha def py_Ws(alpha): coeffs = pywt.unravel_coeffs(alpha, coeff_slices, coeff_shapes,output_format='wavedec') rec = pywt.waverec(coeffs, wavelet=wavelet_type, mode=mode) return rec return py_W, py_Ws,scaling_vec
[docs] def getWaveletTransforms_2D(n,m,wavelet_type = "db2",level = 5, weight=1): mode = "periodization" coeffs_tpl = pywt.wavedecn(data=np.zeros((n, m)), wavelet=wavelet_type, mode=mode, level=level) coeffs_1d, coeff_slices, coeff_shapes = pywt.ravel_coeffs(coeffs_tpl) coeffs_tpl_rec = pywt.unravel_coeffs(coeffs_1d, coeff_slices, coeff_shapes) scaling_vec = np.zeros_like(coeffs_1d) for i,slice in enumerate(coeff_slices): if i==0: scaling_vec[slice] += weight**i else: scaling_vec[slice['ad']] += weight**i scaling_vec[slice['da']] += weight**i scaling_vec[slice['dd']] += weight**i def py_W(im): alpha = pywt.wavedecn(data=im, wavelet=wavelet_type, mode=mode, level=level) alpha, _, _ = pywt.ravel_coeffs(alpha) return alpha def py_Ws(alpha): coeffs = pywt.unravel_coeffs(alpha, coeff_slices, coeff_shapes) im = pywt.waverecn(coeffs, wavelet=wavelet_type, mode=mode) return im return py_W, py_Ws, scaling_vec
# define filter
[docs] def GaussianFilter(s,n): x = np.hstack((np.arange(0,n//2), np.arange(-n//2,0))) h = np.exp( (-x**2)/(2*s**2) ) h = h/sum(h) return h
[docs] def GaussianFilter_2d(s,n,m): x = np.hstack((np.arange(0,n//2), np.arange(-n//2,0))) y = np.hstack((np.arange(0,m//2), np.arange(-m//2,0))) [X,Y] = np.meshgrid(y,x) h = np.exp( (-X**2-Y**2)/(2*s**2) ) h = h/sum(sum(h)) return h
[docs] def rFISTA(proxF, dG, gamma, xinit,niter,mfunc): tol = 1e-16 x = xinit z = x t=1 fval = [] for k in range(niter): xkm = x ykm = z x = proxF( z - gamma*dG(z), gamma ) tnew = (1+np.sqrt(1+4*t**2))/2 z = x + (t-1)/(tnew)*(x-xkm) t = tnew if np.sum((ykm-x)*(x-xkm))>0: z=x; fval.append(mfunc(x)) if np.linalg.norm(xkm-x)<tol: break return x, fval
[docs] def ISTA(proxF, dG, gamma, xinit,niter,mfunc): x = xinit fval = [] for k in range(niter): x = proxF( x - gamma*dG(x), gamma ) fval.append(mfunc(x)) return x, fval