Source code for cr.sparse._src.lop.block_diag
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax.numpy as jnp
from .lop import Operator
from .util import apply_along_axis
[docs]def block_diag(operators, axis=0):
r"""Returns a block diagonal operator from 2 or more operators
Args:
operators (list): List of linear operators
axis (int): For multi-dimensional array input, the axis along which
the linear operator will be applied
Returns:
Operator: A block diagonal operator
Assume a set of operators :math:`T_1, T_2, \dots, T_k`
each having the shape :math:`(m_i, n_i)`.
The input model space dimension for the block diagonal operator is:
.. math::
n = \sum_{i=1}^k n_i
The output data space dimension for the block diagonal operator is:
.. math::
m = \sum_{i=1}^k m_i
For the forward mode, split the input model vector :math:`x \in \mathbb{F}^n` into
.. math::
x = \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_k \end{bmatrix}
such that :math:`\text{dim}(x_i) = n_i`.
Then the application of the block diagonal operator
in the forward mode from model space to data space
can be represented as:
.. math::
T x = \begin{bmatrix}
T_1 & 0 & \dots & 0 \\
0 & T_2 & \dots & 0 \\
\vdots & \vdots & \ddots & \vdots \\
0 & 0 & \dots & T_k
\end{bmatrix}\begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_k \end{bmatrix}
= \begin{bmatrix} T_1 x_1 \\ T_2 x_2 \\ \vdots \\ T_k x_k \end{bmatrix}
Similarly the application in the adjoint mode from the
data space to the model space can be represented as:
.. math::
T^H y = \begin{bmatrix}
T_1^H & 0 & \dots & 0 \\
0 & T_2^H & \dots & 0 \\
\vdots & \vdots & \ddots & \vdots \\
0 & 0 & \dots & T_k^H
\end{bmatrix}\begin{bmatrix} y_1 \\ y_2 \\ \vdots \\ y_k \end{bmatrix}
= \begin{bmatrix} T_1^H y_1 \\ T_2^H y_2 \\ \vdots \\ T_k^H y_k \end{bmatrix}
Examples:
>>> T1 = lop.matrix(2.*jnp.ones((2,2)))
>>> T2 = lop.matrix(3.*jnp.ones((3,3)))
>>> T = lop.block_diag([T1, T2])
>>> print(lop.to_matrix(T))
[[2. 2. 0. 0. 0.]
[2. 2. 0. 0. 0.]
[0. 0. 3. 3. 3.]
[0. 0. 3. 3. 3.]
[0. 0. 3. 3. 3.]]
>>> x = jnp.arange(5)+1
>>> print(T.times(x))
[ 6. 6. 36. 36. 36.]
>>> print(T.trans(x))
[ 6. 6. 36. 36. 36.]
"""
assert isinstance(operators, list)
assert len(operators) >= 2
in_slices = []
out_slices = []
m_all = 0
n_all = 0
in_start = 0
out_start = 0
for op in operators:
m, n = op.shape
m_all += m
n_all += n
in_slice = slice(in_start, in_start + n)
in_slices.append(in_slice)
out_slice = slice(out_start, out_start + m)
out_slices.append(out_slice)
in_start += n
out_start += m
# number of operators
num_operators = len(operators)
def times(x):
"""Forward operation"""
ys = []
for i in range(num_operators):
op = operators[i]
in_slice = in_slices[i]
out = op.times(x[in_slice])
ys.append(out)
return jnp.concatenate(ys)
def trans(x):
"""Adjoint operation"""
ys = []
for i in range(num_operators):
op = operators[i]
out_slice = out_slices[i]
out = op.trans(x[out_slice])
ys.append(out)
return jnp.concatenate(ys)
times, trans = apply_along_axis(times, trans, axis)
return Operator(times=times, trans=trans, shape=(m_all,n_all))