# 2D Wavelet Decomposition and Reconstruction 

In [1]:
from jax.config import config
config.update("jax_enable_x64", True)

In [2]:
import numpy as np
import jax.numpy as jnp

In [3]:
import skimage.data
import cr.sparse.wt as wt
import pywt

In [4]:
import matplotlib.pyplot as plt
%matplotlib inline

In [5]:
image = skimage.data.camera()
image_jax = jnp.array(image).block_until_ready()

In [6]:
w_name = 'bior1.3'

In [7]:
coeffs_np = pywt.dwt2(image, w_name)

In [8]:
coeffs_jax = wt.dwt2(image_jax, w_name)

In [9]:
LL_np, (LH_np, HL_np, HH_np) = coeffs_np

In [10]:
LL_jax, (LH_jax, HL_jax, HH_jax) = coeffs_jax

In [11]:
jnp.allclose(LL_np, LL_jax)

DeviceArray(True, dtype=bool)

In [12]:
jnp.allclose(LH_np, LH_jax)

DeviceArray(True, dtype=bool)

In [13]:
jnp.allclose(HL_np, HL_jax)

DeviceArray(True, dtype=bool)

In [14]:
jnp.allclose(HH_np, HH_jax)

DeviceArray(True, dtype=bool)

In [15]:
np_wavelet = pywt.Wavelet(w_name)
np_times = %timeit -o pywt.dwt2(image, np_wavelet)

4.48 ms ± 12.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
jax_wavelet = wt.build_wavelet(w_name)
jax_times = %timeit -o wt.dwt2(image_jax, jax_wavelet)

656 µs ± 7.08 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [17]:
gain = np_times.average / jax_times.average
print(gain)

6.831058197177011


In [18]:
rec_np = pywt.idwt2(coeffs_np, w_name)

In [19]:
rec_jax = wt.idwt2(coeffs_jax, w_name)

In [20]:
np.allclose(image, rec_np)

True

In [21]:
jnp.allclose(image_jax, rec_jax)

DeviceArray(True, dtype=bool)

In [22]:
jnp.allclose(rec_np, rec_jax)

DeviceArray(True, dtype=bool)

In [23]:
np_times = %timeit -o pywt.idwt2(coeffs_np, np_wavelet)

3.4 ms ± 19.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [24]:
jax_times = %timeit -o wt.idwt2(coeffs_jax, jax_wavelet)

614 µs ± 3.96 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [25]:
gain = np_times.average / jax_times.average
print(gain)

5.544721592895543
