7.1. Wavelet Scaling Function

%load_ext autoreload
%autoreload 2
from jax.config import config
config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import jax.numpy as jnp
import pywt
import cr.sparse.wt as wt
name = 'db20'
ref = pywt.Wavelet(name)
our = wt.build_wavelet(name)
phi, psi, x = ref.wavefun()
phi2, psi2, x2 = our.wavefun()
np.allclose(phi, phi2), np.allclose(psi, psi2), np.allclose(x, x2)
(True, True, True)
ref_times = %timeit -o ref.wavefun()
408 µs ± 12.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
our_times = %timeit -o our.wavefun()
487 µs ± 895 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)