4.3. Wavelet Decomposition and Reconstruction of an ImageΒΆ
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import jax.numpy as jnp
import cr.sparse.wt as wt
import skimage.data
import matplotlib.pyplot as plt
%matplotlib inline
image = skimage.data.camera()
image_jax = jnp.array(image).block_until_ready()
coeffs2 = wt.dwt2(image, 'bior1.3')
LL, (LH, HL, HH) = coeffs2
titles = ['Approximation', ' Horizontal detail',
'Vertical detail', 'Diagonal detail']
fig = plt.figure(figsize=(12, 3))
for i, a in enumerate([LL, LH, HL, HH]):
ax = fig.add_subplot(1, 4, i + 1)
ax.imshow(a, interpolation="nearest", cmap=plt.cm.gray)
ax.set_title(titles[i], fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
fig.tight_layout()
plt.show()
reconstruction = wt.idwt2(coeffs2, 'bior1.3')
jnp.allclose(reconstruction, image)
DeviceArray(True, dtype=bool)