Recovering spikes via TNIPM

This example has following features:

  • A sparse signal consists of a small number of spikes.

  • The sensing matrix is a random dictionary with orthonormal rows.

  • The number of measurements is one fourth of ambient dimensions.

  • The measurements are corrupted by noise.

  • Truncated Newton Interior Points Method (TNIPM) a.k.a. l1-ls algorithm is being used for recovery.

Let’s import necessary libraries

import jax.numpy as jnp
from jax import random
norm = jnp.linalg.norm

import matplotlib as mpl
import matplotlib.pyplot as plt

import cr.sparse as crs
import cr.sparse.data as crdata
import cr.sparse.lop as lop
import cr.sparse.cvx.l1ls as l1ls
from cr.nimble.dsp import (
  hard_threshold_by,
  support,
  largest_indices_by
)

Setup

# Number of measurements
m = 2**10
# Ambient dimension
n  = 2**12
# Number of spikes (sparsity)
k = 160
print(f'{m=}, {n=}')

key = random.PRNGKey(0)
keys = random.split(key, 4)
m=1024, n=4096

The Spikes

xs, omega = crdata.sparse_spikes(keys[0], n, k)
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(xs)
spikes l1ls
[<matplotlib.lines.Line2D object at 0x7f1ffae2c340>]

The Sparsifying Basis

A = lop.random_orthonormal_rows_dict(keys[1], m, n)

Measurement process

# Clean measurements
bs = A.times(xs)
# Noise
sigma = 0.01
noise = sigma * random.normal(keys[2], (m,))
# Noisy measurements
b = bs + noise
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(b)
spikes l1ls
[<matplotlib.lines.Line2D object at 0x7f1ff9ec6cd0>]

Recovery using TNIPM

# We need to estimate the regularization parameter
Atb = A.trans(b)
tau = float(0.1 * jnp.max(jnp.abs(Atb)))
print(f'{tau=}')
# Now run the solver
sol = l1ls.solve_jit(A, b, tau)

# number of L1-LS iterations
iterations = int(sol.iterations)
# number of A x operations
n_times = int(sol.n_times)
# number of A^H y operations
n_trans = int(sol.n_trans)
print(f'{iterations=} {n_times=} {n_trans=}')

# residual norm
r_norm = norm(sol.x)
print(f'{r_norm=:.3e}')

# relative error
rel_error = norm(xs - sol.x) / norm(xs)
print(f'{rel_error=:.3e}')
tau=0.05101185632358787
iterations=17 n_times=173 n_trans=174
r_norm=1.122e+01
rel_error=1.224e-01

Solution

plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.subplot(211)
plt.plot(xs)
plt.subplot(212)
plt.plot(sol.x)
spikes l1ls
[<matplotlib.lines.Line2D object at 0x7f1ffabc9fa0>]

The magnitudes of non-zero values

plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(jnp.sort(jnp.abs(sol.x)))
spikes l1ls
[<matplotlib.lines.Line2D object at 0x7f1ffa861fd0>]

Thresholding for large values

x = hard_threshold_by(sol.x, 0.5)
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.subplot(211)
plt.plot(xs)
plt.subplot(212)
plt.plot(x)
spikes l1ls
[<matplotlib.lines.Line2D object at 0x7f1ff9d4c730>]

Verifying the support recovery

support_xs = support(xs)
support_x = support(x)
jnp.all(jnp.equal(support_xs, support_x))
Array(True, dtype=bool)

Improvement using least squares over support

# Identify the sub-matrix of columns for the support of recovered solution's large entries
support_x = largest_indices_by(sol.x, 0.5)
AI = A.columns(support_x)
print(AI.shape)

# Solve the least squares problem over these columns
x_I, residuals, rank, s  = jnp.linalg.lstsq(AI, b)
# fill the non-zero entries into the sparse least squares solution
x_ls = jnp.zeros_like(xs)
x_ls = x_ls.at[support_x].set(x_I)

# relative error
ls_rel_error = norm(xs - x_ls) / norm(xs)
print(f'{ls_rel_error=:.3e}')

plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.subplot(211)
plt.plot(xs)
plt.subplot(212)
plt.plot(x_ls)
spikes l1ls
(1024, 160)
ls_rel_error=2.070e-02

[<matplotlib.lines.Line2D object at 0x7f1ffa42d2b0>]

Total running time of the script: (0 minutes 5.898 seconds)

Gallery generated by Sphinx-Gallery