6.1. Lanczos Diagonalization with Partial Reorthogonalization

# Configure JAX for 64-bit computing
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import scipy as sp
import scipy.io
import jax
import jax.numpy as jnp
import cr.sparse as crs
import cr.sparse.la.svd as lasvd
data = scipy.io.loadmat('labpro_test.mat', squeeze_me=True, struct_as_record=True)
A = jnp.array(data['A'])
m, n = A.shape
p0 = jnp.array(data['p0'])
rnk = int(data['rnk'])
k = int(data['k'])
alpha = jnp.array(data['alpha'])
beta = jnp.array(data['beta'])
U = jnp.array(data['U'])
V = jnp.array(data['V'])
state = lasvd.lanbpro_jit(A, k, p0)
jnp.allclose(alpha, state.alpha)
DeviceArray(True, dtype=bool)
jnp.allclose(beta, state.beta[1:])
DeviceArray(True, dtype=bool)
%timeit lasvd.lanbpro_jit(A, k, p0).alpha.block_until_ready()
4.89 s ± 45.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)