Source code for cr.sparse._src.lop.windowed_op

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

from jax import lax
import jax.numpy as jnp
from .lop import Operator
from .util import apply_along_axis

[docs]def windowed_op(m, T, overlap=0, axis=0): """A wrapper to convert an operator into an overcomplete windowed operator Args: m (int): Length of output T (Operator): The linear operator to be wrapped overlap (int): The amount of overlap between two windows This operator scans the input window by window, applies T on each window and then concatenates the results. * The window length is determined on the basis of the input size for T. * The length of input (n) is determined on the window and overlap sizes. * overlap must be less than window length. If overlap is zero, then this operator behaves like a block diagonal operator which each block is processed by T. """ # The shape of the underlying operator tm, tn = T.shape # window length w = tm assert overlap < w, "Overlap must be less than window size" offset = w - overlap # number of blocks n_blocks = max(1, math.ceil((m - w) / offset) + 1) # input length n = n_blocks * tn real = T.real dtype = jnp.float64 if real else jnp.complex128 yl = tm + (n_blocks - 1) * offset m_range = jnp.arange(tm) n_range = jnp.arange(tn) # initial value of y for times operation yf = jnp.zeros(yl, dtype=dtype) # x padding to be used in adjoint operation xz = jnp.zeros(tm + yl -m) # initial value of y for trans operation ya = jnp.zeros(tn * n_blocks) def times1d(x): def body_func(i, y): xw = x[i*tn + n_range] idx = i * offset + m_range return y.at[idx].add(T.times(xw)) y = lax.fori_loop(0, n_blocks, body_func, yf) return y[:m] def trans1d(x): # pad x with zeros x = jnp.concatenate([x, xz]) def body_func(i, y): xw = x[i * offset + m_range] return y.at[i*tn + n_range].set(T.trans(xw)) return lax.fori_loop(0, n_blocks, body_func, ya) times, trans = apply_along_axis(times1d, trans1d, axis) return Operator(times=times, trans=trans, shape=(m,n), real=real)