9.1. Computation of row-wise and column-wise normsΒΆ
import jax
import jax.numpy as jnp
import cr.sparse as crs
M = 10
p = 3
N = 5
A = jnp.arange(1, M+1) * jnp.ones((p, 1))
A = jnp.sign(jax.random.normal(jax.random.PRNGKey(0), (p, M))) * A
A
DeviceArray([[ -1., 2., -3., 4., 5., 6., -7., 8., -9., 10.],
[ -1., 2., 3., 4., 5., 6., -7., 8., 9., -10.],
[ -1., 2., 3., -4., -5., -6., -7., -8., -9., -10.]], dtype=float32)
# Euclidean norm squared for each column
crs.sqr_norms_l2_cw(A)
DeviceArray([ 3., 12., 27., 48., 75., 108., 147., 192., 243., 300.], dtype=float32)
B = A.T
B
DeviceArray([[ -1., -1., -1.],
[ 2., 2., 2.],
[ -3., 3., 3.],
[ 4., 4., -4.],
[ 5., 5., -5.],
[ 6., 6., -6.],
[ -7., -7., -7.],
[ 8., 8., -8.],
[ -9., 9., -9.],
[ 10., -10., -10.]], dtype=float32)
# Euclidean norm squared for each row
crs.sqr_norms_l2_rw(B)
DeviceArray([ 3., 12., 27., 48., 75., 108., 147., 192., 243., 300.], dtype=float32)
# Euclidean norm for each column
crs.norms_l2_cw(A)
DeviceArray([ 1.7320508, 3.4641016, 5.196152 , 6.928203 , 8.660254 ,
10.392304 , 12.124355 , 13.856406 , 15.588457 , 17.320507 ], dtype=float32)
# Euclidean norm for each row
crs.norms_l2_rw(B)
DeviceArray([ 1.7320508, 3.4641016, 5.196152 , 6.928203 , 8.660254 ,
10.392304 , 12.124355 , 13.856406 , 15.588457 , 17.320507 ], dtype=float32)
# L1 or city block norm for each column
crs.norms_l1_cw(A)
DeviceArray([ 3., 6., 9., 12., 15., 18., 21., 24., 27., 30.], dtype=float32)
# L1 or city block norm for each row
crs.norms_l1_rw(B)
DeviceArray([ 3., 6., 9., 12., 15., 18., 21., 24., 27., 30.], dtype=float32)
# L-inf or Chebyshev norm for each column
crs.norms_linf_cw(A)
DeviceArray([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], dtype=float32)
# L-inf or Chebyshev norm for each row
crs.norms_linf_rw(B)
DeviceArray([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], dtype=float32)
# l2-normalize each column
print(crs.normalize_l2_cw(A))
[[-0.57735026 0.57735026 -0.5773503 0.57735026 0.5773503 0.5773503
-0.57735026 0.57735026 -0.5773503 0.5773503 ]
[-0.57735026 0.57735026 0.5773503 0.57735026 0.5773503 0.5773503
-0.57735026 0.57735026 0.5773503 -0.5773503 ]
[-0.57735026 0.57735026 0.5773503 -0.57735026 -0.5773503 -0.5773503
-0.57735026 -0.57735026 -0.5773503 -0.5773503 ]]
# l2-normalize each row
print(crs.normalize_l2_rw(B))
[[-0.57735026 -0.57735026 -0.57735026]
[ 0.57735026 0.57735026 0.57735026]
[-0.5773503 0.5773503 0.5773503 ]
[ 0.57735026 0.57735026 -0.57735026]
[ 0.5773503 0.5773503 -0.5773503 ]
[ 0.5773503 0.5773503 -0.5773503 ]
[-0.57735026 -0.57735026 -0.57735026]
[ 0.57735026 0.57735026 -0.57735026]
[-0.5773503 0.5773503 -0.5773503 ]
[ 0.5773503 -0.5773503 -0.5773503 ]]
# l1-normalize each column
print(crs.normalize_l1_cw(A))
[[-0.33333334 0.33333334 -0.33333334 0.33333334 0.33333334 0.33333334
-0.33333334 0.33333334 -0.33333334 0.33333334]
[-0.33333334 0.33333334 0.33333334 0.33333334 0.33333334 0.33333334
-0.33333334 0.33333334 0.33333334 -0.33333334]
[-0.33333334 0.33333334 0.33333334 -0.33333334 -0.33333334 -0.33333334
-0.33333334 -0.33333334 -0.33333334 -0.33333334]]
# l1-normalize each row
print(crs.normalize_l1_rw(B))
[[-0.33333334 -0.33333334 -0.33333334]
[ 0.33333334 0.33333334 0.33333334]
[-0.33333334 0.33333334 0.33333334]
[ 0.33333334 0.33333334 -0.33333334]
[ 0.33333334 0.33333334 -0.33333334]
[ 0.33333334 0.33333334 -0.33333334]
[-0.33333334 -0.33333334 -0.33333334]
[ 0.33333334 0.33333334 -0.33333334]
[-0.33333334 0.33333334 -0.33333334]
[ 0.33333334 -0.33333334 -0.33333334]]