Source code for cr.sparse._src.lop.basic

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp

from .impl import _hermitian
from .lop import Operator
from .util import apply_along_axis
from jax.experimental.sparse import sparsify

###########################################################################################
#  Basic operators
###########################################################################################

[docs]def real_matrix(A): """Converts a real matrix into a linear operator Args: A (jax.numpy.ndarray): A real valued matrix (2D array) Returns: Operator: A linear operator wrapping the matrix Forward operation: .. math:: y = A x Adjoint operation: .. math:: y = A^T x """ m, n = A.shape def times(x): assert x.ndim == 1 return A @ x def trans(x): assert x.ndim == 1 return x @ A return Operator(times=times, trans=trans, shape=(m,n))
[docs]def matrix(A, axis=0): """Converts a matrix into a linear operator Args: A (jax.numpy.ndarray): A real or complex matrix (2D array) axis (int): For multi-dimensional array input, the axis along which the linear operator will be applied Returns: Operator: A linear operator wrapping the matrix Forward operation: .. math:: y = A x Adjoint operation: .. math:: y = A^H x = (x^H A)^H """ m, n = A.shape real = jnp.isrealobj(A) times1d = lambda x: A @ x trans1d = lambda x : _hermitian(_hermitian(x) @ A ) def times(x): """Forward matrix multiplication """ if x.ndim == 1: return A @ x if x.ndim == 2: if axis == 0: return A @ x else: return x @ A.T # general case return jnp.apply_along_axis(times1d, axis, x) def trans(x): """Adjoint matrix multiplication """ if x.ndim == 1: return trans1d(x) if x.ndim == 2: if axis == 0: return _hermitian(A) @ x else: return x @ jnp.conjugate(A) # general case return jnp.apply_along_axis(trans1d, axis, x) return Operator(times=times, trans=trans, shape=(m,n), real=real)
[docs]def sparse_real_matrix(A, axis=0): """Converts a sparse real matrix into a linear operator Args: A (jax.experimental.sparse.BCOO): A real valued sparse matrix in BCOO format Returns: Operator: A linear operator wrapping the matrix Forward operation: .. math:: y = A x Adjoint operation: .. math:: y = A^T x """ m, n = A.shape @sparsify def times1d(x): return A @ x @sparsify def trans1d(x): return x @ A @sparsify def times(x): """Forward matrix multiplication """ if x.ndim == 1: return A @ x if x.ndim == 2: if axis == 0: return A @ x else: return x @ A.T # general case return jnp.apply_along_axis(times1d, axis, x) @sparsify def trans(x): """Adjoint matrix multiplication """ if x.ndim == 1: return trans1d(x) if x.ndim == 2: if axis == 0: return A.T @ x else: return x @ A # general case return jnp.apply_along_axis(trans1d, axis, x) return Operator(times=times, trans=trans, shape=(m,n))
[docs]def diagonal(d, axis=0): """Returns a linear operator which mimics multiplication by a diagonal matrix Args: d (jax.numpy.ndarray): A vector (1D array) of diagonal entries axis (int): For multi-dimensional array input, the axis along which the linear operator will be applied Returns: Operator: A linear operator wrapping the diagonal matrix multiplication """ assert d.ndim == 1 n = d.shape[0] times = lambda x: d * x trans = lambda x: _hermitian(d) * x times, trans = apply_along_axis(times, trans, axis) return Operator(times=times, trans=trans, shape=(n,n))
[docs]def scalar_mult(alpha, n): r"""Returns a linear operator T such that :math:`T v = \alpha v` Args: alpha (float): A scalar value n (int): The dimension of model/data space Returns: Operator: A linear operator wrapping the scalar multiplication """ alpha = jnp.asarray(alpha) alpha_c = jnp.conjugate(alpha) assert alpha.ndim == 0, "alpha must be a scalar quantity" times = lambda x: alpha * x trans = lambda x: alpha_c * x return Operator(times=times, trans=trans, shape=(n,n))
[docs]def zero(in_dim, out_dim=None, axis=0): """Returns a linear operator which maps everything to 0 vector in data space Args: in_dim (int): Dimension of the model space out_dim (int): Dimension of the data space (default in_dim) axis (int): For multi-dimensional array input, the axis along which the linear operator will be applied Returns: Operator: A zero linear operator """ out_dim = in_dim if out_dim is None else out_dim times = lambda x: jnp.zeros(out_dim, dtype=x.dtype) trans = lambda x: jnp.zeros(in_dim, dtype=x.dtype) times, trans = apply_along_axis(times, trans, axis) return Operator(times=times, trans=trans, shape=(out_dim,in_dim))
[docs]def flipud(n): """Returns an operator which flips the order of entries in input upside down""" times = lambda x: jnp.flipud(x) trans = lambda x: jnp.flipud(x) return Operator(times=times, trans=trans, shape=(n,n))
[docs]def sum(n): """Returns an operator which computes the sum of a vector""" times = lambda x: jnp.sum(x, keepdims=True, axis=0) trans = lambda x: jnp.repeat(x, n, axis=0) return Operator(times=times, trans=trans, shape=(1,n))
[docs]def pad_zeros(n, before, after): """Adds zeros before and after a vector. Note: This operator is not JIT compliant """ pad_1_dim = (before, after) pad_2_dim = ((before, after), (0, 0)) m = before + n + after def times(x): return jnp.pad(x, pad_1_dim) def trans(x): return x[before:before+n] return Operator(times=times, trans=trans, shape=(m,n), matrix_safe=False)
[docs]def real(n): """Returns the real parts of a vector of complex numbers Note: This is a self-adjoint operator. This is not a linear operator. """ times = lambda x: jnp.real(x) trans = lambda x: jnp.real(x) return Operator(times=times, trans=trans, shape=(n,n), linear=False)
[docs]def symmetrize(n): """An operator which constructs a symmetric vector by pre-pending the input in reversed order """ times = lambda x: jnp.concatenate((jnp.flipud(x), x)) trans = lambda x: x[n:] + x[n-1::-1] return Operator(times=times, trans=trans, shape=(2*n,n))
[docs]def restriction(n, I, axis=0): """An operator which computes y = x[I] over an index set I Args: n (int): Dimension of model space I (jax.numpy.ndarray): axis (int): For multi-dimensional array input, the axis along which the linear operator will be applied """ k = len(I) times1d = lambda x: x[I] trans1d = lambda x: jnp.zeros((n,), dtype=x.dtype).at[I].set(x) def times(x): if x.ndim == 1: return times1d(x) if x.ndim == 2: if axis == 0: # we apply column wise return x[I, :] # we apply row-wise return x[:, I] # general case return jnp.apply_along_axis(times1d, axis, x) def trans(x): if x.ndim == 1: return trans1d(x) # general case return jnp.apply_along_axis(trans1d, axis, x) return Operator(times=times, trans=trans, shape=(k,n))
[docs]def heaviside(n, axis=0, normalized=True): """Returns a linear operator implements the Heaviside step function Args: n (int): Dimension of the model space axis (int): For multi-dimensional array input, the axis along which the linear operator will be applied normalized: If False, then simple cumsum, otherwise, apply on weighted x Returns: Operator: A Heaviside linear operator Heaviside function is also known as the step function. In discrete domain, it is implemented as a cumulative sum operation. An n x n Heaviside matrix has ones below and on the diagonal and zeros elsewhere. """ w = jnp.sqrt(jnp.arange(n, 0, -1)) wi = 1/w times_u = lambda x: jnp.cumsum(x) def trans_u(x): y = jnp.cumsum(x) ym = y[-1] return jnp.insert(ym - y[:-1], 0, ym) times_n = lambda x: jnp.cumsum(x * wi) def trans_n(x): y = jnp.cumsum(x) ym = y[-1] return jnp.insert(ym - y[:-1], 0, ym) * wi times, trans = (times_n, trans_n) if normalized else (times_u, trans_u) times, trans = apply_along_axis(times, trans, axis) return Operator(times=times, trans=trans, shape=(n, n))
def cumsum(n, axis=0): return heaviside(n, axis, normalized=False)
[docs]def inv_heaviside(n, axis=0, normalized=True): """Returns a linear operator that computes the inverse of Heaviside/cumsum on input Args: n (int): Dimension of the model space axis (int): For multi-dimensional array input, the axis along which the linear operator will be applied normalized(bool): Indicates if the Heaviside operator was normalized Returns: Operator: An inverse of Heaviside linear operator Recall that Heaviside operate computes the cumulative sum. This operator computes the reverse of cumulative sum which is the difference of consecutive values. """ w = jnp.sqrt(jnp.arange(n, 0, -1)) times_u = lambda x: jnp.diff(x, prepend=0) trans_u = lambda x: -jnp.diff(x, append=0) times_n = lambda x: jnp.diff(x, prepend=0) * w trans_n = lambda x: -jnp.diff(x * w, append=0) times, trans = (times_n, trans_n) if normalized else (times_u, trans_u) times, trans = apply_along_axis(times, trans, axis) return Operator(times=times, trans=trans, shape=(n, n))
def diff(n, axis=0): return inv_heaviside(n, axis, normalized=False)