9.1. Computation of row-wise and column-wise normsΒΆ

import jax
import jax.numpy as jnp
import cr.sparse as crs
M = 10
p = 3
N = 5
A = jnp.arange(1, M+1) * jnp.ones((p, 1))
A = jnp.sign(jax.random.normal(jax.random.PRNGKey(0), (p, M))) * A
A
DeviceArray([[ -1.,   2.,  -3.,   4.,   5.,   6.,  -7.,   8.,  -9.,  10.],
             [ -1.,   2.,   3.,   4.,   5.,   6.,  -7.,   8.,   9., -10.],
             [ -1.,   2.,   3.,  -4.,  -5.,  -6.,  -7.,  -8.,  -9., -10.]],            dtype=float32)
# Euclidean norm squared for each column
crs.sqr_norms_l2_cw(A)
DeviceArray([  3.,  12.,  27.,  48.,  75., 108., 147., 192., 243., 300.],            dtype=float32)
B = A.T
B
DeviceArray([[ -1.,  -1.,  -1.],
             [  2.,   2.,   2.],
             [ -3.,   3.,   3.],
             [  4.,   4.,  -4.],
             [  5.,   5.,  -5.],
             [  6.,   6.,  -6.],
             [ -7.,  -7.,  -7.],
             [  8.,   8.,  -8.],
             [ -9.,   9.,  -9.],
             [ 10., -10., -10.]], dtype=float32)
# Euclidean norm squared for each row
crs.sqr_norms_l2_rw(B)
DeviceArray([  3.,  12.,  27.,  48.,  75., 108., 147., 192., 243., 300.],            dtype=float32)
# Euclidean norm for each column
crs.norms_l2_cw(A)
DeviceArray([ 1.7320508,  3.4641016,  5.196152 ,  6.928203 ,  8.660254 ,
             10.392304 , 12.124355 , 13.856406 , 15.588457 , 17.320507 ],            dtype=float32)
# Euclidean norm for each row
crs.norms_l2_rw(B)
DeviceArray([ 1.7320508,  3.4641016,  5.196152 ,  6.928203 ,  8.660254 ,
             10.392304 , 12.124355 , 13.856406 , 15.588457 , 17.320507 ],            dtype=float32)
# L1 or city block norm for each column
crs.norms_l1_cw(A)
DeviceArray([ 3.,  6.,  9., 12., 15., 18., 21., 24., 27., 30.], dtype=float32)
# L1 or city block norm for each row
crs.norms_l1_rw(B)
DeviceArray([ 3.,  6.,  9., 12., 15., 18., 21., 24., 27., 30.], dtype=float32)
# L-inf or Chebyshev norm for each column
crs.norms_linf_cw(A)
DeviceArray([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], dtype=float32)
# L-inf or Chebyshev norm for each row
crs.norms_linf_rw(B)
DeviceArray([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], dtype=float32)
# l2-normalize each column
print(crs.normalize_l2_cw(A))
[[-0.57735026  0.57735026 -0.5773503   0.57735026  0.5773503   0.5773503
  -0.57735026  0.57735026 -0.5773503   0.5773503 ]
 [-0.57735026  0.57735026  0.5773503   0.57735026  0.5773503   0.5773503
  -0.57735026  0.57735026  0.5773503  -0.5773503 ]
 [-0.57735026  0.57735026  0.5773503  -0.57735026 -0.5773503  -0.5773503
  -0.57735026 -0.57735026 -0.5773503  -0.5773503 ]]
# l2-normalize each row
print(crs.normalize_l2_rw(B))
[[-0.57735026 -0.57735026 -0.57735026]
 [ 0.57735026  0.57735026  0.57735026]
 [-0.5773503   0.5773503   0.5773503 ]
 [ 0.57735026  0.57735026 -0.57735026]
 [ 0.5773503   0.5773503  -0.5773503 ]
 [ 0.5773503   0.5773503  -0.5773503 ]
 [-0.57735026 -0.57735026 -0.57735026]
 [ 0.57735026  0.57735026 -0.57735026]
 [-0.5773503   0.5773503  -0.5773503 ]
 [ 0.5773503  -0.5773503  -0.5773503 ]]
# l1-normalize each column
print(crs.normalize_l1_cw(A))
[[-0.33333334  0.33333334 -0.33333334  0.33333334  0.33333334  0.33333334
  -0.33333334  0.33333334 -0.33333334  0.33333334]
 [-0.33333334  0.33333334  0.33333334  0.33333334  0.33333334  0.33333334
  -0.33333334  0.33333334  0.33333334 -0.33333334]
 [-0.33333334  0.33333334  0.33333334 -0.33333334 -0.33333334 -0.33333334
  -0.33333334 -0.33333334 -0.33333334 -0.33333334]]
# l1-normalize each row
print(crs.normalize_l1_rw(B))
[[-0.33333334 -0.33333334 -0.33333334]
 [ 0.33333334  0.33333334  0.33333334]
 [-0.33333334  0.33333334  0.33333334]
 [ 0.33333334  0.33333334 -0.33333334]
 [ 0.33333334  0.33333334 -0.33333334]
 [ 0.33333334  0.33333334 -0.33333334]
 [-0.33333334 -0.33333334 -0.33333334]
 [ 0.33333334  0.33333334 -0.33333334]
 [-0.33333334  0.33333334 -0.33333334]
 [ 0.33333334 -0.33333334 -0.33333334]]