Source code for cr.sparse._src.lop.normest
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Power iteration method for linear operators
"""
from jax import jit, lax, random
import jax.numpy as jnp
from typing import NamedTuple
import cr.nimble as cnb
class NormEstState(NamedTuple):
"""State for the norm estimation algorithm
"""
# (unnormalized) eigen vector guess
v: jnp.ndarray
# eigen value estimate
old_estimate: float
# new eigen value estimate
new_estimate: float
# number of iterations
iterations: int
class NormEstSolution(NamedTuple):
"""Solution of the eigen vector estimate
"""
v: jnp.ndarray
"""The estimated eigen vector"""
s : float
"""The estimated eigen value"""
iterations: int
"""Number of iterations to converge"""
[docs]def normest(
operator,
max_iters=100,
error_tolerance=1e-6):
"""Estimates the norm of a linear operator by power method
Args:
operator (cr.sparse.lop.Operator): A linear operator :math:`A`
max_iters (int): Maximum number of iterations
error_tolerance (float): Tolerance for relative change in largest eigen value
Returns:
(float): An estimate of the norm
"""
shape = operator.input_shape
# initial eigen vector
b = random.normal(cnb.KEYS[0], shape)
if not operator.real:
bi = random.normal(cnb.KEYS[1], shape)
b = b + bi * 1j
# normalize it
b = b / cnb.arr_l2norm(b)
def init():
return NormEstState(v=b, old_estimate=-1e20, new_estimate=1e20, iterations=0)
def cond(state):
# check if the gap between new and old estimate is still high
change = state.new_estimate - state.old_estimate
relchange = jnp.abs(change / state.old_estimate)
not_converged = jnp.greater(relchange, error_tolerance)
# return true if the the algorithm hasn't converged and there are more iterations to go
return jnp.logical_and(state.iterations < max_iters, not_converged)
def body(state):
"""One step of power iteration."""
v = state.v
# normalize
v_norm = cnb.arr_l2norm(v)
v = v / v_norm
# compute the next vector
v_new = operator.times(v)
v_new = operator.trans(v_new)
# estimate the eigen value
new_estimate = jnp.vdot(v, v_new)
# largest singular value is non-negative
new_estimate = jnp.abs(new_estimate)
return NormEstState(v=v_new, old_estimate=state.new_estimate,
new_estimate=new_estimate, iterations=state.iterations+1)
state = lax.while_loop(cond, body, init())
# We have converged
return jnp.sqrt(state.new_estimate)
normest_jit = jit(normest, static_argnums=(0,))