cr.sparse.lop.block_diag¶
- cr.sparse.lop.block_diag(operators, axis=0)[source]¶
Returns a block diagonal operator from 2 or more operators
- Parameters
- Returns
A block diagonal operator
- Return type
Assume a set of operators \(T_1, T_2, \dots, T_k\) each having the shape \((m_i, n_i)\).
The input model space dimension for the block diagonal operator is:
(1)¶\[n = \sum_{i=1}^k n_i\]The output data space dimension for the block diagonal operator is:
(2)¶\[m = \sum_{i=1}^k m_i\]For the forward mode, split the input model vector \(x \in \mathbb{F}^n\) into
(3)¶\[\begin{split}x = \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_k \end{bmatrix}\end{split}\]such that \(\text{dim}(x_i) = n_i\).
Then the application of the block diagonal operator in the forward mode from model space to data space can be represented as:
(4)¶\[\begin{split}T x = \begin{bmatrix} T_1 & 0 & \dots & 0 \\ 0 & T_2 & \dots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \dots & T_k \end{bmatrix}\begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_k \end{bmatrix} = \begin{bmatrix} T_1 x_1 \\ T_2 x_2 \\ \vdots \\ T_k x_k \end{bmatrix}\end{split}\]Similarly the application in the adjoint mode from the data space to the model space can be represented as:
(5)¶\[\begin{split}T^H y = \begin{bmatrix} T_1^H & 0 & \dots & 0 \\ 0 & T_2^H & \dots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \dots & T_k^H \end{bmatrix}\begin{bmatrix} y_1 \\ y_2 \\ \vdots \\ y_k \end{bmatrix} = \begin{bmatrix} T_1^H y_1 \\ T_2^H y_2 \\ \vdots \\ T_k^H y_k \end{bmatrix}\end{split}\]Examples
>>> T1 = lop.matrix(2.*jnp.ones((2,2))) >>> T2 = lop.matrix(3.*jnp.ones((3,3))) >>> T = lop.block_diag([T1, T2]) >>> print(lop.to_matrix(T)) [[2. 2. 0. 0. 0.] [2. 2. 0. 0. 0.] [0. 0. 3. 3. 3.] [0. 0. 3. 3. 3.] [0. 0. 3. 3. 3.]] >>> x = jnp.arange(5)+1 >>> print(T.times(x)) [ 6. 6. 36. 36. 36.] >>> print(T.trans(x)) [ 6. 6. 36. 36. 36.]