8.1. Orthogonal Matching Pursuit

This notebook compares the runtime of OMP implementation in scikit-learn with cr-sparse.

import matplotlib.pyplot as plt
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
# sklearn
from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.datasets import make_sparse_coded_signal
# cr-sparse imports
import cr.sparse as crs
import cr.sparse.pursuit.omp as crs_omp
n_components, n_features = 10000, 2000
n_nonzero_coefs = 100
y, X, w = make_sparse_coded_signal(
    n_samples=1,
    n_components=n_components,
    n_features=n_features,
    n_nonzero_coefs=n_nonzero_coefs,
    random_state=0,
)
X.shape
(2000, 10000)
(idx,) = w.nonzero()
idx
array([  11,   72,   74,  173,  176,  341,  423,  504,  643, 1441, 1595,
       1601, 1732, 1868, 2086, 2174, 2215, 2216, 2371, 2665, 2764, 2923,
       3011, 3281, 3345, 3372, 3433, 3450, 3546, 3572, 3586, 3628, 3690,
       3714, 3777, 3822, 3911, 4080, 4085, 4150, 4238, 4251, 4334, 4450,
       4544, 4567, 4674, 4960, 5244, 5428, 5434, 5581, 5649, 5657, 5781,
       5807, 5836, 5933, 6043, 6055, 6190, 6229, 6298, 6344, 6395, 6499,
       6505, 6735, 6750, 6810, 6832, 6897, 7304, 7307, 7335, 7370, 7503,
       7505, 7514, 7574, 7604, 7615, 7699, 8111, 8178, 8307, 8309, 8360,
       8422, 8438, 8780, 8909, 9212, 9263, 9517, 9643, 9644, 9649, 9891,
       9982])
# distort the clean signal
y_noisy = y + 0.05 * np.random.randn(len(y))
sklearn_omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_nonzero_coefs, normalize=False)
sklearn_omp.fit(X, y)
coef = sklearn_omp.coef_
(idx_r,) = coef.nonzero()
jnp.allclose(idx, idx_r)
DeviceArray(True, dtype=bool)
w[idx]
array([ 0.22988199, -0.94303868, -1.25317778, -0.87889387,  1.27763385,
       -0.55424202,  0.59858225,  0.63647644,  0.36803094,  2.41796731,
       -1.48495303, -0.12272381, -0.23279657,  1.46299794, -1.69637976,
        1.26852914,  0.24736281, -0.93801921, -1.53426129, -0.7115889 ,
       -1.72529192, -0.61101788,  1.41330667,  0.54418405,  0.0916008 ,
        0.45598073,  0.57411609,  0.74642222,  1.46684293, -0.39081685,
        0.20619313,  0.35597824, -0.20555081, -0.30032162,  1.30586014,
        0.69459212,  0.72687452, -0.1225696 , -0.64648978,  0.07525342,
        0.40708142, -0.90305857,  1.23610938, -0.02057273, -0.62723391,
        1.13914078,  1.60910332, -0.6505078 , -0.54608446, -1.71609353,
        0.67038443,  0.53563422, -0.33609932,  0.42850107, -0.48190576,
        1.8132168 , -0.12322664,  0.76890924, -0.30625163,  0.61183925,
        0.54834446,  1.28747139,  1.98886147, -0.20837021, -0.65145199,
       -1.44431949, -1.7036624 , -1.70906404, -0.20278546, -0.06916413,
       -0.30870709,  0.04939863,  0.03252949, -0.46755286,  1.80088922,
       -0.44528567,  0.74070197, -0.24315753, -0.00907255, -0.31648649,
       -0.44866272, -1.18214884, -0.83968022,  0.22340382,  0.89069855,
       -1.02127932, -0.65027866,  1.96807574, -0.71031316, -0.13957926,
       -0.30692036,  0.0853642 , -1.25947971, -0.03461537,  0.43799029,
        0.80638696,  1.28971144, -2.12465697,  0.90155434,  1.54350939])
X_jax = jnp.array(X)
y_jax = jnp.array(y)
solution = crs_omp.matrix_solve_jit(X_jax, y_jax, n_nonzero_coefs)
coef_jax = crs.build_signal_from_indices_and_values(n_components, solution.I, solution.x_I)
jnp.allclose(coef, coef_jax)
DeviceArray(True, dtype=bool)
np_time = %timeit -o sklearn_omp.fit(X, y)
646 ms ± 6.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jax_time = %timeit -o crs_omp.matrix_solve_jit(X_jax, y_jax, n_nonzero_coefs).I.block_until_ready()
59.2 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
gain = np_time.average / jax_time.average
print(gain)
10.918397555390463