Source code for cr.sparse._src.lop.identity

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp

from .lop import Operator

from .util import apply_along_axis

[docs]def identity(in_dim, out_dim=None, axis=0): """Returns an identity linear operator from model space to data space Args: in_dim (int): Dimension of the model space out_dim (int): Dimension of the data space axis (int): For multi-dimensional array input, the axis along which the linear operator will be applied Returns: Operator: An identity linear operator If ``out_dim`` is not specified, then we assume that both model space and data space have same dimension. Example: A square identity operator:: >>> T = lop.identity(4) >>> T.times(jnp.arange(4) + 0.) DeviceArray([0., 1., 2., 3.], dtype=float32) >>> T.trans(jnp.arange(4)) DeviceArray([0, 1, 2, 3], dtype=int32) A tall identity operator (output has more dimensions):: >>> T = lop.identity(4, 6) >>> T.times(jnp.arange(4) + 0.) DeviceArray([0., 1., 2., 3., 0., 0.], dtype=float32) >>> T.trans(T.times(jnp.arange(4) + 0.)) DeviceArray([0., 1., 2., 3.], dtype=float32) A wide identity operator (output has less dimensions):: >>> T = lop.identity(4, 3) >>> T.times(jnp.arange(4) + 0.) DeviceArray([0., 1., 2.], dtype=float32) >>> T.trans(T.times(jnp.arange(4) + 0.)) DeviceArray([0., 1., 2., 0.], dtype=float32) By default T applies along columns of a matrix (axis=0):: >>> T.times(jnp.arange(20).reshape(4, 5)) DeviceArray([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]], dtype=int32) Identity operator applying along rows of a 2D matrix:: >>> T = lop.identity(4, 3, axis=1) >>> T.times(jnp.arange(20).reshape(5, 4)) DeviceArray([[ 0, 1, 2], [ 4, 5, 6], [ 8, 9, 10], [12, 13, 14], [16, 17, 18]], dtype=int32) """ out_dim = in_dim if out_dim is None else out_dim if in_dim == out_dim: times = lambda x: x trans = lambda x : x elif in_dim > out_dim: # we drop some samples times = lambda x: x[:out_dim] # we pad with zeros trans = lambda x : jnp.pad(x, (0, in_dim - out_dim)) else: # we pad with zeros times = lambda x : jnp.pad(x, (0, out_dim - in_dim)) # we drop some samples trans = lambda x: x[:in_dim] times, trans = apply_along_axis(times, trans, axis) return Operator(times=times, trans=trans, shape=(out_dim,in_dim))