# -*- coding: utf-8 -*-
"""
Script for running the simulations described in the paper:

Reconstruction Algorithms in Undersampled AFM Imaging

Copyright (c) 2015, Thomas Arildsen, Christian Schou Oxvig,
    Patrick Steffen Pedersen, Jan Østergaard, and Torben Larsen
All rights reserved.

**Combinations that are tested**

- 7 images
- 2 sampling patterns: Rect Spiral, Uniform lines
- 9 undersampling ratios: 0.10 0.125 0.15 0.175 0.20 0.225 0.25 0.275 0.30
- 1 Dictionary: DCT
- 7 Reconstruction algorithms: ell_1, TV, IHT, IST, Interpolation, AMP-L1,
  AMP-BG
- 2 Evaluation indicators: PSNR, SSIM

See the bottom of the script for more details about the specific combinations.

**Requirements**

To run the script, the following must be available:

- the seven AFM images
- magni : http://vbn.aau.dk/en/publications/
          magni(194fc193-7913-4b88-85d6-25570aa43ae1).html, magni_1.3.0
- pyunlockbox : https://github.com/epfl-lts2/pyunlocbox,
  commit:f9fafb070df125c38da86b346a330743e8065910
- amp_bgg_solver.py : bundled with this script

**Output***

All output is placed in the "aspmci_reconstructions" folder. The output is:

- A HDF database ("aspmci_reconstructions.hdf5") containing all results.
- An overview pdf of each reconstruction

"""


from __future__ import division
import json
import os
import time
import warnings

import matplotlib as mpl; mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc
from scipy import interpolate
from skimage.measure import structural_similarity as calc_ssim
import tables as tb

import pyunlocbox as ulb
import amp_bgg_solver
import magni


