Source code for cr.sparse._src.problems.problems

# Copyright 2022 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import matplotlib.pyplot as plt

from typing import Callable,Tuple, List

PROBLEM_MAP = {
    'heavi-sine:fourier:heavi-side': 'prob001',
    'blocks:haar': 'prob002',
    'cosine-spikes:dirac-dct' : 'prob003',
    'complex:sinusoid-spikes:dirac-fourier' : 'prob004',
    'cosine-spikes:dirac-dct:gaussian' : 'prob005',
    'piecewise-cubic-poly:daubechies:gaussian': 'prob006',
    'signed-spikes:dirac:gaussian': 'prob007',
    'complex:signed-spikes:dirac:gaussian': 'prob008',
    'blocks:heavi-side': 'prob009',
    'blocks:normalized-heavi-side': 'prob010',
    'gaussian-spikes:dirac:gaussian': 'prob011',
    'src-sep-1': 'prob401'
}

import cr.nimble as crn
import jax.numpy as jnp
from cr.nimble.dsp import largest_indices

from .spec import Problem

[docs]def names() -> List[str]: """Returns the names of available problems """ return sorted(PROBLEM_MAP.keys())
[docs]def generate(name:str, key=crn.KEY0, **args) -> Problem: """Generates a test problem by its name and problem specific arguments Args: name (str): The name of the problem to be instantiated. key: A PRNG key to be used if the problem requires randomized data args (dict): The list of keyword parameters to be passed to the problem generation code Returns: Problem: An instance of a problem """ assert name in PROBLEM_MAP, "Unrecognized problem" problem_id = PROBLEM_MAP[name] module = importlib.import_module('.' + problem_id, 'cr.sparse._src.problems') problem = module.generate(key, **args) return problem
[docs]def plot(problem, height=4): """Plots the figures associated with a problem This is just a utility function to quickly draw all the figures in one shot. You may be interested in drawing specific individual figures associated with a problem. """ nf = len(problem.figures) fig, ax = plt.subplots(nf, 1, figsize=(15, height*nf), constrained_layout = True) for i in range(nf): problem.plot(i, ax[i]) return fig, ax
[docs]def analyze_solution(problem, solution, perc=99.9): """Provides a quick analysis of sparse recovery by one of the algorithms in CR-Sparse """ A = problem.A b0 = problem.b x0 = problem.x y0 = problem.y m, n = A.shape x = solution.x b = A.times(x) print(f'm: {m}, n: {n}') print(f'b_norm: original: {crn.arr_l2norm(b0):.3f} reconstruction: {crn.arr_l2norm(b):.3f}' + f' SNR: {crn.signal_noise_ratio(b0, b):.2f} dB') if x0 is not None: print(f'x_norm: original: {crn.arr_l2norm(x0):.3f} reconstruction: {crn.arr_l2norm(x):.3f}' + f' SNR: {crn.signal_noise_ratio(x0, x):.2f} dB') if y0 is not None and (problem.both or x0 is None): y = problem.reconstruct(x) print(f'y_norm: original: {crn.arr_l2norm(y0):.3f} reconstruction: {crn.arr_l2norm(y):.3f}' + f' SNR: {crn.signal_noise_ratio(y0, y):.2f} dB') if x0 is not None: k1 = crn.num_largest_coeffs_for_energy_percent(x0, perc) s1 = largest_indices(x0, k1) k2 = crn.num_largest_coeffs_for_energy_percent(x, perc) s2 = largest_indices(x, k2) overlap = jnp.intersect1d(s1, s2) correct = overlap.size ratio = correct / max(k1, k2) print(f'Sparsity: original: {k1}, reconstructed: {k2}, overlap: {correct}, ratio: {ratio:.3f}') if hasattr(solution, 'iterations'): print(f'Iterations: {solution.iterations} ', end='') if hasattr(solution, 'n_times'): print(f'n_times: {solution.n_times}, n_trans: {solution.n_trans}', end='') print('\n')