Source code for cr.sparse._src.opt.indicators.lpballs
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import jit
import jax.numpy as jnp
from jax.numpy.linalg import qr, norm
import cr.nimble as cnb
[docs]def indicator_l2_ball(q=1., b=None, A=None):
r"""Returns an indicator function for the closed ball :math:`\| A x - b \|_2 \leq q`
Args:
q (float) : Radius of the ball
b (jax.numpy.ndarray): A vector :math:`b \in \RR^{m}`
A (jax.numpy.ndarray): A matrix :math:`A \in \RR^{m \times n}`
Returns:
An indicator function
The indicator function is defined as:
.. math::
I(x) = \begin{cases}
0 & \text{if } \| A x - b \|_2 \leq q \\
\infty & \text{otherwise}
\end{cases}
Special cases:
* ``indicator_l2_ball()`` returns the Euclidean unit ball :math:`\| x \|_2 \leq 1`.
* ``indicator_l2_ball(q)`` returns the Euclidean ball :math:`\| x \|_2 \leq q`.
* ``indicator_l2_ball(q, b=b)`` returns the Euclidean ball at center :math:`b`, :math:`\| x - b\|_2 \leq q`.
Notes:
* If center :math:`b \in \RR^m` is unspecified, we assume the center to be at origin.
* If radius :math:`q` is unspecified, we assume the radius to be 1.
* If the matrix :math:`A` is unspecified, we assume :math:`A` to be the identity matrix
:math:`I \in \RR^{n \times n}`.
"""
if b is not None:
b = jnp.asarray(b)
b = cnb.promote_arg_dtypes(b)
if A is not None:
A = jnp.asarray(A)
A = cnb.promote_arg_dtypes(A)
if q <= 0:
raise ValueError("q must be greater than 0")
@jit
def indicator_q(x):
x = jnp.asarray(x)
x = cnb.promote_arg_dtypes(x)
invalid = norm(x) > q
return jnp.where(invalid, jnp.inf, 0)
if b is None and A is None:
return indicator_q
@jit
def indicator_q_b(x):
x = jnp.asarray(x)
x = cnb.promote_arg_dtypes(x)
# compute difference from center
r = x - b
invalid = norm(r) > q
return jnp.where(invalid, jnp.inf, 0)
if A is None:
return indicator_q_b
if b is None:
# we have q and A specified.
# default value for b
b = 0.
@jit
def indicator_q_b_A(x):
x = jnp.asarray(x)
x = cnb.promote_arg_dtypes(x)
# compute the residual vector
r = A @ x - b
invalid = norm(r) > q
return jnp.where(invalid, jnp.inf, 0)
return indicator_q_b_A
[docs]def indicator_l1_ball(q=1., b=None, A=None):
r"""Returns an indicator function for the closed l1 ball :math:`\| A x - b \|_1 \leq q`
Args:
q (float) : Radius of the ball
b (jax.numpy.ndarray): A vector :math:`b \in \RR^{m}`
A (jax.numpy.ndarray): A matrix :math:`A \in \RR^{m \times n}`
Returns:
An indicator function
The indicator function is defined as:
.. math::
I(x) = \begin{cases}
0 & \text{if } \| A x - b \|_1 \leq q \\
\infty & \text{otherwise}
\end{cases}
Special cases:
* ``indicator_l1_ball()`` returns the l1 unit ball :math:`\| x \|_1 \leq 1`.
* ``indicator_l1_ball(q)`` returns the l1 ball :math:`\| x \|_1 \leq q`.
* ``indicator_l1_ball(q, b=b)`` returns the l1 ball at center :math:`b`, :math:`\| x - b\|_1 \leq q`.
Notes:
* If center :math:`b \in \RR^m` is unspecified, we assume the center to be at origin.
* If radius :math:`q` is unspecified, we assume the radius to be 1.
* If the matrix :math:`A` is unspecified, we assume :math:`A` to be the identity matrix
:math:`I \in \RR^{n \times n}`.
"""
if b is not None:
b = jnp.asarray(b)
b = cnb.promote_arg_dtypes(b)
if A is not None:
A = jnp.asarray(A)
A = cnb.promote_arg_dtypes(A)
# TODO: This creates problems in JIT
# assert q > 0, ValueError("q must be greater than 0")
@jit
def indicator_q(x):
x = jnp.asarray(x)
x = cnb.promote_arg_dtypes(x)
invalid = cnb.arr_l1norm(x) > q
return jnp.where(invalid, jnp.inf, 0)
if b is None and A is None:
return indicator_q
@jit
def indicator_q_b(x):
# compute difference from center
x = jnp.asarray(x)
x = cnb.promote_arg_dtypes(x)
r = x - b
invalid = cnb.arr_l1norm(r) > q
return jnp.where(invalid, jnp.inf, 0)
if A is None:
return indicator_q_b
if b is None:
# we have q and A specified.
# default value for b
b = 0.
@jit
def indicator_q_b_A(x):
x = jnp.asarray(x)
x = cnb.promote_arg_dtypes(x)
# compute the residual vector
r = A @ x - b
invalid = cnb.arr_l1norm(r) > q
return jnp.where(invalid, jnp.inf, 0)
return indicator_q_b_A