4.3. Wavelet Decomposition and Reconstruction of an ImageΒΆ

from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import jax.numpy as jnp
import cr.sparse.wt as wt
import skimage.data
import matplotlib.pyplot as plt
%matplotlib inline
image = skimage.data.camera()
image_jax = jnp.array(image).block_until_ready()
coeffs2 = wt.dwt2(image, 'bior1.3')
LL, (LH, HL, HH) = coeffs2
titles = ['Approximation', ' Horizontal detail',
          'Vertical detail', 'Diagonal detail']
fig = plt.figure(figsize=(12, 3))
for i, a in enumerate([LL, LH, HL, HH]):
    ax = fig.add_subplot(1, 4, i + 1)
    ax.imshow(a, interpolation="nearest", cmap=plt.cm.gray)
    ax.set_title(titles[i], fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])

fig.tight_layout()
plt.show()
../../../../../_images/cameraman_bior_9_0.png
reconstruction = wt.idwt2(coeffs2, 'bior1.3')
jnp.allclose(reconstruction, image)
DeviceArray(True, dtype=bool)