Source code for cr.sparse._src.opt.proximal_ops.prox
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
from typing import NamedTuple, Callable, Tuple
import jax
from jax import jit
import jax.numpy as jnp
[docs]class ProxCapable(NamedTuple):
r"""Represents a function which is prox capable
The *proximal operator* for a function :math:`f` is defined as
.. math::
p_f(x, t) = \text{arg} \min_{z \in \RR^n} f(x) + \frac{1}{2t} \| z - x \|_2^2
Let `op` be a variable of type ProxCapable
which represents some prox-capable function :math:`f`. Then:
* `op.func(x)` returns the function value :math:`f(x)`.
* `op.prox_op(x)` returns the proximal vector for a step size: :math:`z = p_f(x, t)`.
* `op.prox_vec_val(x)` returns the pair :math:`z,v = p_f(x, t), f(z)`.
"""
func: Callable[[jnp.ndarray], float]
"""Definition of a prox capable function"""
prox_op: Callable[[jnp.ndarray, float], jnp.ndarray]
"""Proximal operator for the function"""
prox_vec_val: Callable[[jnp.ndarray, float], Tuple[float, jnp.ndarray]]
"A wrapper function to evaluate the proximal vector and the function value at the vector"
def build(func, prox_op):
r"""Creates a wrapper for a prox capable function
Args:
func: Definition of a a function :math:`f(x)`
prox_op: Definition of its proximal operator :math:`p_f(x, t)`
Returns:
ProxCapable: A prox-capable function
"""
func = jit(func)
prox_op = jit(prox_op)
prox_vec_val = build_prox_value_vec_func(func, prox_op)
return ProxCapable(func=func, prox_op=prox_op,
prox_vec_val=prox_vec_val)
def build3(func, prox_op, prox_vec_val):
r"""Creates a wrapper for a prox capable function
Args:
func: Definition of a a function :math:`f(x)`
prox_op: Definition of its proximal operator :math:`p_f(x, t)`
prox_vec_val: Combined function for generating both proximal point and value
Returns:
ProxCapable: A prox-capable function
"""
return ProxCapable(func=func, prox_op=prox_op,
prox_vec_val=prox_vec_val)
[docs]def build_from_ind_proj(indicator, projector):
"""Builds a prox capable function wrapper for a convex set indicator function
Args:
indicator: Definition of the indicator function for the convex set
projector: Definition of the projector function for the convex set
Returns:
ProxCapable: A prox-capable function
"""
indicator = jit(indicator)
@jit
def prox_op(x, t):
return projector(x)
@jit
def prox_vec_val(x, t):
# first project to the convex set
z = projector(x)
# the value of indicator function inside the convex set is 0.
return z, 0.
return ProxCapable(func=indicator, prox_op=prox_op,
prox_vec_val=prox_vec_val)
def build_prox_value_vec_func(func, prox_op):
r"""Creates function which computes the proximal vector and then function value at it
Args:
func: Definition of a a function :math:`f(x)`
prox_op: Definition of its proximal operator :math:`p_f(x, t)`
Returns:
A function which can compute the pair :math:`z,v = p_f(x,t), f(z)`
"""
@jit
def impl(x, t):
x = prox_op(x, t)
v = func(x)
return x, v
return impl