import tarfile
import numpy as np
import os
from sys import exit

import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['xtick.labelsize'] = 18
mpl.rcParams['ytick.labelsize'] = 18

import matplotlib.pyplot as plt
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['savefig.dpi'] = 300


#
# Global parameters
#

# should we save the plot
save_plot = True


# cell_data_to_plot is list, or sublist, of:
#   "epithelial", "macrophage", "virus" or "cytokine"
cell_data_to_plot = ["epithelial", "macrophage", "virus", "cytokine"]


#
# Plotting parameters
#

plot_params = {
    "xlabel": "Time (hrs)",
    "xticks": [0, 24, 48, 72, 96, 120],
    "epithelial": {
        "filenames": [
            "cell_counts_epithelial_infected",
            "cell_counts_epithelial_apoptotic",
            "cell_counts_epithelial_phagocytosed"
        ],
        "ylabel": "Epithelial numbers"
    },
    "macrophage": {
        "filenames": [
            "cell_counts_macrophages_alveolar_resting",
            "cell_counts_macrophages_alveolar_active",
            "cell_counts_macrophages_alveolar_apoptotic"
        ],
        "ylabel": "Macrophage numbers"
    },
    "virus": {
        "filenames": [
            "cell_counts_virus_inf_intracellular",
            "cell_counts_virus_inf_extracellular"
        ],
        "ylabel": "Viral load (virions)"
    },
    "cytokine": {
        "filenames": [
            "rd_total_ifn_1",
            "rd_total_il_6",
            "rd_total_il_10",
            "rd_total_phago_ck"
        ],
        "ylabel": "Cytokine levels (nM)"
    }
}


#
# Plotting specific data
#

plot_data = {
    "section_3-3_results_macrophage_parameter_variations_virus_internalisation_rate": {
        "subdirs": [
            "rate_0-006",
            "rate_0-06",
            "rate_0-6"
        ],
        "legend_labels": [
            "Virus uptake = 0.006 hr$^{-1}$",
            "Virus uptake = 0.06 hr$^{-1}$",
            "Virus uptake = 0.6 hr$^{-1}$"
        ],
        "linestyles": ['-', '--', '-.'],
        "color": "black"
    },
    "section_3-3_results_macrophage_parameter_variations_activation_halfmax": {
        "subdirs": [
            "halfmax_0-0006",
            "halfmax_0-006",
            "halfmax_0-06"
        ],
        "legend_labels": [
            "Activation half-max = 0.0006 nM",
            "Activation half-max = 0.006 nM",
            "Activation half-max = 0.06 nM"
        ],
        "linestyles": ['-', '--', '-.'],
        "color": "black"
    }
}


# convert cell_data_to_plot to a list if it isn't already a list
if not isinstance(cell_data_to_plot, list):
    cell_data_to_plot = list([cell_data_to_plot])


# sanity check before plotting cell data
for cdattp in cell_data_to_plot:
    if not cdattp in plot_params:
        raise ValueError("unknown cell data to plot: {0}".format(cdattp))


#
# construct the plot
#

for cdattp in cell_data_to_plot:
    for fn in plot_params[cdattp]["filenames"]:
        # name of output file which is stored in raw.tar.gz
        plot_file = "{0}.txt".format(fn)

        for directory, pdat in plot_data.items():
            # Create figure with subplots
            fig, ax = plt.subplots()

            for i in range(len(pdat["subdirs"])):
                # tar file
                odir = "output"
                tarf = "{0}/{1}/{2}/raw.tar.gz".format(directory, pdat["subdirs"][i], odir)

                # Print to screen
                print('\nTarfile: \n\t{0}'.format(tarf))

                with tarfile.open(tarf, "r:gz") as tar:
                    # load the data to plot from the tar file
                    try:
                        data = np.loadtxt(tar.extractfile(plot_file).readlines(), delimiter=' ', dtype='f8')
                    except:
                        raise

                    # set time
                    time = data[:,0]

                    # data
                    data_to_plot = tuple(data[:,1])

                    # Plot
                    ax.plot(time, data_to_plot, pdat["linestyles"][i], label=pdat["legend_labels"][i], color=pdat["color"])

                    # legend
                    ax.legend(loc='upper left')

                    # labels
                    ax.set_xlabel(plot_params["xlabel"], fontsize=13)
                    ax.set_ylabel(plot_params[cdattp]["ylabel"], fontsize=13)

                    # xticks
                    ax.set_xticks(plot_params["xticks"])

            if save_plot:
                sdir = "{0}/figs".format(directory)

                # Create directory if it does not exist
                if not os.path.isdir(sdir):
                    try:
                        os.makedirs(sdir)
                    except:
                        raise

                print("saved to: {0}/line_plot_{1}.pdf".format(sdir, fn))
                plt.savefig("{0}/line_plot_{1}.pdf".format(sdir, fn))

            else:
                plt.show()

            plt.close()

#
# exit upon successful completion
#

exit()
