Note
Go to the end to download the full example code
CoSaMP step by step¶
This example explains the step by step development of
CoSaMP (Compressive Sensing Matching Pursuit) algorithm
for sparse recovery. It then shows how to use the
official implementation of CoSaMP in CR-Sparse
.
The CoSaMP algorithm has following inputs:
A sensing matrix or dictionary
Phi
which has been used for data measurements.A measurement vector
y
.The sparsity level
K
.
The objective of the algorithm is to estimate a K-sparse solution x
such that y
is approximately equal to Phi x
.
A key quantity in the algorithm is the residual r = y - Phi x
. Each
iteration of the algorithm successively improves the estimate x
so
that the energy of the residual r
reduces.
The algorithm proceeds as follows:
Initialize the solution
x
with zero.Maintain an index set
I
(initially empty) of atoms selected as part of the solution.While the residual energy is above a threshold:
Match: Compute the inner product of each atom in
Phi
with the current residualr
.Identify: Select the indices of 2K atoms from
Phi
with the largest correlation with the residual.Merge: merge these 2K indices with currently selected indices in
I
to formI_sub
.LS: Compute the least squares solution of
Phi[:, I_sub] z = y
Prune: Pick the largest K entries from this least square solution and keep them in
I
.Update residual: Compute
r = y - Phi_I x_I
.
It is time to see the algorithm in action.
Let’s import necessary libraries
import jax
from jax import random
import jax.numpy as jnp
# Some keys for generating random numbers
key = random.PRNGKey(0)
keys = random.split(key, 4)
# For plotting diagrams
import matplotlib.pyplot as plt
# CR-Sparse modules
import cr.sparse as crs
import cr.sparse.dict as crdict
import cr.sparse.data as crdata
from cr.nimble.dsp import (
nonzero_indices,
nonzero_values,
largest_indices
)
Problem Setup¶
The Sparsifying Basis¶
(128, 256)
Coherence of atoms in the sensing matrix
print(crdict.coherence(Phi))
0.3881940752728321
A sparse model vector¶
x0, omega = crdata.sparse_normal_representations(key, N, K)
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(x0)
[<matplotlib.lines.Line2D object at 0x7f1ffe12b760>]
omega
contains the set of indices at which x is nonzero (support of x
)
print(omega)
[ 41 60 68 89 99 198 232 244]
Compressive measurements¶
y = Phi @ x0
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(y)
[<matplotlib.lines.Line2D object at 0x7f1fffe26190>]
Development of CoSaMP algorithm¶
# In the following, we walk through the steps of CoSaMP algorithm.
# Since we have access to ``x0`` and ``omega``, we can measure the
# progress made by the algorithm steps by comparing the estimates
# with actual ``x0`` and ``omega``. However, note that in the
# real implementation of the algorithm, no access to original model
# vector is there.
#
# Initialization
# ''''''''''''''''''''''''''''''''''''''''''''
We assume the initial solution to be zero and
the residual r = y - Phi x
to equal the measurements y
r = y
Squared norm/energy of the residual
y_norm_sqr = float(y.T @ y)
r_norm_sqr = y_norm_sqr
print(f"{r_norm_sqr=}")
r_norm_sqr=7.401212029141624
A boolean array to track the indices selected for least squares steps
During the matching steps, 2K atoms will be picked.
At any time, up to 3K atoms may be selected (after the merge step).
Number of iterations completed so far
iterations = 0
A limit on the maximum tolerance for residual norm
res_norm_rtol = 1e-3
max_r_norm_sqr = y_norm_sqr * (res_norm_rtol ** 2)
print(f"{max_r_norm_sqr=:.2e}")
max_r_norm_sqr=7.40e-06
First iteration¶
print("First iteration:")
First iteration:
Match the current residual with the atoms in Phi
h = Phi.T @ r
Pick the indices of 3K atoms with largest matches with the residual
I_sub = largest_indices(h, K3)
# Update the flags array
flags = flags.at[I_sub].set(True)
# Sort the ``I_sub`` array with the help of flags array
I_sub, = jnp.where(flags)
# Since no atoms have been selected so far, we can be more aggressive
# and pick 3K atoms in first iteration.
print(f"{I_sub=}")
I_sub=Array([ 14, 30, 44, 60, 64, 78, 84, 89, 99, 116, 118, 127, 128,
149, 157, 158, 162, 168, 184, 192, 198, 203, 232, 244], dtype=int64)
Check which indices from omega
are there in I_sub
.
print(jnp.intersect1d(omega, I_sub))
[ 60 89 99 198 232 244]
Select the subdictionary of Phi
consisting of atoms indexed by I_sub
Phi_sub = Phi[:, flags]
Compute the least squares solution of y
over this subdictionary
x_sub, r_sub_norms, rank_sub, s_sub = jnp.linalg.lstsq(Phi_sub, y)
# Pick the indices of K largest entries in in ``x_sub``
Ia = largest_indices(x_sub, K)
print(f"{Ia=}")
Ia=Array([ 3, 7, 23, 20, 22, 8, 15, 18], dtype=int64)
We need to map the indices in Ia
to the actual indices of atoms in Phi
I = I_sub[Ia]
print(f"{I=}")
I=Array([ 60, 89, 244, 198, 232, 99, 158, 184], dtype=int64)
Select the corresponding values from the LS solution
x_I = x_sub[Ia]
We now have our first estimate of the solution
x = jnp.zeros(N).at[I].set(x_I)
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(x0, label="Original vector")
plt.plot(x, '--', label="Estimated solution")
plt.legend()
<matplotlib.legend.Legend object at 0x7f1ffb6b1df0>
We can check how good we were in picking the correct indices from the actual support of the signal
found = jnp.intersect1d(omega, I)
print("Found indices: ", found)
Found indices: [ 60 89 99 198 232 244]
We found 6 out of 8 indices in the support. Here are the remaining.
missing = jnp.setdiff1d(omega, I)
print("Missing indices: ", missing)
Missing indices: [41 68]
It is time to compute the residual after the first iteration
Phi_I = Phi[:, I]
r = y - Phi_I @ x_I
Compute the residual and verify that it is still larger than the allowed tolerance
r_norm_sqr = float(r.T @ r)
print(f"{r_norm_sqr=:.2e} > {max_r_norm_sqr=:.2e}")
r_norm_sqr=8.28e-02 > max_r_norm_sqr=7.40e-06
Store the selected K indices in the flags array
flags = flags.at[:].set(False)
flags = flags.at[I].set(True)
print(jnp.where(flags))
(Array([ 60, 89, 99, 158, 184, 198, 232, 244], dtype=int64),)
Mark the completion of the iteration
iterations += 1
Second iteration¶
print("Second iteration:")
Second iteration:
Match the current residual with the atoms in Phi
h = Phi.T @ r
Pick the indices of 2K atoms with largest matches with the residual
I_2k = largest_indices(h, K2 if iterations else K3)
# We can check if these include the atoms missed out in first iteration.
print(jnp.intersect1d(omega, I_2k))
[41 68]
Merge (union) the set of previous K indices with the new 2K indices
flags = flags.at[I_2k].set(True)
I_sub, = jnp.where(flags)
print(f"{I_sub=}")
I_sub=Array([ 8, 25, 41, 42, 60, 66, 67, 68, 72, 89, 99, 111, 129,
158, 164, 184, 190, 195, 198, 216, 220, 232, 233, 244], dtype=int64)
We can check if we found all the actual atoms
print("Found in I_sub: ", jnp.intersect1d(omega, I_sub))
Found in I_sub: [ 41 60 68 89 99 198 232 244]
Indeed we did. The set difference is empty.
print("Missing in I_sub: ", jnp.setdiff1d(omega, I_sub))
Missing in I_sub: []
Select the subdictionary of Phi
consisting of atoms indexed by I_sub
Phi_sub = Phi[:, flags]
Compute the least squares solution of y
over this subdictionary
x_sub, r_sub_norms, rank_sub, s_sub = jnp.linalg.lstsq(Phi_sub, y)
# Pick the indices of K largest entries in in ``x_sub``
Ia = largest_indices(x_sub, K)
print(Ia)
[ 4 9 23 18 21 10 7 2]
We need to map the indices in Ia
to the actual indices of atoms in Phi
I = I_sub[Ia]
print(I)
[ 60 89 244 198 232 99 68 41]
Check if the final K indices in I
include all the indices in omega
jnp.setdiff1d(omega, I)
Array([], dtype=int64)
Select the corresponding values from the LS solution
x_I = x_sub[Ia]
Here is our updated estimate of the solution
x = jnp.zeros(N).at[I].set(x_I)
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(x0, label="Original vector")
plt.plot(x, '--', label="Estimated solution")
plt.legend()
<matplotlib.legend.Legend object at 0x7f1ffb982c10>
The algorithm has no direct way of knowing that it indeed found the solution It is time to compute the residual after the second iteration
Phi_I = Phi[:, I]
r = y - Phi_I @ x_I
Compute the residual and verify that it is now below the allowed tolerance
r_norm_sqr = float(r.T @ r)
# It turns out that it is now below the tolerance threshold
print(f"{r_norm_sqr=:.2e} < {max_r_norm_sqr=:.2e}")
r_norm_sqr=7.09e-30 < max_r_norm_sqr=7.40e-06
We have completed the signal recovery. We can stop iterating now.
iterations += 1
CR-Sparse official implementation¶
The JIT compiled version of this algorithm is available in
cr.sparse.pursuit.cosamp
module.
Import the module
from cr.sparse.pursuit import cosamp
Run the solver
solution = cosamp.matrix_solve_jit(Phi, y, K)
# The support for the sparse solution
I = solution.I
print(I)
[ 60 89 244 198 232 99 68 41]
The non-zero values on the support
x_I = solution.x_I
print(x_I)
[ 1.9097652 1.12094818 1.04348768 -0.82606793 0.64812788 0.33432345
0.29561749 0.08482584]
Verify that we successfully recovered the support
print(jnp.setdiff1d(omega, I))
[]
Print the residual energy and the number of iterations when the algorithm converged.
print(solution.r_norm_sqr, solution.iterations)
7.726387804898689e-30 3
Let’s plot the solution
x = jnp.zeros(N).at[I].set(x_I)
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(x0, label="Original vector")
plt.plot(x, '--', label="Estimated solution")
plt.legend()
<matplotlib.legend.Legend object at 0x7f1ffb64eb20>
Total running time of the script: (0 minutes 3.471 seconds)