# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convolutions 1D, 2D, ND
"""
import numpy as np
import jax.numpy as jnp
import jax.scipy as jsp
from jax import lax
from .impl import _hermitian
from .lop import Operator
from .util import apply_along_axis
from cr.nimble import promote_arg_dtypes
[docs]def convolve(n, h, offset=0, axis=0):
"""Implements a convolution operator with the filter h
Note:
We don't use padding of coefficients of h. It turns out,
it is faster to perform a full convolution and then
"""
assert n > 0
m = len(h)
# The location of center of the filter response should be within it.
assert offset >= 0
assert offset < m
forward = offset
adjoint = m - 1 - offset
# for the adjoint, we will simply use the conjugate of h
h_conj = jnp.conjugate(h)
# invert the entries as lax convolve is actually correlation
h = h[slice(None, None, -1)]
f_slice = slice(forward, forward+n, None)
b_slice = slice(adjoint, adjoint+n, None)
# add additional dimensions to h
h = h[None, None, None]
h_conj = h_conj[None, None, None]
# padding for conv_general_dilated
padding = [(0, 0), (m - 1, m - 1)]
# strides for conv_general_dilated
strides = (1,1)
def times1d(x):
"""Forward convolution
"""
x, f = promote_arg_dtypes(x, h)
result = lax.conv_general_dilated(x[None, None, None], f, strides,
padding)
return result[0, 0, 0, f_slice]
def trans1d(x):
"""Adjoint convolution
"""
x, f = promote_arg_dtypes(x, h_conj)
result = lax.conv_general_dilated(x[None, None, None], f, strides,
padding)
return result[0, 0, 0, b_slice]
times, trans = apply_along_axis(times1d, trans1d, axis)
return Operator(times=times, trans=trans, shape=(n,n))
[docs]def convolve2D(shape, h, offset=None, axes=None):
"""Performs 2 dimensional convolution on the input array
"""
N = h.ndim
# The filter must be two dimensional
assert N == 2
# Implemented in terms of N dimensional convolution
return convolveND(shape, h, offset, axes)
[docs]def convolveND(shape, h, offset=None, axes=None):
"""Performs N dimensional convolution on input array
"""
# The dimensions of the filter
filter_ndim = h.ndim
# The dimensions of the data
data_ndim = len(shape)
# By default offset along each filter dimension is 0
if offset is None:
offset = np.zeros(filter_ndim, dtype=int)
else:
offset = np.array(offset)
# offset dimensions must match the filter dimensions
assert offset.size == filter_ndim
if axes is None:
# By default, the convolution will happen over the first filter_ndim dimensions
axes = np.arange(filter_ndim)
else:
axes = np.array(axes)
# prepare the slices to be extracted from convolution results
f_slices = [slice(None) for _ in range(data_ndim)]
a_slices = [slice(None) for _ in range(data_ndim)]
for i, ax in enumerate(axes):
# offset along i-th axis
off_ax = offset[i]
# the filter size for i-th axis
h_ax = h.shape[i]
# the data size for i-th axis
n_ax = shape[ax]
# the offset for the adjoint operator for i-th axis
adj_ax = h_ax - 1 - off_ax
# the forward slice for the i-th axis
f_slices[ax] = slice(off_ax, off_ax + n_ax)
# the adjoint slice for the i-th axis
a_slices[ax] = slice(adj_ax, adj_ax + n_ax)
# print(f_slices)
# print(a_slices)
f_slices = tuple(f_slices)
a_slices = tuple(a_slices)
# check if filter dimensions and data dimensions match
if data_ndim != filter_ndim:
# Extend the filter to the data dimensions
h_dims = np.ones(data_ndim, dtype=int)
h_dims[axes] = h.shape
h = jnp.reshape(h, h_dims)
padding = [(s - 1, s - 1) for s in h.shape]
strides = tuple(1 for s in h.shape)
# reverse the h kernel
h_conv = h[tuple(slice(None, None, -1) for s in shape)]
# extend it
h_conv = h_conv[None, None]
h_corr = h[None, None]
def times(x):
"""Forward N-D convolution
"""
x, f = promote_arg_dtypes(x, h_conv)
# Make sure that x has the appropriate shape
x = jnp.reshape(x, shape)
result = lax.conv_general_dilated(x[None, None], f, strides,
padding)
result = result[0, 0]
# pick the slices from other dimensions
return result[f_slices]
def trans(x):
"""Backward N-D convolution
"""
x, f = promote_arg_dtypes(x, h_corr)
# Make sure that x has the appropriate shape
x = jnp.reshape(x, shape)
result = lax.conv_general_dilated(x[None, None], f, strides,
padding)
result = result[0, 0]
# pick the slices from other dimensions
return result[a_slices]
return Operator(times=times, trans=trans, shape=(shape,shape))