Source code for cr.sparse._src.lop.reshape

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import reduce
import jax.numpy as jnp

from .lop import Operator


[docs]def reshape(in_shape, out_shape, order='C'): """Returns a linear operator which reshapes vectors from model space to data space Args: in_shape (int): Shape of vectors in the model space out_shape (int): Shape of vectors in the data space order: Specifies index order of data layout ['C', 'F', 'A'] C means C-like index order (default). F means Fortran like order. This is the order in MATLAB arrays also. Returns: (Operator): A reshaping linear operator """ in_size = jnp.prod(jnp.array(in_shape)) out_size = jnp.prod(jnp.array(out_shape)) assert in_size == out_size, "Input and output size must be equal" assert order in ['C', 'F', 'A'], "Invalid order" times = lambda x: jnp.reshape(x, out_shape, order=order) trans = lambda x : jnp.reshape(x, in_shape, order=order) return Operator(times=times, trans=trans, shape=(out_shape,in_shape))
[docs]def arr2vec(shape): """Returns a linear operator which reshapes arrays to vectors Args: shape (int): Shape of arrays in the model space Returns: (Operator): An array to vec linear operator """ in_size = reduce((lambda x, y: x * y), shape) out_shape = (in_size,) times = lambda x: jnp.reshape(x, (in_size,)) trans = lambda x : jnp.reshape(x, shape) return Operator(times=times, trans=trans, shape=(out_shape,shape))