2.2. Basis Pursuit via ADMM

We are given the measurements \(b = Ax \).

We solve the problem:

\[ \tag{BP} {\min}_{x} \|x\|_{1} \; \text{s.t.} \, A x = b \]

We will work with a sensing matrix \(A\) of size \(M \times N\) which consists of orthonormal rows.

%load_ext autoreload
%autoreload 2
from jax.config import config
config.update("jax_enable_x64", True)
from jax import jit
import jax.numpy as jnp
import numpy as np
np.set_printoptions(precision=6)
from jax.numpy.linalg import norm
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import cr.sparse as crs
import cr.sparse.dict as crdict
import cr.sparse.data as crdata
from cr.sparse import lop
from cr.sparse.cvx.adm import yall1
# Problem size
M = 2000
N = 20000
K = 200
# Dictionary Setup
A = crdict.random_orthonormal_rows(crs.KEYS[0],M, N)
fig=plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.imshow(A, extent=[0, 2, 0, 1])
plt.gray()
plt.colorbar()
plt.title(r'$A$');
../../../../_images/basis_pursuit_7_0.png
x, omega = crdata.sparse_normal_representations(crs.KEYS[1], N, K)
fig=plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.stem(x, markerfmt='.');
/tmp/ipykernel_31134/4035688382.py:3: UserWarning: In Matplotlib 3.3 individual lines on a stem plot will be added as a LineCollection instead of individual lines. This significantly improves the performance of a stem plot. To remove this warning and switch to the new behaviour, set the "use_line_collection" keyword argument to True.
  plt.stem(x, markerfmt='.');
../../../../_images/basis_pursuit_8_1.png
# Convert A into a linear operator
T = lop.real_matrix(A)
T = lop.jit(T)
# Compute the measurements
b0 = T.times(x)
fig=plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.stem(b0, markerfmt='.');
/tmp/ipykernel_31134/3213536479.py:2: UserWarning: In Matplotlib 3.3 individual lines on a stem plot will be added as a LineCollection instead of individual lines. This significantly improves the performance of a stem plot. To remove this warning and switch to the new behaviour, set the "use_line_collection" keyword argument to True.
  plt.stem(b0, markerfmt='.');
../../../../_images/basis_pursuit_11_1.png
sol = yall1.solve(T, b0)
int(sol.iterations), int(sol.n_times), int(sol.n_trans)
(58, 117, 60)
norm(sol.x-x)/norm(x)
DeviceArray(0.016301, dtype=float64)
fig=plt.figure(figsize=(8,7), dpi= 100, facecolor='w', edgecolor='k')
plt.subplot(211)
plt.title('original')
plt.stem(x, markerfmt='.', linefmt='gray');
plt.subplot(212)
plt.stem(sol.x, markerfmt='.');
plt.title('reconstruction');
/tmp/ipykernel_31134/2976149975.py:4: UserWarning: In Matplotlib 3.3 individual lines on a stem plot will be added as a LineCollection instead of individual lines. This significantly improves the performance of a stem plot. To remove this warning and switch to the new behaviour, set the "use_line_collection" keyword argument to True.
  plt.stem(x, markerfmt='.', linefmt='gray');
/tmp/ipykernel_31134/2976149975.py:6: UserWarning: In Matplotlib 3.3 individual lines on a stem plot will be added as a LineCollection instead of individual lines. This significantly improves the performance of a stem plot. To remove this warning and switch to the new behaviour, set the "use_line_collection" keyword argument to True.
  plt.stem(sol.x, markerfmt='.');
../../../../_images/basis_pursuit_15_2.png
%timeit yall1.solve(T, b0).x.block_until_ready()
194 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
b, x0, z0, w, b_max, n_times, n_trans = yall1.bp_setup(T, b0)
state = yall1.solve_bp_jit(T, b, x0, z0, w, nonneg=False, gamma=1., tolerance=5e-3, max_iters=9999)
sol = yall1.finalize(state, b_max, n_times, n_trans)
int(sol.iterations), int(sol.n_times), int(sol.n_trans)
(58, 117, 60)
norm(sol.x-x)/norm(x)
DeviceArray(0.016301, dtype=float64)
state = yall1.solve_bp_jit(T, b, x0, z0, w, nonneg=False, gamma=1., tolerance=5e-3, max_iters=9999)
%timeit yall1.solve_bp_jit(T, b, x0, z0, w, nonneg=False, gamma=1., tolerance=5e-3, max_iters=500)
193 ms ± 570 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
jnp.arange(4) @ jnp.ones((4, 4))
DeviceArray([6., 6., 6., 6.], dtype=float64)
1542.85 / 445
3.467078651685393