def run_simulation_task(img_folder=None, result_folder=None, h5_name=None,
                        task=None):
    """
    Function to run for each simulation task.

    The following elements are part of this simulation:

    * Load and downsample image
    * Scan / sample image
    * Setup reconstruction
    * Reconstruct image
    * Evaluate reconstruction result
    * Save results

    Parameters
    ----------
    img_folder : str
        The path to the folder containing the mi-files.
    result_folder : str
        The path to the root folder to save result figures to.
    h5_name : str
        The name of the HDF5 file in which the results are saved.
    task : dict
        The simulation task specification.

    """

    # Load image
    image_size = 128
    downsampling = int(512 / image_size)
    mi_img = magni.afm.io.read_mi_file(
        img_folder + task['image']).get_buffer('Topography')[0]
    mi_img_data = mi_img.data
    if task['image'] == 'image_0.mi':
        # Remove single outlier
        mi_img_data = mi_img_data.copy()
        mi_img_data[511, 0] = mi_img_data[510, 1]
    img = magni.imaging.visualisation.stretch_image(
        mi_img_data[::downsampling, ::downsampling], 1.0)
    assert img.shape == (image_size, image_size)
    assert np.allclose(img.min(), 0.0)
    assert np.allclose(img.max(), 1.0)
    h, w = img.shape

    # Scanning setup
    img_coords, Phi = get_sampling_setup(
        task['sampling_pattern'], task['delta'], h, w)
    unique_pixels = magni.imaging.measurements.unique_pixels(img_coords)

    # De-tilt based on samples
    scan_mask = np.zeros((h, w), dtype=np.bool_)
    scan_mask[unique_pixels[:, 1], unique_pixels[:, 0]] = True
    img_detilt, tilt = magni.imaging.preprocessing.detilt(
        img, mask=scan_mask, return_tilt=True)

    # Convert image to vector
    img_vec = magni.imaging.mat2vec(img)
    img_detilt_vec = magni.imaging.mat2vec(
        magni.imaging.visualisation.stretch_image(img_detilt, 1.0))
    tilt_vec = magni.imaging.mat2vec(tilt)
    assert np.allclose(img_detilt_vec.min(), 0.0)
    assert np.allclose(img_detilt_vec.max(), 1.0)

    # Reconstruction setup
    Psi = magni.imaging.dictionaries.utils.get_function_handle(
        'matrix', 'DCT')((h, w))
    domain = magni.imaging.domains.MultiDomainImage(Phi, Psi)
    domain.measurements = Phi.dot(img_detilt_vec)

    # Reconstruction
    reconstructed_img_vec, reconstruction_time, rec_coefs = reconstruct_image(
        task['reconstruction_algorithm'], domain.measurements, Phi, Psi,
        img_coords, h, w)

    # Evaluation
    psnr = magni.imaging.evaluation.calculate_psnr(
        img_detilt_vec, reconstructed_img_vec, 1.0)

    with warnings.catch_warnings():
        # Ignore irrelevant copy warning
        warnings.simplefilter('ignore')
        ssim = calc_ssim(
            magni.imaging.vec2mat(img_detilt_vec, (h, w)),
            magni.imaging.vec2mat(reconstructed_img_vec, (h, w)),
            dynamic_range=1)

    # Save results in database
    h5_file = result_folder + '/' + h5_name
    with magni.utils.multiprocessing.File(h5_file, mode='a') as h5file:
        # Save metrics
        row = h5file.root.simulation_results.metrics.row
        row['image'] = task['image']
        row['sampling_pattern'] = task['sampling_pattern']
        row['delta'] = task['delta']
        row['reconstruction_algorithm'] = task['reconstruction_algorithm']
        row['psnr'] = psnr
        row['ssim'] = ssim
        row['time'] = reconstruction_time
        row.append()

    save_path = '/'.join(['/simulation_results',
                          task['image'],
                          task['sampling_pattern'],
                          'd' + str(task['delta'])])

    save_path = _fix_str_representation(save_path)

    with magni.utils.multiprocessing.File(h5_file, mode='a') as h5file:
        # Save arrays and tasks
        db_group = h5file.create_group(
            save_path, task['reconstruction_algorithm'],
            createparents=True)

        h5file.create_array(db_group, 'img_vec', obj=img_vec)
        h5file.create_array(db_group, 'img_detilt_vec', obj=img_detilt_vec)
        h5file.create_array(db_group, 'tilt_vec', obj=tilt_vec)
        h5file.create_array(db_group, 'domain_measurements',
                            obj=domain.measurements)
        h5file.create_array(db_group, 'img_coords', obj=img_coords)
        h5file.create_array(db_group, 'reconstructed_coefficients_vec',
                            obj=rec_coefs)
        h5file.create_array(db_group, 'reconstructed_img_vec',
                            obj=reconstructed_img_vec)
        h5file.create_array(db_group, 'task', obj=json.dumps(task).encode())
        h5file.create_array(db_group, 'img_shape', obj=(h, w))

    # Save summary figures
    measurement_img = magni.imaging.mat2vec(
        magni.imaging.visualisation.mask_img_from_coords(
            img, magni.imaging.measurements.unique_pixels(img_coords)))

    figs = [magni.imaging.vec2mat(fig, (h, w))
            for fig in [measurement_img, img_detilt_vec, reconstructed_img_vec]
            ]
    titles = ['Measurements', 'Original', 'Reconstruction']

    for colormap in ['coolwarm', 'afmhot']:
        magni.utils.plotting.setup_matplotlib(
            {'figure': {'figsize': (20, 12)}}, cmap=colormap)
        fig = magni.imaging.visualisation.imsubplot(figs, 1, titles=titles)

        out_dir = (result_folder + save_path + '/' +
                   task['reconstruction_algorithm'])

        fig.suptitle(
            '{}\n PSNR: {:.2f} dB, SSIM: {:.2f}, time: {:.2f} s'.format(
                out_dir, psnr, ssim, reconstruction_time), fontsize=20)

        if not os.path.isdir(out_dir):
            os.makedirs(out_dir)

        plt.savefig(out_dir + '/summary_{}.png'.format(colormap))
        plt.close(fig)


def get_sampling_setup(sampling_pattern, delta, h, w):
    """
    Return image coordinates and matrix representation for sampling pattern.

    Parameters
    ----------
    sampling_pattern : str
        The sampling pattern to use.
    delta : float
        The undersampling ratio.
    h : int
        The image height in pixels.
    w : int
        The image width in pixels.

    Returns
    -------
    img_coords : ndarray
        The Pixel coordinates used in the image sampling.
    Phi : magni.utils.matrices.Matrix
        The sampling matrix operator.

    """

    scan_length = delta * 2 * h * w
    num_points = 10 * int(scan_length)  # Make sure to have enough points

    if sampling_pattern == 'rect_spiral':
        img_coords = magni.imaging.measurements.spiral_sample_image(
            h, w, scan_length, num_points, rect_area=True)

    elif sampling_pattern == 'uniform_lines':
        img_coords = magni.imaging.measurements.uniform_line_sample_image(
            h, w, scan_length, num_points)

    else:
        raise ValueError('Invalid sampling pattern: {!r}'.format(
            sampling_pattern))

    Phi = magni.imaging.measurements.construct_measurement_matrix(
        img_coords, h, w)

    return img_coords, Phi


