5.2. Diagonal Operator¶
import numpy as np
import jax.numpy as jnp
import pylops
from cr.sparse import lop
import cr.sparse as crs
n = 1000*1000
d_np = np.random.normal(0, 1, (n))
x_np = np.ones(n)
op_np = pylops.Diagonal(d_np)
y_np = op_np * x_np
d_jax = jnp.array(d_np)
x_jax = jnp.array(x_np)
op_jax = lop.diagonal(d_jax)
op_jax = lop.jit(op_jax)
y_jax = op_jax.times(x_jax)
np.allclose(y_np, y_jax)
True
np_time = %timeit -o op_np * x_np
1.55 ms ± 88.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
jax_time = %timeit -o op_jax.times(x_jax).block_until_ready()
49.1 µs ± 3.44 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
gain = np_time.average / jax_time.average
print(gain)
31.53315690766351
y1_np = op_np.H * x_np
y1_jax = op_jax.trans(x_jax)
np.allclose(y1_np, y1_jax)
True
np_time = %timeit -o op_np.H * x_np
1.68 ms ± 77.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
jax_time = %timeit -o op_jax.trans(x_jax).block_until_ready()
47.4 µs ± 653 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
gain = np_time.average / jax_time.average
print(gain)
35.42500030514343