Source code for corrct.filters

# -*- coding: utf-8 -*-
"""
Filtered back-projection filters.

@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

import numpy as np
from scipy.interpolate import interp1d

import skimage.transform as skt

import matplotlib.pyplot as plt

from .operators import BaseTransform
from .processing import circular_mask

from typing import Union, Sequence, Optional, Any
from numpy.typing import ArrayLike, DTypeLike, NDArray

from abc import ABC, abstractmethod
from dataclasses import dataclass
from collections.abc import Mapping

try:
    import pywt

    has_pywt = True
except ImportError:
    print("WARNING: You need to install PyWavelets to benefit from wavelet bases.")

    has_pywt = False


[docs]class BasisOptions(ABC, Mapping): """Options for the different types of bases.""" def __len__(self) -> int: """Return the number of options. Returns ------- int The number of options. """ return self.__dict__.__len__() def __getitem__(self, k: Any) -> Any: """Return the selected option. Parameters ---------- k : Any The key of the selected option. Returns ------- Any The selected option. """ return self.__dict__.__getitem__(k) def __iter__(self) -> Any: """Iterate the options list. Returns ------- Any The following option. """ return iter(self.__dict__)
[docs]@dataclass class BasisOptionsBlocks(BasisOptions): """Options for the wavelet bases.""" binning_start: Optional[int] = 2 binning_type: str = "exponential" order: int = 1 normalized: bool = True
[docs]@dataclass class BasisOptionsWavelets(BasisOptions): """Options for the wavelet bases.""" wavelet: str = "bior2.2" level: int = 5 norm: float = 1.0
[docs]def create_basis( num_pixels: int, binning_start: Optional[int] = 2, binning_type: str = "exponential", normalized: bool = False, order: int = 1, dtype: DTypeLike = np.float32, ) -> NDArray: """Compute filter basis matrix. Parameters ---------- num_pixels : int Number of filter fixels. binning_start : Optional[int], optional Starting displacement of the binning, by default 2. binning_type : str, optional Type of pixel binning, by default "exponential". normalized : bool, optional Whether to normalize the bins by the window size, by default True. order : int, optional Order of the basis functions. Only 0 and 1 supported, by default 1. dtype : DTypeLike, optional Data type, by default np.float32. Returns ------- NDArray The filter basis. """ filter_positions = np.abs(np.fft.fftfreq(num_pixels, 1 / num_pixels)) window_size = 1 window_position = 0 basis_r = [] while window_position < filter_positions.max(): basis_tmp = np.zeros(filter_positions.shape, dtype=dtype) if order == 0: binning_positions = np.logical_and( window_position <= filter_positions, filter_positions < (window_position + window_size) ) basis_val = 1.0 if normalized: basis_val /= window_size basis_tmp[binning_positions] = basis_val else: basis_tmp = np.fmax(1.0 - filter_positions / (window_position + window_size), 0.0) basis_r.append(basis_tmp) window_position += window_size if binning_start is not None and window_position > binning_start: if binning_type == "exponential": window_size = 2 * window_size elif binning_type == "incremental": window_size += 1 elif binning_type == "custom": window_size += int(np.sqrt(window_position - binning_start)) else: raise ValueError(f"Invalid 'binning_type' = {binning_type}.") basis_r = np.array(basis_r, dtype=dtype) if order > 0: for ii in range(1, basis_r.shape[0]): for jj in range(0, ii): other_vec = basis_r[jj, ...] other_vec_norm_2 = other_vec.dot(other_vec) basis_r[ii, ...] -= other_vec * other_vec.dot(basis_r[ii, ...]) / other_vec_norm_2 if normalized: basis_r /= np.linalg.norm(basis_r, ord=1, axis=-1, keepdims=True) return basis_r
[docs]def create_basis_wavelet( num_pixels: int, wavelet: str = "bior2.2", level: int = 5, norm: float = 1.0, dtype: DTypeLike = np.float32, ) -> NDArray: """Compute filter basis matrix. Parameters ---------- num_pixels : int Number of wavelet filters. wavelet: str, optional The wavelet to use, by default "bior4.4". level : int, optional The decomposition level to reach, by default 5. norm : float, optional The norm to use, by default 1.0. dtype : DTypeLike, optional Data type, by default np.float32. Returns ------- NDArray The filter basis. """ if not has_pywt: print("WARNING: You need to install PyWavelets to benefit from wavelet bases.") raise ImportError("PyWavelets (pywt) module not found.") # max_level = pywt.swt_max_level(num_pixels) # if level > max_level: # print(f"WARNING: Requested level {level} is too high for {num_pixels} pixels. Max allowed is {max_level}.") # level = max_level w = pywt.Wavelet(wavelet) dec_lo = np.trim_zeros(w.dec_lo) dec_hi = np.trim_zeros(w.dec_hi) crop_size_l = num_pixels // 2 crop_size_u = num_pixels - crop_size_l pad_size_u = (num_pixels - len(dec_hi)) // 2 pad_size_l = num_pixels - len(dec_hi) - pad_size_u pad_width = (pad_size_l, pad_size_u) basis_hi_tmp = np.pad(dec_hi, pad_width=np.array(pad_width)) pad_size_u = (num_pixels - len(dec_lo)) // 2 pad_size_l = num_pixels - len(dec_lo) - pad_size_u pad_width = (pad_size_l, pad_size_u) basis_lo_tmp = np.pad(dec_lo, pad_width=np.array(pad_width)) coords = np.fft.fftshift(np.fft.fftfreq(num_pixels, 1 / num_pixels)) basis_r = [] basis_r.append(basis_hi_tmp) for _ in range(level - 1): basis_lo_old = basis_lo_tmp.copy() int_obj = interp1d(coords, basis_hi_tmp, kind="linear") basis_hi_tmp = int_obj((coords + 1 - len(coords) % 2) / 2) basis_hi_tmp = np.convolve(basis_hi_tmp, basis_lo_old, mode="same") int_obj = interp1d(coords, basis_lo_tmp, kind="linear") basis_lo_tmp = int_obj((coords + 1 - len(coords) % 2) / 2) basis_lo_tmp = np.convolve(basis_lo_tmp, basis_lo_old, mode="same") basis_r.append(basis_hi_tmp) basis_r.append(basis_lo_tmp) basis_r = np.array(basis_r, dtype=dtype) basis_r /= np.linalg.norm(basis_r, axis=-1, ord=norm, keepdims=True) return np.fft.ifftshift(basis_r, axes=(-1,))
[docs]class Filter(ABC): """Base FBP filter.""" fbp_filter: NDArray[np.floating] pad_mode: str use_rfft: bool dtype: DTypeLike def __init__( self, fbp_filter: Union[ArrayLike, NDArray[np.floating], None], pad_mode: str, use_rfft: bool, dtype: DTypeLike, ) -> None: """Initialize Base FBP filter. Parameters ---------- fbp_filter : Union[ArrayLike, NDArray[np.floating], None] The filter. pad_mode : str The padding mode. use_rfft : bool Whethert to use the `rfft` or complex `fft`. dtype : DTypeLike The data type of the filter. """ self.dtype = dtype self.pad_mode = pad_mode.lower() self.use_rfft = use_rfft if fbp_filter is None: self.fbp_filter = np.array([1.0], dtype=dtype) else: self.fbp_filter = np.array(np.real(fbp_filter), dtype=dtype)
[docs] def get_padding_size(self, data_wu_shape: Sequence[int]) -> int: """Compute the projection padding size for a linear convolution. Parameters ---------- data_wu_shape : Sequence[int] The shape of the data Returns ------- int The padding size of the last dimension. """ return max(64, int(2 ** np.ceil(np.log2(2 * data_wu_shape[-1]))))
[docs] def to_fourier(self, data_wu: NDArray) -> NDArray: if self.use_rfft: return np.fft.rfft(data_wu, axis=-1) else: return np.fft.fft(data_wu, axis=-1)
[docs] def to_real(self, data_wu: NDArray) -> NDArray: if self.use_rfft: return np.fft.irfft(data_wu, axis=-1) else: return np.fft.ifft(data_wu, axis=-1).real
@property def filter_fourier(self) -> NDArray[np.floating]: """Fourier representation of the filter. Returns ------- NDArray[np.floating] The filter in Fourier. """ return self.fbp_filter.copy() @property def filter_real(self) -> NDArray[np.floating]: """Real-space representation of the filter. Returns ------- NDArray[np.floating] The filter in real-space. """ fbp_filter_r = self.to_real(self.fbp_filter) return np.fft.fftshift(fbp_filter_r, axes=(-1,)) @filter_real.setter def filter_real(self, fbp_filter_r: NDArray[np.floating]) -> None: self.fbp_filter = self.to_fourier(fbp_filter_r) @filter_fourier.setter def filter_fourier(self, fbp_filter_f: NDArray[np.floating]) -> None: self.fbp_filter = fbp_filter_f.copy() @property def num_filters(self) -> int: return np.array(self.fbp_filter, ndmin=2).shape[-2]
[docs] def apply_filter(self, data_wu: NDArray, fbp_filter: Optional[NDArray] = None) -> NDArray: """Apply the filter to the data_wu. Parameters ---------- data_wu : NDArray The sinogram. fbp_filter : NDArray, optional The filter to use. The default is None Returns ------- NDArray The filtered sinogram. """ data_wu_shape = data_wu.shape if fbp_filter is None: local_filter = self.fbp_filter else: local_filter = fbp_filter prj_size_pad = self.get_padding_size(data_wu_shape) pad_edge_u = (prj_size_pad - data_wu_shape[-1]) / 2 pad_width = np.zeros((len(data_wu_shape), 2), dtype=int) pad_width[-1, :] = (int(np.ceil(pad_edge_u)), int(np.floor(pad_edge_u))) prj_pad = np.pad(data_wu, pad_width=tuple(pad_width), mode=self.pad_mode) # type: ignore prj_pad = np.roll(prj_pad, shift=-pad_width[-1][0], axis=-1) prj_f = self.to_fourier(prj_pad) local_filter = np.array(local_filter, ndmin=2) num_filters = local_filter.shape[-2] if num_filters > 1: prjs_f = [np.array([])] * num_filters for ii, f in enumerate(local_filter): prj_f_ii = prj_f * np.array(f, ndmin=len(data_wu_shape)) prj_f_ii = self.to_real(prj_f_ii) prjs_f[ii] = prj_f_ii[..., : data_wu_shape[-1]] return np.array(prjs_f) else: prj_f *= np.array(local_filter, ndmin=len(data_wu_shape)) return self.to_real(prj_f)[..., : data_wu_shape[-1]]
[docs] @abstractmethod def compute_filter(self, data_wu: NDArray) -> None: """Compute the FBP filter for the given data. Parameters ---------- data_wu : NDArray The reference sinogram / projection data. """
def __call__(self, data_wu: NDArray) -> NDArray: """Filter the sinogram, by first computing the filter, and then applying it. Parameters ---------- data_wu : NDArray The unfiltered sinogram. Returns ------- NDArray The filtered sinogram. """ self.compute_filter(data_wu) return self.apply_filter(data_wu, self.fbp_filter)
[docs] def plot_filters(self, fourier_abs: bool = False): filters_r = np.array(self.filter_real, ndmin=2) filters_f = np.array(self.filter_fourier, ndmin=2) f, axes = plt.subplots(1, 2, figsize=(10, 4)) for ii in range(self.num_filters): axes[0].plot(filters_r[ii, ...], label=f"Filter-{ii}") axes[0].set_title("Real-space") axes[0].set_xlabel("Pixel") for ii in range(self.num_filters): filt_f = filters_f[ii, ...] if fourier_abs: filt_f = np.abs(filt_f) axes[1].plot(filt_f, label=f"Filter-{ii}") axes[1].set_title("Fourier-space") axes[1].set_xlabel("Frequency") axes[0].grid() axes[1].grid() if self.num_filters > 1: axes[1].legend() f.tight_layout()
[docs]class FilterCustom(Filter): """Custom FBP filter.""" def __init__( self, fbp_filter: Union[ArrayLike, NDArray[np.floating], None], pad_mode: str = "constant", use_rfft: bool = True, dtype: DTypeLike = np.float32, ) -> None: """Initialize Custom FBP filter. Parameters ---------- fbp_filter : Union[ArrayLike, NDArray[np.floating], None] The filter. pad_mode : str, optional The padding mode, by default "constant". use_rfft : bool, optional Whethert to use the `rfft` or complex `fft`, by default True. dtype : DTypeLike, optional The data type of the filter, by default np.float32. """ super().__init__(fbp_filter, pad_mode, use_rfft, dtype)
[docs] def compute_filter(self, data_wu: NDArray) -> None: """Provide dummy implementation, because it is not required."""
[docs]class FilterFBP(Filter): """Traditional FBP filter.""" filter_name: str FILTERS = ("ramp", "shepp-logan", "cosine", "hamming", "hann") def __init__( self, filter_name: str = "ramp", pad_mode: str = "constant", use_rfft: bool = True, dtype: DTypeLike = np.float32 ) -> None: """Initialize traditional FBP filter. Parameters ---------- filter_name : str The filter name use_rfft : bool, optional Whethert to use the `rfft` or complex `fft`, by default True dtype : DTypeLike, optional The type of the filter, by default np.float32 """ if filter_name.lower() not in self.FILTERS: raise ValueError(f"Unknown filter {filter_name}. Available filters: {self.FILTERS}") super().__init__(fbp_filter=None, pad_mode=pad_mode, use_rfft=use_rfft, dtype=dtype) self.filter_name = filter_name.lower()
[docs] def compute_filter(self, data_wu: NDArray) -> None: """Compute the traditional FBP filter for the given data. Parameters ---------- data_wu : NDArray The reference sinogram / projection data. """ prj_size_pad = self.get_padding_size(data_wu.shape) self.fbp_filter = skt.radon_transform._get_fourier_filter(prj_size_pad, self.filter_name) self.fbp_filter = np.squeeze(self.fbp_filter) * np.pi / (2 * data_wu.shape[-2]) if self.use_rfft: self.fbp_filter = self.fbp_filter[: (self.fbp_filter.shape[-1]) // 2 + 1]
[docs] def get_available_filters(self) -> Sequence[str]: """Provide available FBP filters. Returns ------- Sequence[str] The available filters. """ return self.FILTERS
[docs]class FilterMR(Filter): """Data dependent FBP filter. This is a simplified implementation from: [1] Pelt, D. M., & Batenburg, K. J. (2014). Improving filtered backprojection reconstruction by data-dependent filtering. Image Processing, IEEE Transactions on, 23(11), 4750-4762. Code inspired by: https://github.com/dmpelt/pymrfbp """ projector: BaseTransform binning_type: str binning_start: Union[int, None] lambda_smooth: Union[float, None] is_initialized: bool def __init__( self, projector: BaseTransform, binning_type: str = "exponential", binning_start: Union[int, None] = 2, lambda_smooth: Optional[float] = None, pad_mode: str = "constant", use_rfft: bool = True, dtype: DTypeLike = np.float32, ) -> None: """Initialize data-driven FBP filter. Parameters ---------- projector : BaseTransform The projector to use for handling the data. start_exp_binning : int, optional From which distance to start exponentional binning. The default is 2. lambda_smooth : float, optional Smoothing parameter. The default is None. dtype : DTypeLike, optional Filter data type. The default is np.float32. """ super().__init__(fbp_filter=None, pad_mode=pad_mode, use_rfft=use_rfft, dtype=dtype) self.projector = projector self.binning_type = binning_type.lower() if self.binning_type not in ("exponential", "incremental"): raise ValueError("Binning type should be either 'exponential' or 'incremental'.") self.binning_start = binning_start self.lambda_smooth = lambda_smooth self.is_initialized = False
[docs] def initialize(self, data_wu_shape: Sequence[int]) -> None: """Initialize filter. Parameters ---------- data_wu_shape : Sequence[int] Shape of the data. """ num_pad_pixels = self.get_padding_size(data_wu_shape) self.basis_r = create_basis( num_pad_pixels, binning_type=self.binning_type, binning_start=self.binning_start, dtype=self.dtype ) self.basis_f = self.to_fourier(self.basis_r).real self.is_initialized = True
[docs] def compute_filter(self, data_wu: NDArray) -> None: """Compute the filter. Parameters ---------- data_wu : NDArray The sinogram. projector : ProjectorOperator The projector used in the FBP. """ if not self.is_initialized: self.initialize(data_wu.shape) num_sino_pixels = data_wu.shape[-1] sino_size = data_wu.shape[-2] * num_sino_pixels nrows = sino_size ncols = self.basis_f.shape[-2] if self.lambda_smooth: grad_vol_size = num_sino_pixels * (num_sino_pixels - 1) nrows += 2 * grad_vol_size A = np.zeros((nrows, ncols), dtype=self.dtype) vol_mask = circular_mask([num_sino_pixels] * 2) for ii, bas_f in enumerate(self.basis_f): data_wu_f = self.apply_filter(data_wu, bas_f) img = self.projector.T(data_wu_f) img *= vol_mask A[:sino_size, ii] = self.projector(img).flatten() if self.lambda_smooth: dx = np.diff(img, axis=-2) dy = np.diff(img, axis=-1) d = np.concatenate((dx.flatten(), dy.flatten())) A[sino_size:, ii] = self.lambda_smooth * d b = np.zeros((nrows,), dtype=self.dtype) b[:sino_size] = data_wu.flatten() fitted_components = np.linalg.lstsq(A, b, rcond=None)[0].astype(self.dtype) self.fbp_filter = fitted_components.dot(self.basis_f)