Note
Go to the end to download the full example code
Recovering spikes via TNIPM¶
This example has following features:
A sparse signal consists of a small number of spikes.
The sensing matrix is a random dictionary with orthonormal rows.
The number of measurements is one fourth of ambient dimensions.
The measurements are corrupted by noise.
Truncated Newton Interior Points Method (TNIPM) a.k.a. l1-ls algorithm is being used for recovery.
Let’s import necessary libraries
import jax.numpy as jnp
from jax import random
norm = jnp.linalg.norm
import matplotlib as mpl
import matplotlib.pyplot as plt
import cr.sparse as crs
import cr.sparse.data as crdata
import cr.sparse.lop as lop
import cr.sparse.cvx.l1ls as l1ls
from cr.nimble.dsp import (
hard_threshold_by,
support,
largest_indices_by
)
Setup¶
# Number of measurements
m = 2**10
# Ambient dimension
n = 2**12
# Number of spikes (sparsity)
k = 160
print(f'{m=}, {n=}')
key = random.PRNGKey(0)
keys = random.split(key, 4)
m=1024, n=4096
The Spikes¶
xs, omega = crdata.sparse_spikes(keys[0], n, k)
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(xs)
[<matplotlib.lines.Line2D object at 0x7f1ffae2c340>]
The Sparsifying Basis¶
Measurement process¶
# Clean measurements
bs = A.times(xs)
# Noise
sigma = 0.01
noise = sigma * random.normal(keys[2], (m,))
# Noisy measurements
b = bs + noise
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(b)
[<matplotlib.lines.Line2D object at 0x7f1ff9ec6cd0>]
Recovery using TNIPM¶
# We need to estimate the regularization parameter
Atb = A.trans(b)
tau = float(0.1 * jnp.max(jnp.abs(Atb)))
print(f'{tau=}')
# Now run the solver
sol = l1ls.solve_jit(A, b, tau)
# number of L1-LS iterations
iterations = int(sol.iterations)
# number of A x operations
n_times = int(sol.n_times)
# number of A^H y operations
n_trans = int(sol.n_trans)
print(f'{iterations=} {n_times=} {n_trans=}')
# residual norm
r_norm = norm(sol.x)
print(f'{r_norm=:.3e}')
# relative error
rel_error = norm(xs - sol.x) / norm(xs)
print(f'{rel_error=:.3e}')
tau=0.05101185632358787
iterations=17 n_times=173 n_trans=174
r_norm=1.122e+01
rel_error=1.224e-01
Solution¶
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.subplot(211)
plt.plot(xs)
plt.subplot(212)
plt.plot(sol.x)
[<matplotlib.lines.Line2D object at 0x7f1ffabc9fa0>]
The magnitudes of non-zero values¶
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(jnp.sort(jnp.abs(sol.x)))
[<matplotlib.lines.Line2D object at 0x7f1ffa861fd0>]
Thresholding for large values¶
x = hard_threshold_by(sol.x, 0.5)
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.subplot(211)
plt.plot(xs)
plt.subplot(212)
plt.plot(x)
[<matplotlib.lines.Line2D object at 0x7f1ff9d4c730>]
Verifying the support recovery¶
Array(True, dtype=bool)
Improvement using least squares over support¶
# Identify the sub-matrix of columns for the support of recovered solution's large entries
support_x = largest_indices_by(sol.x, 0.5)
AI = A.columns(support_x)
print(AI.shape)
# Solve the least squares problem over these columns
x_I, residuals, rank, s = jnp.linalg.lstsq(AI, b)
# fill the non-zero entries into the sparse least squares solution
x_ls = jnp.zeros_like(xs)
x_ls = x_ls.at[support_x].set(x_I)
# relative error
ls_rel_error = norm(xs - x_ls) / norm(xs)
print(f'{ls_rel_error=:.3e}')
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.subplot(211)
plt.plot(xs)
plt.subplot(212)
plt.plot(x_ls)
(1024, 160)
ls_rel_error=2.070e-02
[<matplotlib.lines.Line2D object at 0x7f1ffa42d2b0>]
Total running time of the script: (0 minutes 5.898 seconds)