import torch
import math

from livereco.api.parameters.padding import Padding
from livereco.api.parameters import DataDimensions
from livereco.core.preprocessing.process_image import flip_and_pad


def apply_padding_refractive(
    image, data_dimensions: DataDimensions, padding_options: Padding, a0_log
):
    image = flip_and_pad(image, data_dimensions, padding_options, True)
    image.imag = image.imag - a0_log
    image = (
        image.real * data_dimensions.window + 1j * image.imag * data_dimensions.window
    )
    image.imag = image.imag + a0_log
    return image

def apply_l1(values,weight):
    values.real = torch.sign(values.real) * torch.maximum(torch.abs(values.real)-weight.real, torch.tensor(0, dtype=values.real.dtype, device=values.real.device))
    values.imag = torch.sign(values.imag) * torch.maximum(torch.abs(values.imag)-weight.imag, torch.tensor(0, dtype=values.real.dtype, device=values.real.device))
    return values

def apply_non_negativity(values, phase_max, absorption_min):
    if not math.isnan(phase_max):
        values.real = torch.minimum(values.real, phase_max)
    if not math.isnan(absorption_min):
        values.imag = torch.maximum(values.imag, absorption_min)
    return values


def apply_filter(values, filter_kernel_real, filter_kernel_imag):
    values_real_fft = torch.fft.rfft2(values.real)
    values_real_fft *= filter_kernel_real
    values.real = torch.fft.irfft2(values_real_fft, values.real.size())

    values_imag_fft = torch.fft.rfft2(values.imag)
    values_imag_fft *= filter_kernel_imag
    values.imag = torch.fft.irfft2(values_imag_fft, values.imag.size())

    return values


def apply_window(values, window, intensities_log):
    values.imag = values.imag - intensities_log
    values *= window
    values.imag = values.imag + intensities_log

    return values
