Source code for cr.sparse._src.fom.scd
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
First Order Solver for Smooth Conic Dual Problem
"""
from .defs import FomOptions, FomState
import jax.numpy as jnp
from jax import jit, lax
import cr.nimble as cnb
import cr.sparse.opt as opt
import cr.sparse.lop as lop
from .fom import fom
def smooth_dual(prox_f: opt.ProxCapable, mu=1, x0=0):
"""Constructs the smooth dual of a prox capable function
"""
@jit
def func(x):
"""Computes the value of the function at x
"""
x = jnp.asarray(x)
x = cnb.promote_arg_dtypes(x)
px, pv = prox_f.prox_vec_val(x0 + mu * x, mu)
v = cnb.arr_rdot(x, px) - pv - (0.5/mu) * cnb.arr_rnorm_sqr(px - x0)
return -v
@jit
def grad(x):
"""Computes the gradient of the smooth function at x"""
x = jnp.asarray(x)
x = cnb.promote_arg_dtypes(x)
px = prox_f.prox_op(x0 + mu * x, mu)
return -px
@jit
def grad_val(x):
"""Computes the gradient as well as the value of the function at x"""
x = jnp.asarray(x)
x = cnb.promote_arg_dtypes(x)
px, pv = prox_f.prox_vec_val(x0 + mu * x, mu)
# the gradient
g = -px
v = cnb.arr_rdot(x, px) - pv - (0.5/mu) * cnb.arr_rnorm_sqr(px - x0)
v = -v
return g,v
return opt.smooth_build3(func, grad, grad_val)
[docs]def scd(prox_f: opt.ProxCapable, conj_neg_h: opt.ProxCapable,
A: lop.Operator, b, mu, x0, z0, options: FomOptions = FomOptions()):
r"""First order solver for smooth conic dual problems driver routine
Args:
prox_f (cr.sparse.opt.SmoothFunction): A prox-capable objective function
conj_neg_h (cr.sparse.opt.ProxCapable): The conjugate negative :math:`h^{-}` function
A (cr.sparse.lop.Operator): A linear operator
b (jax.numpy.ndarray): The translation vector
mu (float): The (positive) scaling term for the quadratic term :math:`\frac{\mu}{2} \| x - x_0 \|_2^2`
x0 (jax.numpy.ndarray): The center point for the quadratic term
z0 (jax.numpy.ndarray): The initial dual point
options (FomOptions): Options for configuring the algorithm
Returns:
FomState: Solution of the optimization problem
The function uses first order conic solver algorithms to solve an
optimization problem of the form:
.. math::
\underset{x}{\text{minimize}}
\left [ f(x) + \frac{\mu}{2} \| x - x_0 \|_2^2
+ h \left (\AAA(x) + b \right) \right ]
* Both :math:`f, h` must be convex and prox-capable, although neither needs to be smooth.
When :math:`h` is an indicator function for a convex cone :math:`\KKK`, this is
equivalent to:
.. math::
\begin{split}\begin{aligned}
& \underset{x}{\text{minimize}}
& & f(x) + \frac{\mu}{2} \| x - x_0 \|_2^2\\
& \text{subject to}
& & \AAA(x) + b \in \KKK
\end{aligned}\end{split}
which is the smooth conic dual (SCD) model discussed in :cite:`becker2011templates`.
"""
smooth_f = smooth_dual(prox_f, mu, x0)
options = options._replace(saddle=True, maximize=True)
sol = fom(smooth_f, conj_neg_h, A, b, z0, options)
return sol