Source code for cr.sparse._src.lop.dot
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax.numpy as jnp
from .lop import Operator
import cr.nimble as cnb
[docs]def dot(v, adjoint=False, axis=0):
"""Returns a linear operator T such that :math:`T x = \\langle v , x \\rangle = v^H x`
Args:
v (jax.numpy.ndarray): The vector/array with which the inner product will be computed
adjoint (bool): Indicates if we need the dot operator or its adjoint
axis (int): For multi-dimensional array input, the axis along which
the linear operator will be applied
Note:
axis parameter is useful only if v is 1D.
"""
v = jnp.asarray(v)
assert v.ndim >= 1, "v cannot be a scalar"
# make sure that v is inexact
v = cnb.promote_arg_dtypes(v)
n = v.shape[0] if v.ndim == 1 else v.shape
m = 1
def times1d(x):
result = cnb.arr_rdot(v, x)
return jnp.expand_dims(result, 0)
def times(x):
x = jnp.asarray(x)
assert x.shape == v.shape, "shape of x must be same as shape of v"
if x.ndim == v.ndim:
return times1d(x)
if v.ndim == 1:
return jnp.apply_along_axis(times1d, axis, x)
raise ValueError("axis parameter is not supported for ND arrays")
def trans(x):
# the inner product must be real
assert jnp.isrealobj(x), "The data space is real as this linear operator represents a real inner product"
return v * x
if adjoint:
m,n = n, m
times, trans = trans, times
return Operator(times=times, trans=trans, shape=(m, n))