"""
This script extracts data and calculates and plots statistics for
the ASPMCI paper from a specified experiment results HDF5 file.

"""

from __future__ import division
import itertools

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from os import makedirs
import errno

import magni
from magni.utils.plotting import colour_collections

# Setup - the following path can be changed if the result file is in a
# different location. This should correspond to the `result_folder`
# variable in 'aspmci_reconstructions.py', line 488:
data_path = './aspmci_reconstructions/'
hdf_database_name = 'aspmci_reconstructions.hdf5'
try:
    mpl.rc('text', usetex=False)
except AttributeError:
    None

undersampling_ratios = np.linspace(0.1, 0.3, 9)

# Extract data for graphs/tables
hdf_store = pd.HDFStore(data_path + hdf_database_name)
reconstruction_metrics = hdf_store.select('/simulation_results/metrics')
grouped = reconstruction_metrics.groupby(['sampling_pattern', 'reconstruction_algorithm', 'delta']).mean()
levels = grouped.index.levels
pattern_names = {'rect_spiral': 'spiral', 'uniform_lines': 'raster'}
recon_names = {'cubic_interpolation':'Interpolation', 'tv': 'Total variation', 'ell_1': 'L1-minimisation', 'iht': 'IHT', 'ist': 'IST', 'l1_amp': 'Laplace AMP', 'bg_amp': 'Bernoulli-Gauss AMP'}
line_specs = ('-','--')
magni.utils.plotting.setup_matplotlib({'figure': {'figsize': (10, 3)}, 'axes': {'color_cycle': colour_collections['cb4']['OrRd'] + colour_collections['cb3']['Blues']}});

# Create figure and iterate over sampling patterns and reconstruction
# algorithms to plot the statistics of PSNR, SSIM, and reconstruction
# time over these combinations
fig, axes = plt.subplots(1, 3)
for j, samp_patt in enumerate(levels[0]):
    for k, recon_alg in enumerate(levels[1]):
        if (samp_patt != 'spiral') & (recon_alg not in ('fixed_iht', 'fixed_ist')):
            axes[0].plot(undersampling_ratios, grouped['psnr'][samp_patt][recon_alg], line_specs[j], label=recon_names[recon_alg] + ', ' + pattern_names[samp_patt])
            axes[1].plot(undersampling_ratios, grouped['ssim'][samp_patt][recon_alg], line_specs[j], label=recon_names[recon_alg] + ', ' + pattern_names[samp_patt])
            if recon_alg == 'bg_amp':
                axes[2].plot(undersampling_ratios, np.nan * np.ones_like(undersampling_ratios), line_specs[j], label=recon_names[recon_alg] + ', ' + pattern_names[samp_patt])
            else:
                axes[2].plot(undersampling_ratios, grouped['time'][samp_patt][recon_alg], line_specs[j], label=recon_names[recon_alg] + ', ' + pattern_names[samp_patt])                
axes[0].set_xlabel('Undersampling ratio, $\delta$')
axes[1].set_xlabel('Undersampling ratio, $\delta$')
axes[2].set_xlabel('Undersampling ratio, $\delta$')
axes[0].set_ylabel('PSNR [dB]')
axes[1].set_ylabel('SSIM')
axes[2].set_ylabel('Time [s]')
axes[1].legend(loc='lower center', ncol = 4, borderaxespad=-10, borderpad=.6)
plt.tight_layout()
try:
    makedirs('../figures/fig3/')
except OSError, e:
    if e.errno != errno.EEXIST:
        throw
plt.savefig('../figures/fig3/graphs-psnr-ssim-time.pdf', bbox_inches='tight', pad_inches=0)
plt.close()

hdf_store.close()
