Image Deblurring

This example demonstrates following features:

  • cr.sparse.lop.convolve2D A 2D convolution linear operator

  • cr.sparse.sls.lsqr LSQR algorithm for solving a least square problem on 2D images

  • cr.sparse.lop.dwt2D A 2D discrete wavelet basis operator

  • cr.sparse.sls.fista Fast Iterative Shrinkage and Thresholding Algorithm on 2D images

Image deblurring can be treated as a deconvolution problem if the filter used for blurring the image is known.

Please see the deconvolution example for some background.

Let’s import necessary libraries

import jax.numpy as jnp
# For plotting diagrams
import matplotlib.pyplot as plt
## CR-Sparse modules
import cr.nimble as crn
# Linear operators
from cr.sparse import lop
# Image processing utilities
from cr.sparse import vision
# Solvers for sparse linear systems
from cr.sparse import sls
# Several thresholding functions are available in this module
from cr.sparse import geo
# Sample images
import skimage.data
# Configure JAX for 64-bit computing
from jax.config import config
config.update("jax_enable_x64", True)

Problem Setup

image = skimage.data.checkerboard()
print(image.shape)
(200, 200)

Gaussian blur kernel

h  = vision.kernel_gaussian((15,25), (8,4))
# plot the kernel
fig, ax = plt.subplots(1, 1, figsize=(5, 3))
him = ax.imshow(h)
ax.set_title('Blurring kernel')
fig.colorbar(him, ax=ax)
ax.axis('tight')
Blurring kernel
(-0.5, 24.5, 14.5, -0.5)

The linear operator for the blur kernel

Locate the center of the filter

offset = crn.arr_largest_index(h)
print(offset)
# Construct a 2D convolution operator based on the kernel
H = lop.convolve2D(image.shape, h, offset=offset)
# JIT compile the convolution operator for efficiency
H = lop.jit(H)
(Array(7, dtype=int64), Array(12, dtype=int64))

The blurring

Apply the blurring operator to the original image

blurred_image = H.times(image)
# Measure the PSNR
print("Blurred PSNR: ", crn.peak_signal_noise_ratio(image, blurred_image), 'dB')
# plot the original and the blurred images
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
ax[0].imshow(image, cmap=plt.cm.gray)
ax[0].set_title('Original')
ax[1].imshow(blurred_image, cmap=plt.cm.gray)
ax[1].set_title('After blurring')
Original, After blurring
Blurred PSNR:  14.592858290084076 dB

Text(0.5, 1.0, 'After blurring')

The deblurring using LSQR algorithm

An initial guess of the deblurred image is all zeros

x0 = jnp.zeros_like(blurred_image)
# We run LSQR algorithm to deblur the image for 50 iterations
sol = sls.lsqr(H, blurred_image, x0, max_iters=50)
deblurred_image = sol.x
# Measure the PSNR
print("Deblurred PSNR: ", crn.peak_signal_noise_ratio(image, deblurred_image), 'dB')
# Plot the original, blurred and deblurred image
fig, ax = plt.subplots(ncols=3, figsize=(15, 5))
ax[0].imshow(image, cmap=plt.cm.gray)
ax[0].set_title('Original')
ax[1].imshow(blurred_image, cmap=plt.cm.gray)
ax[1].set_title('After blurring')
ax[2].imshow(deblurred_image, cmap=plt.cm.gray)
ax[2].set_title('After deblurring')

print(sol)
Original, After blurring, After deblurring
Deblurred PSNR:  21.206455209076868 dB
x: (200, 200)
A_norm: 4.946758875785462
A_cond: 167.30699088560272
x_norm: 34971.66794553033
r_norm: 28.394130441229127
atr_norm: 3.81619135390487
iterations: 50
n_times: 50
n_trans: 50

A wavelet basis for the images

Construct the basis

DWT_basis = lop.dwt2D(image.shape, wavelet='haar', level=3, basis=True)
DWT_basis = lop.jit(DWT_basis)
# Visualize the wavelet transform of the image
coefs = DWT_basis.trans(image)
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
ax[0].imshow(image, cmap=plt.cm.gray)
ax[0].set_title('Image')
ax[1].imshow(coefs, cmap=plt.cm.gray)
ax[1].set_title('Wavelet coefficients')
Image, Wavelet coefficients
Text(0.5, 1.0, 'Wavelet coefficients')

Deblurring with Fast Iterative Shrinkage and Thresholding Algorithm

We combine the convolution operator and the wavelet basis operator

A = H @ DWT_basis
# Step size for the FISTA algorithm
step_size = 1.
# Thresholding function for the FISTA algorithm
threshold_func = lambda i, x : geo.soft_threshold(x, 0.02)
# Initial guess for the wavelet coefficients matrix is all zeros
x0 = jnp.zeros(DWT_basis.shape[1])
# Solve the \| A x - b \|_2^2 + \lambda \| x \|_1 problem
sol = sls.fista_jit(
    # The combined convolution+wavelet basis operator
    A,
    # The blurred image as input
    b=blurred_image,
    # Initial guess for the coefficients
    x0=x0,
    # Step size for the FISTA algorithm
    step_size=1.,
    # Thresholding function to be used for FISTA
    threshold_func=threshold_func,
    # Maximum number of iterations for which the algorithm will be run
    max_iters=50)
print(f"Number of FISTA iterations {sol.iterations}")
# Compute the deblurred image from the coefficients given by FISTA
deblurred_image = DWT_basis.times(sol.x)
# Measure the PSNR
print("Deblurred PSNR: ", crn.peak_signal_noise_ratio(image, deblurred_image), 'dB')
fig, ax = plt.subplots(ncols=3, figsize=(15, 5))
ax[0].imshow(image, cmap=plt.cm.gray)
ax[0].set_title('Original')
ax[1].imshow(blurred_image, cmap=plt.cm.gray)
ax[1].set_title('After blurring')
ax[2].imshow(deblurred_image, cmap=plt.cm.gray)
ax[2].set_title('FISTA deblurring')
Original, After blurring, FISTA deblurring
Number of FISTA iterations 50
Deblurred PSNR:  20.411790224661722 dB

Text(0.5, 1.0, 'FISTA deblurring')

Total running time of the script: (0 minutes 9.313 seconds)

Gallery generated by Sphinx-Gallery