import numpy as np
import matplotlib.pyplot as plt
# from tikzplotlib import save


def plotit(data, timerange=None, coupling=True, fignum=100,
           timescale='continuos',
           plsetr=None, plsevals=None,
           filterdata=None, nstates=2, figsize=(5, 3), pname='2x2'):
    ttlstr = 'Uncoupled-System' if not coupling else 'Coupled-System'

    datalength = len(timerange)
    if filterdata is not None:
        filtert = np.arange(0, datalength, filterdata)
    else:
        filtert = np.arange(0, datalength)

    if timescale == 'timescale':
        ttlstr = ttlstr + ' on a Timescale'
    else:
        pass

    if timescale == 'discrete':
        ttlstr = 'Discrete Time ' + ttlstr
        frmstrdrv = '.'
        frmstrrsp = '.'
        drivelw = 2
    else:
        frmstrdrv = ''
        frmstrrsp = ''
        drivelw = 3

    dataarray = np.array(data)
    plt.figure(fignum, figsize=figsize)
    fignum += 1
    N = 2*nstates
    if N == 2:
        cml = [.25, .85, .15, .75]
    else:
        cml = np.linspace(0.2, .8, N)
    plt.rcParams["axes.prop_cycle"] = \
        plt.cycler("color", plt.cm.plasma(cml))
    if nstates == 2:
        leglist = ['$x_1$', '$x_2$', '$y_1$', '$y_2$',
                   '$u_1$', '$u_2$', '$v_1$', '$v_2$']
    elif nstates == 1:
        leglist = ['$y_1$', '$y_2$',
                   '$z_1$', '$z_2$']
    else:
        leglist = [None]*(nstates*4)
    if timescale == 'timescale':
        # for ctr, cdt in zip(timerange[:-1], data[:-1]):
        for ctr, cdt in zip(timerange, data):
            linol, lintl = [], []
            for kkk in range(N):
                lino, = plt.plot(ctr, cdt[:, kkk+N], frmstrrsp, linewidth=1)
                linol.append(lino)
            for kkk in range(N):
                lint, = plt.plot(ctr, cdt[:, kkk], frmstrdrv,
                                 linewidth=drivelw)
                lintl.append(lint)
        for kkk in range(N):
            linol[kkk].set_label(leglist[kkk+N])
            lintl[kkk].set_label(leglist[kkk])

        # ctr, cdt = timerange[-1], data[-1]
        # for kkk in range(N):
        #     plt.plot(ctr, cdt[:, kkk+N], frmstrrsp,
        #              label=leglist[kkk+N], linewidth=1)
        # for kkk in range(N):
        #     plt.plot(ctr, cdt[:, kkk], frmstrdrv,
        #              label=leglist[kkk], linewidth=drivelw)
    else:
        for kkk in range(N):
            plt.plot(timerange[filtert], dataarray[filtert, kkk+N], frmstrrsp,
                     label=leglist[kkk+N], linewidth=1)
        for kkk in range(N):
            plt.plot(timerange[filtert], dataarray[filtert, kkk],
                     frmstrdrv, label=leglist[kkk], linewidth=drivelw)
    if nstates <= 2:
        plt.legend(ncol=1)
    else:
        pass
    plt.xlabel('time t')
    plt.title(ttlstr)
    plt.tight_layout()
    #plt.savefig(pname+ttlstr+'.png')  # without any new folder
    #plt.savefig(pname+ttlstr+'.eps')
    # save(pname+ttlstr.replace(' ', '-')+'.tex',
    #      axis_height='\\figureheight',
    #      axis_width='\\figurewidth')
    plt.savefig('pngs/'+pname+ttlstr+'.png')    # for png folder
    plt.savefig('epss/'+pname+ttlstr+'.eps')    # for eps folder
    # plt.savefig('epss/'+pname+ttlstr+'.pgf')    # for eps folder
    # save('texs/'+pname+ttlstr.replace(' ', '-')+'.tex',
    #      axis_height='\\figureheight',
    #      axis_width='\\figurewidth')
    # plt.show()
    plt.figure(fignum, figsize=figsize)
    fignum += 1
    plt.rcParams["axes.prop_cycle"] = \
        plt.cycler("color", plt.cm.cividis(np.linspace(0.2, .8, N)))
    if nstates == 2:
        leglist = ['$x_1-u_1$', '$x_2-u_2$', '$y_1-v_1$', '$y_2-v_2$']
    elif nstates == 1:
        leglist = ['$x_1-y_1$', '$x_2-y_2$']
    else:
        leglist = [None]*(nstates*2)
    if timescale == 'timescale':
        ctr, cdt = timerange[0], data[0]
        for kkk in range(N):
            plt.plot(ctr, cdt[:, kkk]-cdt[:, kkk+N], frmstrrsp,
                     label=leglist[kkk])
        for ctr, cdt in zip(timerange[1:], data[1:]):
            for kkk in range(N):
                plt.plot(ctr, cdt[:, kkk]-cdt[:, kkk+N], frmstrrsp)
    else:
        for kkk in range(N):
            plt.plot(timerange[filtert],
                     dataarray[filtert, kkk]-dataarray[filtert, kkk+N],
                     frmstrrsp, label=leglist[kkk])
    if nstates <= 2:
        plt.legend(ncol=1)
    else:
        pass
    plt.xlabel('time t')
    plt.title(ttlstr+' Errors')
    plt.tight_layout()
    #plt.savefig(pname+ttlstr+'-errors.png')
    #plt.savefig(pname+ttlstr+'-errors.eps')
    #save(pname+ttlstr.replace(' ', '-')+'-errors.tex',
    #     axis_height='\\figureheight',
    #     axis_width='\\figurewidth')
    plt.savefig('pngs/'+pname+ttlstr+'-errors.png')
    plt.savefig('epss/'+pname+ttlstr+'-errors.eps')
    # plt.savefig('epss/'+pname+ttlstr+'-errors.pgf')
    # save('texs/'+pname+ttlstr.replace(' ', '-')+'-errors.tex',
    #      axis_height='\\figureheight',
    #      axis_width='\\figurewidth')
    # plt.show()

