# Matrix Multiplication Operator

In [1]:
import numpy as np
import jax.numpy as jnp
import pylops
from cr.sparse import lop
import cr.sparse as crs

In [2]:
m,n = 10000,10000
A_np = np.random.normal(0, 1, (m, n))
x_np = np.ones(n)

In [3]:
op_np = pylops.MatrixMult(A_np)

In [4]:
y_np = op_np * x_np

In [5]:
A_jax = jnp.array(A_np)
x_jax = jnp.array(x_np)

In [6]:
op_jax = lop.matrix(A_jax)
op_jax = lop.jit(op_jax)

In [7]:
y_jax = op_jax.times(x_jax)

In [8]:
np.allclose(y_np, y_jax, atol=1e-4)

True

In [9]:
np_time = %timeit -o op_np * x_np

11 ms ± 59.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
jax_time = %timeit -o op_jax.times(x_jax).block_until_ready()

2.51 ms ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
gain = np_time.average / jax_time.average
print(gain)

4.370156377294028


In [12]:
y1_np = op_np.H * x_np

In [13]:
y1_jax = op_jax.trans(x_jax).block_until_ready()

In [14]:
np.allclose(y1_np, y1_jax, atol=1e-4)

True

In [15]:
np_time = %timeit -o op_np.H * x_np

11.6 ms ± 55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
jax_time = %timeit -o op_jax.trans(x_jax).block_until_ready()

2.51 ms ± 685 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
gain = np_time.average / jax_time.average
print(gain)

4.629653961375258
