Source code for cr.sparse._src.lop.tv
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Total variation linear operator
"""
import jax.numpy as jnp
import cr.nimble as cnb
from .lop import Operator
from .util import apply_along_axis
REGULAR = 'regular'
DIRICHLET = 'dirichlet'
CIRCULAR = 'circular'
def diff_fwd_1d_regular(x):
append = jnp.array([x[-1]])
return jnp.diff(x, append=append)
def diff_adj_1d_regular(x):
x = x.at[-1].set(0)
prepend = jnp.array([0])
return jnp.diff(-x, prepend=prepend)
def diff_fwd_1d_dirichlet(x):
append = jnp.array([0])
return jnp.diff(x, append=append)
def diff_adj_1d_dirichlet(x):
x1 = cnb.vec_shift_right(x)
return x1 - x
def diff_fwd_1d_circular(x):
append = jnp.array([x[0]])
return jnp.diff(x, append=append)
def diff_adj_1d_circular(x):
x1 = cnb.vec_rotate_right(x)
return x1 - x
[docs]def tv(n, kind='regular', axis=0):
r"""Returns a total variation linear operator for 1D signals
Args:
n (int): Dimension of the model space
kind (str): Boundary condition for handling differences
axis (int): For multi-dimensional array input, the axis along which
the linear operator will be applied
Returns:
(Operator): An linear operator which computes the variation in 1D signals
Note:
To compute the total variation, we first apply the linear operator
and then compute the norm of the variation.
"""
if kind == REGULAR:
times = diff_fwd_1d_regular
trans = diff_adj_1d_regular
elif kind == DIRICHLET:
times = diff_fwd_1d_dirichlet
trans = diff_adj_1d_dirichlet
elif kind == CIRCULAR:
times = diff_fwd_1d_circular
trans = diff_adj_1d_circular
else:
raise NotImplementedError(f"The kind {kind} is not supported")
times, trans = apply_along_axis(times, trans, axis)
return Operator(times=times, trans=trans, shape=(n,n))
[docs]def tv2D(shape, kind='regular'):
r"""Returns a total variation linear operator for 2D images
Args:
shape (int): Shape of the input images (model space)
kind (str): Boundary condition for handling differences
Returns:
(Operator): An linear operator which computes the variation in 2D images
Note:
The output is a complex image. The horizontal differences are stored
in the real part and the vertical differences are stored in the imaginary
part.
To compute the total variation, we first apply the linear operator
and then compute the norm of the variation image.
"""
if kind == REGULAR:
times1d = diff_fwd_1d_regular
trans1d = diff_adj_1d_regular
elif kind == DIRICHLET:
times1d = diff_fwd_1d_dirichlet
trans1d = diff_adj_1d_dirichlet
elif kind == CIRCULAR:
times1d = diff_fwd_1d_circular
trans1d = diff_adj_1d_circular
else:
raise NotImplementedError(f"The kind {kind} is not supported")
def times(X):
"""Forward total variation
"""
# horizontal variation
Dh = jnp.apply_along_axis(times1d, 1, X)
# vertical variation
Dv = jnp.apply_along_axis(times1d, 0, X)
# combine them to complex output
return Dh + Dv * 1j
def trans(X):
"""Adjoint total variation
"""
# horizontal variation
Dh = jnp.apply_along_axis(trans1d, 1, X)
# vertical variation
Dv = jnp.apply_along_axis(trans1d, 0, X)
# combine them to complex output
return Dh + Dv * 1j
return Operator(times=times, trans=trans, shape=(shape,shape))