"""
This script extracts reconstructed images for the ASPMCI paper
from a specified experiment results HDF5 file.

"""

from __future__ import division
import itertools

import matplotlib.pyplot as plt
import numpy as np
import tables as tb
import pandas as pd
from os import makedirs
import errno

import magni

# Setup
extract_many = False # Will extract a large selection of the reconstructed images
extract_specifics = True # Will extract only the selection of reconstructed images shown in the accompanying research article

# Setup - the following path can be changed if the result file is in a
# different location. This should correspond to the `result_folder`
# variable in 'aspmci_reconstructions.py', line 488:
data_path = './aspmci_reconstructions/'
hdf_database_name = 'aspmci_reconstructions.hdf5'

images = tuple(['image_{}.mi'.format(k) for k in range(7)])
sampling_patterns = ('rect_spiral', 'uniform_lines')
undersampling_ratios = np.linspace(0.1, 0.3, 3)
reconstruction_algorithms = ('cubic_interpolation', 'ist', 'iht', 'ell_1', 'tv', 'l1_amp', 'bg_amp')

# Open data file
hdf_store = pd.HDFStore(data_path + hdf_database_name)
reconstruction_metrics = hdf_store.select('/simulation_results/metrics')

# Extract a selection of reconstructed images: one for each original
# undersampling rate, image, algorithm, and sampling pattern
if extract_many:
    try:
        makedirs('../figures/')
    except OSError, e:
        if e.errno != errno.EEXIST:
            throw
    with tb.File(data_path + hdf_database_name, mode='r') as h5_file:
        for ur in undersampling_ratios:
            for image in images:
                for algorithm in reconstruction_algorithms:
                    for pattern in sampling_patterns:
                        recon_img = (h5_file.get_node('/simulation_results/{}/{}/d{}/{}/'.format(
                            image.replace('.', '_'), pattern,
                            str(ur).replace('.', '_'), algorithm),
                                                      name='reconstructed_img_vec').read())
                        img_shape = (h5_file.get_node('/simulation_results/{}/{}/d{}/{}/'.format(
                            image.replace('.', '_'), pattern,
                            str(ur).replace('.', '_'), algorithm),
                                                      name='img_shape').read())
                        disp_img = magni.imaging.vec2mat(recon_img, img_shape)
                        plt.figure(figsize=(4,4)) # figsize chosen arbitrarily, but match it to dpi in line 78
                        select_criterion = np.logical_and(np.logical_and(reconstruction_metrics['delta'] == ur,
                                                                         reconstruction_metrics['image'] == image),
                                                          np.logical_and(reconstruction_metrics['reconstruction_algorithm'] == algorithm,
                                                                         reconstruction_metrics['sampling_pattern'] == pattern))
                        plt.title('PSNR: {:.2f} dB / SSIM: {:.2f}'.format(reconstruction_metrics[select_criterion]['psnr'].iat[0],
                                                                          reconstruction_metrics[select_criterion]['ssim'].iat[0]))
                        fig = magni.imaging.visualisation.imshow(disp_img, show_axis='none')
                        # Axis modification suggested in http://stackoverflow.com/a/26610602/865169
                        plt.axis('off')
                        fig.axes.get_xaxis().set_visible(False)
                        fig.axes.get_yaxis().set_visible(False)
                        plt.savefig('../figures/recon-{}-{}-{}-{}.pdf'.format(image.replace('.mi',''),
                                                                   algorithm,
                                                                   pattern,
                                                                   str(ur).replace('.','_')),
                                    bbox_inches='tight',
                                    pad_inches=0,
                                    dpi=41.5) # dpi was chosen to match figzise in line 60 to obtain 128x128 resolution
                        plt.close()

# Extract specific images (from the notebook analyse_data.ipynb)
if extract_specifics:
    try:
        makedirs('../figures/fig4/')
        makedirs('../figures/fig5/')
        makedirs('../figures/fig6/')
    except OSError, e:
        if e.errno != errno.EEXIST:
            throw
    image_ids = ((0.1,  'image_1.mi', 'ist', 'rect_spiral', 'fig4'),
                 (0.1,  'image_6.mi', 'l1_amp', 'uniform_lines', 'fig4'),
                 (0.1,  'image_0.mi', 'iht', 'rect_spiral', 'fig4'),
                 (0.1,  'image_5.mi', 'ell_1', 'uniform_lines', 'fig4'),
                 (0.1,  'image_3.mi', 'cubic_interpolation', 'rect_spiral', 'fig5'),
                 (0.15, 'image_3.mi', 'ell_1', 'rect_spiral', 'fig5'),
                 (0.1,  'image_3.mi', 'l1_amp', 'rect_spiral', 'fig5'),
                 (0.22499999999999998,  'image_3.mi', 'ist', 'rect_spiral', 'fig5'),
                 (0.1,  'image_3.mi', 'tv', 'uniform_lines', 'fig5'),
                 (0.3,  'image_3.mi', 'bg_amp', 'rect_spiral', 'fig5'),
                 (0.3,  'image_0.mi', 'iht', 'rect_spiral', 'fig6'),
                 (0.3,  'image_0.mi', 'bg_amp', 'rect_spiral', 'fig6'),
                 (0.3,  'image_0.mi', 'ell_1', 'uniform_lines', 'fig6'),
                 (0.3,  'image_0.mi', 'l1_amp', 'rect_spiral', 'fig6'),
                 (0.3,  'image_0.mi', 'cubic_interpolation', 'uniform_lines', 'fig6'),
                 (0.3,  'image_0.mi', 'tv', 'uniform_lines', 'fig6'),
                 (0.3,  'image_0.mi', 'ist', 'rect_spiral', 'fig6'))
    with tb.File(data_path + hdf_database_name, mode='r') as h5_file:
        for ur, image, algorithm, pattern, subdir in image_ids:
            recon_img = (h5_file.get_node('/simulation_results/{}/{}/d{}/{}/'.format(
                image.replace('.', '_'), pattern, str(ur).replace('.', '_'),
                algorithm), name='reconstructed_img_vec').read())
            img_shape = (h5_file.get_node('/simulation_results/{}/{}/d{}/{}/'.format(
                image.replace('.', '_'), pattern, str(ur).replace('.', '_'),
                algorithm), name='img_shape').read())
            disp_img = magni.imaging.vec2mat(recon_img, img_shape)
            plt.figure(figsize=(4,4)) # figsize chosen arbitrarily, but match it to dpi in line 129
            selection = reconstruction_metrics.query('delta == @ur and image == @image and reconstruction_algorithm == @algorithm and sampling_pattern == @pattern')
            plt.title('PSNR: {:.2f} dB / SSIM: {:.2f}'.format(selection['psnr'].iat[0],
                                                           selection['ssim'].iat[0]))
            fig = magni.imaging.visualisation.imshow(disp_img, show_axis='none')
            # Axis modification suggested in http://stackoverflow.com/a/26610602/865169
            plt.axis('off')
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.savefig('../figures/{}/recon-{}-{}-{}-{}.pdf'.format(subdir, image.replace('.mi',''),
                                                       algorithm, pattern,
                                                       str(ur).replace('.','_')),
                        bbox_inches='tight',
                        pad_inches=0, dpi=41.5) # dpi was chosen to match figzise in line 116 to obtain 128x128 resolution
            plt.close()    
