7.2. 2D Wavelet Decomposition and Reconstruction¶
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import jax.numpy as jnp
import skimage.data
import cr.sparse.wt as wt
import pywt
import matplotlib.pyplot as plt
%matplotlib inline
image = skimage.data.camera()
image_jax = jnp.array(image).block_until_ready()
w_name = 'bior1.3'
coeffs_np = pywt.dwt2(image, w_name)
coeffs_jax = wt.dwt2(image_jax, w_name)
LL_np, (LH_np, HL_np, HH_np) = coeffs_np
LL_jax, (LH_jax, HL_jax, HH_jax) = coeffs_jax
jnp.allclose(LL_np, LL_jax)
DeviceArray(True, dtype=bool)
jnp.allclose(LH_np, LH_jax)
DeviceArray(True, dtype=bool)
jnp.allclose(HL_np, HL_jax)
DeviceArray(True, dtype=bool)
jnp.allclose(HH_np, HH_jax)
DeviceArray(True, dtype=bool)
np_wavelet = pywt.Wavelet(w_name)
np_times = %timeit -o pywt.dwt2(image, np_wavelet)
4.45 ms ± 43 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
jax_wavelet = wt.build_wavelet(w_name)
jax_times = %timeit -o wt.dwt2(image_jax, jax_wavelet)
2.42 ms ± 1.87 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
gain = np_times.average / jax_times.average
print(gain)
1.8403317015626668
rec_np = pywt.idwt2(coeffs_np, w_name)
rec_jax = wt.idwt2(coeffs_jax, w_name)
np.allclose(image, rec_np)
True
jnp.allclose(image_jax, rec_jax)
DeviceArray(True, dtype=bool)
jnp.allclose(rec_np, rec_jax)
DeviceArray(True, dtype=bool)
np_times = %timeit -o pywt.idwt2(coeffs_np, np_wavelet)
3.87 ms ± 42.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
jax_times = %timeit -o wt.idwt2(coeffs_jax, jax_wavelet)
1.16 ms ± 8.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
gain = np_times.average / jax_times.average
print(gain)
3.344081957774347