def reconstruct_image(algorithm, measurements, Phi, Psi, img_coords, h, w):
    """
    Return a reconstructed image along with the reconstruction time.

    Parameters
    ----------
    algorithm : str
        The reconstruction algorithm to use.
    measurements : ndarray
        The m x 1 vector of measurments.
    Phi : magni.utils.matrices.Matrix
        The measurements matrix operator.
    Psi : magni.utils.matrices.Matrix
        The dictionary matrix operator.
    img_coords : ndarray
        The m x 2 array of measurment coordinates.
    h : int
        The image height in pixels.
    w : int
        The image width in pixels.

    Returns
    -------
    reconstructed_img_vec : ndarray
        The n x 1 vector representing the reconstructed image.
    reconstrution_time : float
        The time in seconds it took to do the reconstruction.
    reconstructed_coefficients : ndarray
        The sparse coefficients in the reconstruction.


    """

    if algorithm == 'iht':
        A = magni.utils.matrices.MatrixCollection((Phi, Psi))
        t0 = time.time()
        alpha = magni.cs.reconstruction.it.run(measurements, A)
        x = Psi.dot(alpha)
        t1 = time.time()

    elif algorithm == 'ist':
        A = magni.utils.matrices.MatrixCollection((Phi, Psi))
        magni.cs.reconstruction.it.config.update(
            {'kappa_fixed': 0.6, 'threshold_operator': 'soft'})
        t0 = time.time()
        alpha = magni.cs.reconstruction.it.run(measurements, A)
        x = Psi.dot(alpha)
        t1 = time.time()
        magni.cs.reconstruction.it.config.reset()

    elif algorithm == 'cubic_interpolation':
        xx, yy = np.meshgrid(*map(np.arange, (h, w)))
        unique_pixels = magni.imaging.measurements.unique_pixels(img_coords)
        t0 = time.time()
        recon = interpolate.griddata(unique_pixels, measurements.ravel(),
                                     (xx, yy), method='cubic',
                                     fill_value=measurements.mean())
        x = magni.imaging.mat2vec(recon)
        t1 = time.time()
        alpha = Psi.T.dot(x)

    elif algorithm == 'ell_1':
        A = magni.utils.matrices.MatrixCollection((Phi, Psi))

        def Afunc(x):
            return A.dot(x)

        def Atfunc(x):
            return A.T.dot(x)

        f1 = ulb.functions.norm_l1()
        f2 = ulb.functions.proj_b2(epsilon=1e-3*np.linalg.norm(measurements),
                                   y=measurements, A=Afunc, At=Atfunc,
                                   tight=False)
        solver = ulb.solvers.douglas_rachford()
        x0 = np.zeros((A.shape[1], 1))

        t0 = time.time()
        solution = ulb.solvers.solve([f1, f2], x0, solver)
        alpha = solution['sol']
        x = Psi.dot(alpha)
        t1 = time.time()

    elif algorithm == 'tv':
        unique_pixels = magni.imaging.measurements.unique_pixels(img_coords)
        mask = np.zeros((h, w), dtype=np.bool_)
        mask[unique_pixels[:, 1], unique_pixels[:, 0]] = True

        meas_mat = magni.imaging.vec2mat(Phi.T.dot(measurements), (h, w))

        def Afunc(x):
            return mask * x

        f1 = ulb.functions.norm_tv()
        f2 = ulb.functions.proj_b2(epsilon=1e-3*np.linalg.norm(measurements),
                                   y=meas_mat, A=Afunc, At=Afunc)
        solver = ulb.solvers.douglas_rachford(step=1e-2)
        x0 = meas_mat

        t0 = time.time()
        solution = ulb.solvers.solve([f1, f2], x0, solver)
        x = magni.imaging.mat2vec(solution['sol'])
        alpha = Psi.T.dot(x)
        t1 = time.time()

    elif algorithm == 'l1_amp':
        A = magni.utils.matrices.MatrixCollection((Phi, Psi))
        alpha = np.zeros((A.shape[1], 1))
        y = measurements
        z = y.copy()

        t0 = time.time()
        for it in range(300):
            alpha_z = alpha + A.T.dot(z)

            thres = np.sort(np.abs(alpha_z).ravel())[-A.shape[0]]
            a_t_p = (alpha_z > thres)
            a_t_m = (alpha_z < -thres)
            alpha = (alpha_z - thres) * a_t_p + (alpha_z + thres) * a_t_m

            r = y - A.dot(alpha)

            z = r + z * 1/A.shape[0] * np.sum(a_t_p + a_t_m)

            if np.linalg.norm(r) < 1e-3 * np.linalg.norm(y):
                break

        x = Psi.dot(alpha)
        t1 = time.time()

    elif algorithm == 'bg_amp':
        A = magni.utils.matrices.MatrixCollection((Phi, Psi)).A
        rho = 0.3
        theta_bar = 0.0
        theta_hat = 1.0
        d_s = 1.0E-6
        T = 300
        prior_prmts = {'rho': rho, 'theta_bar': theta_bar,
                       'theta_hat': theta_hat}
        y = measurements.flatten()

        t0 = time.time()
        alpha, _dummy = amp_bgg_solver.camp(y, A, d_s, prior_prmts, T, 1.0E-3)
        x = Psi.dot(alpha)
        t1 = time.time()

    else:
        raise ValueError('Invalid reconstruction algorithm: {!r}'.format(
            algorithm))

    reconstructed_img_vec = magni.imaging.visualisation.stretch_image(x, 1.0)
    reconstruction_time = t1 - t0
    reconstructed_coefficients = alpha

    return (reconstructed_img_vec, reconstruction_time,
            reconstructed_coefficients)


