from typing import List
import torch
import cupy as cp
from cupyx.scipy.ndimage import fourier_gaussian

def get_filter_kernels(gaussian_filter_fwhm, shape, device):
    ones = torch.ones(shape, device=device, dtype=torch.float)

    if gaussian_filter_fwhm is not None and gaussian_filter_fwhm.real != 0.0:
        filter_kernel_obj_phase = torch.tensor(
            fourier_gaussian(
                cp.ones(shape),
                sigma= gaussian_filter_fwhm.real / 2.35,
            )[:, 0 : int(shape[1] / 2) + 1],
            device=device,
            dtype=torch.float,
        )
    else:
        filter_kernel_obj_phase = ones[:, 0 : int(shape[1] / 2) + 1]

    if  gaussian_filter_fwhm is not None and gaussian_filter_fwhm.imag != 0.0:
        filter_kernel_obj_absorption = torch.tensor(
            fourier_gaussian(
                cp.ones(shape),
                sigma= gaussian_filter_fwhm.imag / 2.35,
            )[:, 0 : int(shape[1] / 2) + 1],
            device=device,
            dtype=torch.float,
        )
    else:
        filter_kernel_obj_absorption = ones[:, 0 : int(shape[1] / 2) + 1]

    return filter_kernel_obj_phase, filter_kernel_obj_absorption