Source code for cr.sparse._src.dict.props
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import jit
import jax.numpy as jnp
from cr.nimble import mat_hermitian
[docs]def gram(A):
"""Computes the Gram matrix :math:`G = A^T A`
"""
if jnp.isrealobj(A):
return A.T @ A
G = mat_hermitian(A) @ A
return G.real
[docs]def frame(A):
"""Computes the frame matrix :math:`G = A A^T`
"""
if jnp.isrealobj(A):
return A @ A.T
F = A @ mat_hermitian(A)
return F.real
[docs]def coherence_with_index(A):
"""Returns the coherence of a dictionary A along with indices of most correlated atoms
"""
G = gram(A)
G = jnp.abs(G)
n = G.shape[0]
# set diagonals to 0
G = G.at[jnp.diag_indices(n)].set(0)
index = jnp.unravel_index(jnp.argmax(G, axis=None), G.shape)
max_val = G[index]
return max_val, index
[docs]@jit
def coherence(A):
"""Computes the coherence of a dictionary
"""
max_val, index = coherence_with_index(A)
return max_val
[docs]def frame_bounds(A):
"""Computes the frame bounds (largest and smallest singular valuee)
"""
s = jnp.linalg.svd(A, False, False)
indices = jnp.array([0, -1])
return s[indices]
[docs]def upper_frame_bound(A):
"""Computes the upper frame bound for a dictionary
"""
s = jnp.linalg.svd(A, False, False)
return s[0]
[docs]def lower_frame_bound(A):
"""Computes the lower frame bound for a dictionary
"""
s = jnp.linalg.svd(A, False, False)
return s[-1]
[docs]@jit
def babel(A):
"""Computes the babel function for a dictionary (generalized coherence)
"""
# compute gram matrix
G = gram(A)
# compute absolute values
G = jnp.abs(G)
# sort on each row
G = jnp.sort(G)
# reverse each row and drop last entry [self similarity is 1]
G = G[:, -2::-1]
# compute cumulative sums over rows
sums = jnp.cumsum(G, axis=1)
# find maximum over each column
result = jnp.max(sums, axis=0)
return result
[docs]def mutual_coherence_with_index(A, B):
"""Mutual coherence between two dictionaries A and B along with indices of most correlated atoms
"""
# compute inner products of atoms of A with atoms of B
G = mat_hermitian(A) @ B
# Take absolute values
G = jnp.abs(G)
# Find the maximum value and identify its index
index = jnp.unravel_index(jnp.argmax(G, axis=None), G.shape)
# Maxium value
max_val = G[index]
return max_val, index
[docs]def mutual_coherence(A, B):
""""Mutual coherence between two dictionaries A and B
"""
max_val, index = mutual_coherence_with_index(A, B)
return max_val