Source code for cr.sparse._src.fom.defs
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NamedTuple
import jax.numpy as jnp
[docs]class FomOptions(NamedTuple):
"""Options for FOCS driver routine
"""
nonneg : bool = False
"Whether output is expected to be non-negative"
solver : str = 'at'
"Default first order conic solver"
max_iters: int = 1000
"Maximum number of iterations for the solver"
tol: float = 1e-8
"Tolerance for convergence"
L0 : float = 1.
"Initial estimate of Lipschitz constant"
Lexact: float = jnp.inf
"Known bound of Lipschitz constant"
alpha: float = 0.9
"Line search increase parameter, in (0,1)"
beta: float = 0.5
"Backtracking parameter, in (0,1). No line search if >= 1"
mu: float = 0
"Strong convexity parameter"
maximize : bool = False
"By default, we attempt minimization of the objective, otherwise maximize"
saddle: bool = False
"Indicates if it's a saddle point problem setup by SCD subroutine"
def __str__(self):
s = []
s.append(f'solver={self.solver}')
s.append(f'max_iters={self.max_iters}')
s.append(f'tol={self.tol}')
s.append(f'L0={self.L0}')
s.append(f'Lexact={self.Lexact}')
s.append(f'alpha={self.alpha}')
s.append(f'beta={self.beta}')
s.append(f'mu={self.mu}')
return '\n'.join(s)
[docs]class FomState(NamedTuple):
"""
State of the FOCS method
"""
L : float
"Lipschitz constant estimate"
theta: float
""
x: jnp.ndarray
""
A_x : jnp.ndarray
"A @ x "
g_Ax : jnp.ndarray
"gradient of f at A @ x + b"
g_x : jnp.ndarray
"A^H (g_Ax)"
f_x : float
" f(A_x + b)"
C_x : float
"value of nonsmooth function h at x"
y: jnp.ndarray
A_y : jnp.ndarray
g_Ay : jnp.ndarray
g_y : jnp.ndarray
f_y : float
C_y : float
z : jnp.ndarray
A_z: jnp.ndarray
g_Az : jnp.ndarray
g_z : jnp.ndarray
f_z : float
C_z : float
# quantities for convergence check
norm_x : float
norm_dx : float
# counters
iterations: int
def __str__(self):
s = []
s.append(f'L={self.L:.2f}, theta={self.theta:.2f}')
s.append(f'f_x={self.f_x:.2f}, f_y={self.f_y:.2f}, f_z={self.f_z:.2f}')
s.append(f'C_x={self.C_x:.2f}, C_y={self.C_y:.2f}, C_z={self.C_z:.2f}')
s.append(f'norm_x={self.norm_x:.2f}, norm_dx={self.norm_dx:.2e}')
s.append(f'iterations={self.iterations}')
s.append(f'')
return '\n'.join(s)
@property
def at_str(self):
s = []
s.append(f'iterations={self.iterations}, L={self.L:.2f}, theta={self.theta:.2f}')
s.append(f'f_y={self.f_y:.2e}, C_z={self.C_z:.2e}, norm_x={self.norm_x:.2e}, norm_dx={self.norm_dx:.2e}')
return '\n'.join(s)