"""
This script extracts the original images for the ASPMCI paper from
a specified experiment results HDF5 file.

"""

from __future__ import division
import itertools

#import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import tables as tb
import pandas as pd
from os import makedirs
import errno

import magni

# Setup - the following path can be changed if the result file is in a
# different location. This should correspond to the `result_folder`
# variable in 'aspmci_reconstructions.py', line 488:
data_path = './aspmci_reconstructions/'
hdf_database_name = 'aspmci_reconstructions.hdf5'

images = tuple(['image_{}.mi'.format(k) for k in range(7)])
imsize = 128
sampling_patterns = ('rect_spiral', 'uniform_lines')
undersampling_ratios = np.linspace(0.1, 0.3, 9)
reconstruction_algorithms = ('iht', 'ist', 'cubic_interpolation', 'fixed_ist', 'fixed_iht', 'ell_1', 'tv')

# Extract ground truth images
try:
    makedirs('../figures/fig2/')
except OSError, e:
    if e.errno != errno.EEXIST:
        throw
for image in images:
    with tb.File(data_path + hdf_database_name, mode='r') as h5_file:
        org_img = h5_file.get_node(
            '/simulation_results/{}/{}/d{}/{}/'.format(
                image.replace('.', '_'), sampling_patterns[0],
                str(undersampling_ratios[0]).replace('.', '_'),
                reconstruction_algorithms[0]), name='img_detilt_vec').read()
        disp_img = magni.imaging.vec2mat(org_img, (imsize, imsize))
        plt.figure(figsize=(1,1)) # figsize chosen arbitrarily, but match it to dpi in line 52
        fig = magni.imaging.visualisation.imshow(disp_img, show_axis='none')
        # Axis modifications suggested in http://stackoverflow.com/a/26610602/865169
        plt.axis('off')
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.savefig('../figures/fig2/' + image.replace('mi','png'), bbox_inches='tight', pad_inches=0, dpi=165.5) # dpi was chosen to match figzise in line 46 to obtain 128x128 resolution

# Generate colorbar for cool-warm colormap
y = np.array([0, 1])
x = np.linspace(0,1,1024)
X, Y = np.meshgrid(x, y)
fig, axes = plt.subplots(1, 1)
p_mesh = axes.pcolormesh(X, Y, X, vmin=0, vmax=1, edgecolor='face', cmap='coolwarm')
plt.axis('off')
fig.axes[0].get_xaxis().set_visible(False)
fig.axes[0].get_yaxis().set_visible(False)
try:
    makedirs('../figures/fig1/')
except OSError, e:
    if e.errno != errno.EEXIST:
        throw
plt.savefig('../figures/fig1/colorbar-coolwarm.png', bbox_inches='tight', pad_inches=0)
p_mesh = axes.pcolormesh(X, Y, X, vmin=0, vmax=1, edgecolor='face', cmap='afmhot')
plt.savefig('../figures/fig1/colorbar-afmhot.png', bbox_inches='tight', pad_inches=0)
