from __future__ import print_function

import numpy as np


#===============================================================================
# Define the grid codes
#===============================================================================

# _grid_codes['name'] = (code, dot_size_scale, color)
_grid_codes = {
    'blood vessel': (0, 10, 'cyan'),
    'infected epithelial': (11, 5, 'royalblue'),
    'infectious epithelial': (12, 5, 'darkblue'),
    'apoptotic epithelial': (13, 5, 'grey'),
    'removed epithelial': (14, 5, 'black'),
    'burst epithelial': (15, 5, 'black'),
    'resting macrophage': (200, 5, 'red'),
    'active macrophage': (201, 5, 'darkred'),
    'apoptotic macrophage': (202, 5, 'grey'),
    'apoptotic cell': (-10, 5, 'grey')
}

#===============================================================================


#
# Initialise the MPL routines by creating an individual cache dir for each worker process.
#

def mpl_initialise():
    # Get the ID of the current worker process
    from multiprocessing import current_process
    name, idn = current_process().name.split('-')

    # Set the MPL cache directory
    import os
    mpldir = './mpl_cache/worker_{0}/matplotlib'.format(idn)
    os.environ['MPLCONFIGDIR'] = mpldir

    # Import Matplotlib and define global params
    import matplotlib as mpl
    mpl.rcParams['text.usetex'] = True
    mpl.rcParams['xtick.labelsize'] = 18
    mpl.rcParams['ytick.labelsize'] = 18
    mpl.rcParams['agg.path.chunksize'] = 10000

    import matplotlib.pyplot as plt
    plt.rcParams['savefig.bbox'] = 'tight'
    plt.rcParams['savefig.dpi'] = 300


#
# Create the Matplotlib figure. Return the figure and axes.
#

def mpl_create( asp, figsz=None ):
    # Check args
    if (not figsz is None) and (not isinstance(figsz, tuple)):
        raise ValueError('figsz must be a tuple')

    # Imports
    import matplotlib.pyplot as plt

    # Create the figure
    if figsz is None:
        fig = plt.figure()
    else:
        fig = plt.figure(figsize=figsz)

    # Create axes and set aspect
    ax = fig.subplots()
    if not asp == 'none':
        ax.set_aspect(asp)

    return fig, ax


#
# Plot of cell codes on a grid in time
#

def plot_scatter( args ):
    # Initialise MPL
    mpl_initialise()

    # Imports
    import matplotlib.pyplot as plt
    import matplotlib.lines as mlines

    # Check number of args
    nargs = 6
    if len(args) != nargs:
        raise ValueError('number of args is not correct; expecting {0}'.format(nargs))

    # Unpack the arguments from args
    dat, savedir, filename, fr, plt_opts, count = args

    if isinstance(filename, dict):
        raise ValueError("filename must be a string")

    # Check expected plotting options
    expected_options = (
        "combined_plot", 'timestep', 'output_interval', 'grid_size', 'dot_size',
        'xlim', 'ylim', 'xticks', 'yticks', 'lbl', 'lblfontsize'
    )

    # sanity check on plot options
    if not all(k in plt_opts for k in expected_options):
        missing = [k for k in expected_options if not k in plt_opts]
        raise ValueError('plotting options missing: {0}'.format(missing))


    # Create output file name
    oname = filename % count


    # Define local plotting variables
    combined_plot = plt_opts["combined_plot"]
    dt = plt_opts['timestep']
    out_int = plt_opts['output_interval']
    grid_size = plt_opts['grid_size']
    dot_size = plt_opts['dot_size']
    xlim = plt_opts['xlim']
    ylim = plt_opts['ylim']
    xticks = plt_opts['xticks']
    yticks = plt_opts['yticks']
    lbl = plt_opts['lbl']
    lblfz = plt_opts['lblfontsize']


    # define possible legend entries
    legend_entries = {}
    for k in _grid_codes.keys():
        legend_entries[k] = mlines.Line2D([], [], color="white", linestyle='', marker='o', markerfacecolor=_grid_codes[k][2], markersize=2*dot_size)


    #===========================================================================
    # Construct the plot
    #===========================================================================

    if not combined_plot:
        raise ValueError("scatter plot assumes a combined plot")

    # Create the matplotlib figure
    fig, ax = mpl_create('equal')

    x2 = []   # x index
    y2 = []   # y index
    r2 = []   # dot size
    c2 = []   # marker colours
    l2 = []   # legend entries
    l2_n = [] # legend entries names
    for k in dat.keys():
        for i in np.arange(0,grid_size,1):
            for j in np.arange(0,grid_size,1):
                C = dat[k][i*grid_size + j]

                for gk, gv in _grid_codes.items():
                    if C == gv[0]:
                        x2.append(i)
                        y2.append(j)
                        r2.append(gv[1]*dot_size)
                        c2.append(gv[2])
                        if "apoptotic" in gk:
                            l2.append(legend_entries["apoptotic cell"])
                            l2_n.append("apoptotic cell")
                        elif "vessel" in gk:
                            continue
                        else:
                            l2.append(legend_entries[gk])
                            l2_n.append(gk)

    ax.scatter(x2, y2, s=r2, c=c2)

    # x limit
    if not xlim is None:
        if len(xlim) == 2:
            ax.set_xlim(xlim[0], xlim[1])

    # y limit
    if not ylim is None:
        if len(ylim) == 2:
            ax.set_ylim(ylim[0], ylim[1])

    # x ticks
    if not xticks is None:
        if isinstance(xticks,list):
            ax.set_xticks(xticks)

    # y ticks
    if not yticks is None:
        if isinstance(yticks,list):
            ax.set_yticks(yticks)

    # axis labels
    if not lbl is None:
        if len(lbl) == 2:
            ax.set_xlabel(lbl[0], fontsize=lblfz)
            ax.set_ylabel(lbl[1], fontsize=lblfz)

    time = (fr - 1) * out_int * dt / dt
    ax.set_title('Time = {0:.1f} Hrs / {1:.1f} Days / {2:.1f} Wks'.format(time, time/24.0, time/24.0/7.0))

    # create legend components
    leg_comp = tuple(sorted(set(l2), key=l2.index))
    leg_comp_n = tuple(sorted(set(l2_n), key=l2_n.index))

    # add legend
    bbox = [-0.12, -1.1, 1.16, 1]
    plt.legend(
        leg_comp, leg_comp_n, loc='upper left', ncol=3, mode="expand",
        bbox_to_anchor=bbox, borderaxespad=0., handletextpad=0.01
    )

    # Save the figure
    plt.savefig('{0}/{1}'.format(savedir,oname))
    plt.close('all')


