# HAAR DWT2 

In [1]:
from jax.config import config
config.update("jax_enable_x64", True)
from jax import random
import numpy as np
import jax.numpy as jnp
import pylops
from cr.sparse import lop
import cr.sparse as crs

In [4]:
n = 4000
m =crs.next_pow_of_2(n)
shape = (n, n)
x_jax = random.randint(random.PRNGKey(0), shape, -10, 10)
x_np = np.array(x_jax)
print(m, n)

4096 4000


In [10]:
level = 8
op_np = pylops.signalprocessing.DWT2D(shape, level=level)
op_jax = lop.jit(lop.dwt2D(shape, level=level))
print(op_jax.shape)

((4096, 4096), (4000, 4000))


In [4]:
y_np = (op_np @ x_np.flatten()).reshape((m, m))

In [5]:
y_jax = op_jax.times(x_jax)

In [6]:
jnp.allclose(y_np, y_jax)

DeviceArray(True, dtype=bool)

In [7]:
np_time = %timeit -o (op_np @ x_np.flatten()).reshape((m, m))

981 ms ± 2.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
jax_time = %timeit -o op_jax.times(x_jax).block_until_ready()

34.4 ms ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
gain = np_time.average / jax_time.average
print(gain)

28.514630735291732


In [10]:
z_np = op_np.rmatvec(y_np.flatten()).reshape(shape)

In [11]:
z_jax = op_jax.trans(y_jax)

In [12]:
jnp.allclose(z_np, z_jax)

DeviceArray(True, dtype=bool)

In [13]:
np_time = %timeit -o op_np.rmatvec(y_np.flatten()).reshape(shape)

713 ms ± 5.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
jax_time = %timeit -o op_jax.trans(y_jax)

60.8 ms ± 25.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [15]:
gain = np_time.average / jax_time.average
print(gain)

11.724658650916197
