Source code for cr.sparse._src.cs.cs1bit.biht

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple, List, Dict


import jax.numpy as jnp
from jax import vmap, jit, lax
from jax.numpy.linalg import norm

from cr.nimble.dsp import (hard_threshold,
    build_signal_from_indices_and_values)


[docs]class BIHTState(NamedTuple): """Represents the state of the BIHT algorithm """ # The non-zero values x_I: jnp.ndarray """Non-zero values""" I: jnp.ndarray """The support for non-zero values""" r: jnp.ndarray """The residuals""" n_mismatched_bits: jnp.ndarray # """The number of bits that are mismatched""" iterations: int """The number of iterations it took to complete"""
[docs]def biht(Phi, y, K, tau, max_iters=1000): r"""Solves the 1-bit compressive sensing problem :math:`\text{sgn} (\\Phi x) = y` using Binary Iterative Hard Thresholding Args: Phi (jax.numpy.ndarray): A random dictionary of shape (M, N) y (jax.numpy.ndarray): The 1-bit measurements K (int): Sparsity level of solution x (number of non-zero entries) tau (float): Step size for the x update step max_iters (int): Maximum number of iterations Returns: (BIHTState): A named tuple containing the solution x and other details We assume that :math:`x` is a K-sparse vector. We assume that the one-bit measurements are made as follows: .. math:: y = \text{sgn} (\Phi x) Thus the vector y contains entries 1 and -1 for the signs of the entries in the measurement :math:`\Phi x`. The BIHT algorithm proceeds as follows: - Start with an estimate :math:`x = 0` - Compute the guess :math:`\hat{y} = \text{sgn} (\Phi x)` - Measure the residual :math:`r = y - \hat{y}` - Count the number of mismatched bits as number of places where r is non-zero. - Compute the correlation :math:`h = \Phi^T r` - Update x as :math:`x = x + \frac{\tau}{2} h` - Hard threshold x to keep only K largest entries - Repeat till convergence Example: >>> import cr.sparse as crs >>> import cr.sparse.dict as crdict >>> import cr.sparse.data as crdata >>> import cr.sparse.cs.cs1bit as cs1bit >>> M, N, K = 256, 512, 4 >>> Phi = crdict.gaussian_mtx(cnb.KEYS[0], M, N, normalize_atoms=False) >>> x, omega = crdata.sparse_normal_representations(cnb.KEYS[1], N, K) >>> x = x / norm(x) >>> y = cs1bit.measure_1bit(Phi, x) >>> s0 = crdict.upper_frame_bound(Phi) >>> tau = 0.98 * s0 >>> state = cs1bit.biht_jit(Phi, y, K, tau) >>> x_rec = build_signal_from_indices_and_values(N, state.I, state.x_I) >>> x_rec = x_rec / norm(x_rec) """ ## Initialize some constants for the algorithm M, N = Phi.shape def init(): # Data for the initial approximation [r = y, x = 0] I = jnp.arange(0, K) x_I = jnp.zeros(K) # Assume initial estimate to be zero and compute residual # compute the 1 bit output based on current x estimate y_int = Phi[:,I] @ x_I y_hat = jnp.sign(y_int) # difference between actual y and current estimate r = y - y_hat n_mismatched_bits = jnp.sum(r != 0) return BIHTState(x_I=x_I, I=I, r=r, n_mismatched_bits=n_mismatched_bits, iterations=0) def body(state): # Compute the correlation of the residual with the atoms of Phi h = Phi.T @ state.r # current approximation x = build_signal_from_indices_and_values(N, state.I, state.x_I) # update x = x + tau/2 * h # threshold I, x_I = hard_threshold(x, K) # Form the subdictionary of corresponding atoms Phi_I = Phi[:, I] # compute the 1 bit output based on current x estimate y_int = Phi_I @ x_I y_hat = jnp.sign(y_int) # Compute new residual # difference between actual y and current estimate r = y - y_hat # Compute residual norm squared n_mismatched_bits = jnp.sum(r != 0) # update new state return BIHTState(x_I=x_I, I=I, r=r, n_mismatched_bits=n_mismatched_bits, iterations=state.iterations+1) def cond(state): # limit on residual norm a = state.n_mismatched_bits > 0 # limit on number of iterations b = state.iterations < max_iters c = jnp.logical_and(a, b) # overall condition return c state = lax.while_loop(cond, body, init()) return state
biht_jit = jit(biht, static_argnums=(2))