#
# Plot of reaction-diffusion contours
#

def plot_contourf( args ):
    # Initialise MPL
    mpl_initialise()

    # Imports
    import matplotlib.pyplot as plt

    # Check number of args
    nargs = 6
    if len(args) != nargs:
        raise ValueError('number of args is not correct; expecting {0}'.format(nargs))

    # Store the arguments
    dat, savedir, filename, fr, plt_opts, count = args

    if not isinstance(filename, dict):
        raise ValueError("file name must be a dictionary")

    if not all(k in dat.keys() for k in filename.keys()):
        raise ValueError("file names and data keys do not match")

    # Check expected plotting options
    expected_options = (
        "combined_plot", 'timestep', 'output_interval', 'grid_size', 'dot_size',
        'xlim', 'ylim', 'xticks', 'yticks', 'lbl', 'lblfontsize'
    )

    if not all(k in plt_opts for k in expected_options):
        missing = [k for k in expected_options if not k in plt_opts]
        raise ValueError('plotting options missing: {0}'.format(missing))


    # Create output file name
    oname = {}
    for k in filename.keys():
        oname[k] = filename[k] % count


    # Define local plotting variables
    combined_plot = plt_opts["combined_plot"]
    dt = plt_opts['timestep']
    out_int = plt_opts['output_interval']
    grid_size = plt_opts['grid_size']
    dot_size = plt_opts['dot_size']
    xlim = plt_opts['xlim']
    ylim = plt_opts['ylim']
    xticks = plt_opts['xticks']
    yticks = plt_opts['yticks']
    lbl = plt_opts['lbl']
    lblfz = plt_opts['lblfontsize']

    #===========================================================================
    # Construct the plot
    #===========================================================================

    if combined_plot:
        raise ValueError("contourf plot cannot use a combined plot")

    for k in dat.keys():
        # Create the matplotlib figure
        fig, ax = mpl_create('equal')

        max_ifn = -np.inf
        min_ifn = np.inf
        C = [[0.0 for j in np.arange(1,grid_size+1,1)] for i in np.arange(1,grid_size+1,1)]
        for i in np.arange(0,grid_size,1):
            for j in np.arange(0,grid_size,1):
                C[i][j] = dat[k][i*grid_size + j]

                if C[i][j] < min_ifn:
                    min_ifn = C[i][j]

                if C[i][j] > max_ifn:
                    max_ifn = C[i][j]

        for i in np.arange(0,grid_size,1):
            for j in np.arange(0,grid_size,1):
                C[i][j] = C[i][j] / max_ifn * 100.0 if max_ifn > 1e-15 else C[i][j]

        plt.contourf(list(map(list, zip(*C))), levels=np.linspace(0,100,100,endpoint=True), vmin=0, vmax=100)

        cbar = plt.colorbar()
        cbar.set_ticks(np.linspace(0, 100, 11, endpoint=True))

        # x limit
        if not xlim is None:
            if len(xlim) == 2:
                ax.set_xlim(xlim[0], xlim[1])

        # y limit
        if not ylim is None:
            if len(ylim) == 2:
                ax.set_ylim(ylim[0], ylim[1])

        # x ticks
        if not xticks is None:
            if isinstance(xticks,list):
                ax.set_xticks(xticks)

        # y ticks
        if not yticks is None:
            if isinstance(yticks,list):
                ax.set_yticks(yticks)

        # axis labels
        if not lbl is None:
            if len(lbl) == 2:
                ax.set_xlabel(lbl[0], fontsize=lblfz)
                ax.set_ylabel(lbl[1], fontsize=lblfz)

        time = (fr - 1) * float(out_int) * float(dt) / float(dt)
        ax.set_title('Time = {0:.1f} Hrs / {1:.1f} Days / {2:.1f} Wks'.format(time, time/24.0, time/24.0/7.0))

        # Save the figure
        plt.savefig('{0}/{1}'.format(savedir,oname[k]))
        plt.close('all')