def get_tasks():
    """
    The task setup, i.e. the list of dicts of tasks for the workers.

    """

    # Combinations to test
    images = tuple(['image_{}.mi'.format(k) for k in range(7)])
    sampling_patterns = ('rect_spiral', 'uniform_lines')
    undersampling_ratios = np.linspace(0.1, 0.3, 9)
    reconstruction_algorithms = ('iht', 'ist', 'ell_1', 'tv',
                                 'cubic_interpolation', 'l1_amp', 'bg_amp')

    # Resulting tasks
    tasks = [{'image': image,
              'sampling_pattern': sampling_pattern,
              'delta': undersampling_ratio,
              'reconstruction_algorithm': reconstruction_algorithm}
             for image in images
             for sampling_pattern in sampling_patterns
             for undersampling_ratio in undersampling_ratios
             for reconstruction_algorithm in reconstruction_algorithms]

    # Group structure
    group_structure = tuple(['image', 'sampling_pattern',
                             'undersampling_ratio', 'reconstruction_algorithm']
                            )

    return group_structure, tasks


def create_database(h5_path, group_structure):
    """
    Create an empty database for storing the results.

    """

    class ReconMetrics(tb.IsDescription):
        """
        Table description for table to contain reconstruction performance
        metrics.

        """

        image = tb.StringCol(itemsize=10, pos=0)
        sampling_pattern = tb.StringCol(itemsize=20, pos=1)
        delta = tb.Float64Col(pos=2)
        reconstruction_algorithm = tb.StringCol(itemsize=30, pos=3)
        psnr = tb.Float64Col(pos=4)
        ssim = tb.Float64Col(pos=5)
        time = tb.Float64Col(pos=6)

    magni.reproducibility.io.create_database(h5_path)
    with magni.utils.multiprocessing.File(h5_path, mode='a') as h5file:
        sim_group = h5file.create_group('/', 'simulation_results')
        h5file.create_array(sim_group, 'group_structure',
                            obj=json.dumps(group_structure).encode())
        h5file.create_table(sim_group, 'metrics', description=ReconMetrics,
                            expectedrows=1000)


def _fix_str_representation(string):
    return string.replace('.', '_').replace('-', '_')


# Run the simulation
if __name__ == '__main__':
    result_folder = './aspmci_reconstructions/'
    if not os.path.isdir(result_folder):
        os.makedirs(result_folder)

    img_folder = './'
    h5_name = 'aspmci_reconstructions.hdf5'

    group_structure, tasks = get_tasks()
    h5_path = result_folder + h5_name
    create_database(h5_path, group_structure)

    kwargs = [{'img_folder': img_folder, 'result_folder': result_folder,
               'h5_name': h5_name, 'task': task} for task in tasks]

    magni.utils.multiprocessing.config.update(workers=24)
    magni.utils.multiprocessing.process(
        run_simulation_task, kwargs_list=kwargs, maxtasks=1)
