import tarfile
import numpy as np
import os
from sys import exit

import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['xtick.labelsize'] = 18
mpl.rcParams['ytick.labelsize'] = 18

import matplotlib.pyplot as plt
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['savefig.dpi'] = 300

import entry_repl_model as ode_model


#
# Global parameters
#

# should we save the plot
save_plot = True


# cell_data_to_plot is list, or sublist, of:
#   "entry", "replication"
cell_data_to_plot = ["entry", "replication"]


#
# Plotting parameters
#

plot_params = {
    "sim_time": 15,
    "xlabel": "Time (hrs)",
    "color": [
        "black", "blue", "red", "magenta", "orange"
    ],
    "entry": {
        "dim": 4,
        "filenames": [
            "re"
        ],
        "ylabel": "Entry components",
        "legend_labels": {
            # "$[R_{eu}]$": 1,
            "$[R_{eb}]$": 2,
            "$[R_{ib}]$": 3,
            "$[R_{iu}]$": 4
        },
        "exact_sol_min_ind": 1
    },
    "replication": {
        "dim": 5,
        "filenames": [
            "vr"
        ],
        "ylabel": "Replication components",
        "legend_labels": {
            "$[V]$": 1,
            "$[U]$": 2,
            "$[R]$": 3,
            "$[P]$": 4,
            "$[A]$": 5
        },
        "exact_sol_min_ind": 4
    }
}


# convert cell_data_to_plot to a list if it isn't already a list
if not isinstance(cell_data_to_plot, list):
    cell_data_to_plot = list([cell_data_to_plot])


# sanity check before plotting cell data
for cdattp in cell_data_to_plot:
    if not cdattp in plot_params:
        raise ValueError("unknown cell data to plot: {0}".format(cdattp))


# set simulation time
sim_time = plot_params["sim_time"]


#
# get the exact solution of ODE model
#

exact_time, exact_sol = ode_model.entry_repl_solve(sim_time)


#
# construct the plot
#

# increment simulation time
sim_time += 1

for cdattp in cell_data_to_plot:
    for fn in plot_params[cdattp]["filenames"]:
        # name of output file which is stored in raw.tar.gz
        plot_file = "{0}.txt".format(fn)

        # splice exact time
        exact_t = exact_time[:sim_time]

        # splice exact sol
        min_ind = plot_params[cdattp]["exact_sol_min_ind"]
        max_ind = min_ind + plot_params[cdattp]["dim"]
        exact_s = exact_sol[min_ind:max_ind, :sim_time]

        # Create figure with subplots
        fig, ax = plt.subplots()

        # Plot
        j = 0
        for leg_lbl,ind in plot_params[cdattp]["legend_labels"].items():
            ax.plot(exact_t, exact_s[j].T, '-', label=leg_lbl, color=plot_params["color"][ind-1])
            j += 1

        # tar file
        tarf = "output/raw.tar.gz"
        print("tar: {0}".format(tarf))

        with tarfile.open(tarf, "r:gz") as tar:
            # load the data to plot from the tar file
            try:
                data = np.loadtxt(tar.extractfile(plot_file).readlines(), delimiter=' ', dtype='f8')
            except:
                raise

            # set time
            time = data[:sim_time, 0]

            # data
            data_to_plot = []
            for ind in plot_params[cdattp]["legend_labels"].values():
                data_to_plot.append(tuple(data[:sim_time, ind]))

            # Plot
            j = 0
            for ind in plot_params[cdattp]["legend_labels"].values():
                ax.plot(time, data_to_plot[j], '.', label="", color=plot_params["color"][ind-1])
                j += 1

            # legend
            if len(data_to_plot) == 4 and cdattp == "entry":
                ax.legend(loc='upper left', bbox_to_anchor=(0.0, 0.92))
            else:
                ax.legend(loc='upper left')

            # labels
            ax.set_xlabel(plot_params["xlabel"], fontsize=13)
            ax.set_ylabel(plot_params[cdattp]["ylabel"], fontsize=13)

        if save_plot:
            sdir = "figs"

            # Create directory if it does not exist
            if not os.path.isdir(sdir):
                try:
                    os.makedirs(sdir)
                except:
                    raise

            print("saved to: {0}/line_plot_{1}.pdf".format(sdir, fn))
            plt.savefig("{0}/line_plot_{1}.pdf".format(sdir, fn))

        else:
            plt.show()

        plt.close()

#
# exit upon successful completion
#

exit()
