8.1. Orthogonal Matching Pursuit¶
This notebook compares the runtime of OMP implementation in scikit-learn
with cr-sparse
.
import matplotlib.pyplot as plt
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
# sklearn
from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.datasets import make_sparse_coded_signal
# cr-sparse imports
import cr.sparse as crs
import cr.sparse.pursuit.omp as crs_omp
n_components, n_features = 10000, 2000
n_nonzero_coefs = 100
y, X, w = make_sparse_coded_signal(
n_samples=1,
n_components=n_components,
n_features=n_features,
n_nonzero_coefs=n_nonzero_coefs,
random_state=0,
)
X.shape
(2000, 10000)
(idx,) = w.nonzero()
idx
array([ 11, 72, 74, 173, 176, 341, 423, 504, 643, 1441, 1595,
1601, 1732, 1868, 2086, 2174, 2215, 2216, 2371, 2665, 2764, 2923,
3011, 3281, 3345, 3372, 3433, 3450, 3546, 3572, 3586, 3628, 3690,
3714, 3777, 3822, 3911, 4080, 4085, 4150, 4238, 4251, 4334, 4450,
4544, 4567, 4674, 4960, 5244, 5428, 5434, 5581, 5649, 5657, 5781,
5807, 5836, 5933, 6043, 6055, 6190, 6229, 6298, 6344, 6395, 6499,
6505, 6735, 6750, 6810, 6832, 6897, 7304, 7307, 7335, 7370, 7503,
7505, 7514, 7574, 7604, 7615, 7699, 8111, 8178, 8307, 8309, 8360,
8422, 8438, 8780, 8909, 9212, 9263, 9517, 9643, 9644, 9649, 9891,
9982])
# distort the clean signal
y_noisy = y + 0.05 * np.random.randn(len(y))
sklearn_omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_nonzero_coefs, normalize=False)
sklearn_omp.fit(X, y)
coef = sklearn_omp.coef_
(idx_r,) = coef.nonzero()
jnp.allclose(idx, idx_r)
DeviceArray(True, dtype=bool)
w[idx]
array([ 0.22988199, -0.94303868, -1.25317778, -0.87889387, 1.27763385,
-0.55424202, 0.59858225, 0.63647644, 0.36803094, 2.41796731,
-1.48495303, -0.12272381, -0.23279657, 1.46299794, -1.69637976,
1.26852914, 0.24736281, -0.93801921, -1.53426129, -0.7115889 ,
-1.72529192, -0.61101788, 1.41330667, 0.54418405, 0.0916008 ,
0.45598073, 0.57411609, 0.74642222, 1.46684293, -0.39081685,
0.20619313, 0.35597824, -0.20555081, -0.30032162, 1.30586014,
0.69459212, 0.72687452, -0.1225696 , -0.64648978, 0.07525342,
0.40708142, -0.90305857, 1.23610938, -0.02057273, -0.62723391,
1.13914078, 1.60910332, -0.6505078 , -0.54608446, -1.71609353,
0.67038443, 0.53563422, -0.33609932, 0.42850107, -0.48190576,
1.8132168 , -0.12322664, 0.76890924, -0.30625163, 0.61183925,
0.54834446, 1.28747139, 1.98886147, -0.20837021, -0.65145199,
-1.44431949, -1.7036624 , -1.70906404, -0.20278546, -0.06916413,
-0.30870709, 0.04939863, 0.03252949, -0.46755286, 1.80088922,
-0.44528567, 0.74070197, -0.24315753, -0.00907255, -0.31648649,
-0.44866272, -1.18214884, -0.83968022, 0.22340382, 0.89069855,
-1.02127932, -0.65027866, 1.96807574, -0.71031316, -0.13957926,
-0.30692036, 0.0853642 , -1.25947971, -0.03461537, 0.43799029,
0.80638696, 1.28971144, -2.12465697, 0.90155434, 1.54350939])
X_jax = jnp.array(X)
y_jax = jnp.array(y)
solution = crs_omp.matrix_solve_jit(X_jax, y_jax, n_nonzero_coefs)
coef_jax = crs.build_signal_from_indices_and_values(n_components, solution.I, solution.x_I)
jnp.allclose(coef, coef_jax)
DeviceArray(True, dtype=bool)
np_time = %timeit -o sklearn_omp.fit(X, y)
646 ms ± 6.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jax_time = %timeit -o crs_omp.matrix_solve_jit(X_jax, y_jax, n_nonzero_coefs).I.block_until_ready()
59.2 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
gain = np_time.average / jax_time.average
print(gain)
10.918397555390463