5.4. HAAR DWT2¶
from jax.config import config
config.update("jax_enable_x64", True)
from jax import random
import numpy as np
import jax.numpy as jnp
import pylops
from cr.sparse import lop
import cr.sparse as crs
n = 4000
m =crs.next_pow_of_2(n)
shape = (n, n)
x_jax = random.randint(random.PRNGKey(0), shape, -10, 10)
x_np = np.array(x_jax)
print(m, n)
4096 4000
level = 8
op_np = pylops.signalprocessing.DWT2D(shape, level=level)
op_jax = lop.jit(lop.dwt2D(shape, level=level))
print(op_jax.shape)
((4096, 4096), (4000, 4000))
y_np = (op_np @ x_np.flatten()).reshape((m, m))
y_jax = op_jax.times(x_jax)
jnp.allclose(y_np, y_jax)
DeviceArray(True, dtype=bool)
np_time = %timeit -o (op_np @ x_np.flatten()).reshape((m, m))
1.13 s ± 15.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jax_time = %timeit -o op_jax.times(x_jax).block_until_ready()
77.5 ms ± 17.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
gain = np_time.average / jax_time.average
print(gain)
14.547531586021503
z_np = op_np.rmatvec(y_np.flatten()).reshape(shape)
z_jax = op_jax.trans(y_jax)
jnp.allclose(z_np, z_jax)
DeviceArray(True, dtype=bool)
np_time = %timeit -o op_np.rmatvec(y_np.flatten()).reshape(shape)
804 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jax_time = %timeit -o op_jax.trans(y_jax)
54.6 ms ± 24.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
gain = np_time.average / jax_time.average
print(gain)
14.711116617240743