import numpy as np
import re
import numba
import json
import pandas as pd
import copy
import os

class THandler:
    """
    Level-based "transposition" of the nested data.

    We may have a nested integrator structure like this:
         R
      0    1   | e.g. splitting 
    00 01 10  11 | e.g. coupled integrators within each substepper
    and we typically want information about the leaves.
    Since some data is stored along the path, we still want to be able
    to retrieve that data here.

    Hence we define a per-level (or rather key) information for
    each node of the above tree. Information about earlier
    levels can be directly accessed by looking at the dict
    by truncating the level key to the appropriate size.

    """
    def __init__(self, needlefun):
        self.levelinfo = {}
        
    def doit(self, js, levelname, basic=False):
        self.levelinfo[levelname] = {}
        for k, v in js.items():
            if "int_" in k:
                continue
            self.levelinfo[levelname][k] = v
            
    def get_info(self, key, level):
        """
        Traverse levels leading to level in search of key. 
        """
        if key in self.levelinfo[level].keys():
            return self.levelinfo[level][key]
        else:
            lvl = level[:-1]
            if len(lvl) == 0:
                raise KeyError
            return self.get_info(key, level[:-1])

def get_levelname(level):
    return "int_"+str(level)
re_subint = re.compile("^int_[0-9]+$")
def traverse_js(js, handler, level=0, levelname="0"):
    # top level stuff ...
    if level == 0:
        pass
    # format change...
    nints = 0
    if "nints" not in js.keys():
      notbottom = "int_0" in js.keys()
      if notbottom:
        nints = 0
        for k in js.keys():
          if re_subint.match(k):
            nints += 1 
    else:
      nints = js["nints"]
      
    if nints != 0:  
      for i in range(nints):
        subjs = js[get_levelname(i)]
        traverse_js(subjs, handler, level+1, levelname+str(i))
    # bottom level
    handler.doit(js, levelname)
        
        
class intinfo:
    """
    Holds info relevant to a particular base-level integrator.

    Any info among the branch can be gotten with [key], but most common one
    is also stored for ease of use.
    """
    def __init__(self, rhsname, intname, atol, rtol, dt, rhsevals, thandler, level):
        self.rhsname = rhsname
        self.intname = intname
        self.atol = atol
        self.rtol = rtol
        self.dt = dt
        self.rhsevals = rhsevals
        self.thandler = copy.deepcopy(thandler)
        self.level = level
        
    def __str__(self, ):
        s = f"{self.intname}({self.rhsname})@"
        if self.atol != 0 or self.rtol != 0:
            s += f"[r{self.rtol:.1e},a{self.atol:.1e}]"
        else:
            s += f"[dt{self.dt:.2e}]"
        return s   
    def __getitem__(self, key):
        return self.thandler.get_info(key, self.level)
    
    
    # no setitem since it should be immutable
def handleMeta(fn):
    #try:
    with open(fn, "r") as f:
        js = json.load(f)
    #except Exception as e:
    #    print("failed in opening/parsing json", e)
    #    return []
    h = THandler(None)
    traverse_js(js, h)
    descrs = []
    for k, v in h.levelinfo.items():
        for kk, vv in v.items():
            if kk == "rhsname":
                rhsname = vv
                try:
                    intname = v["intname"]
                    rhsevals = v["rhsevals"]
                except KeyError:
                    # non-nested integtator
                    intname = h.levelinfo[k[:-1]]["intname"]
                    rhsevals = v["rhsevals"]
                
                # go "up" one layer
                # to go up the whole tree we'd successively move up a j, e.g. h.levelinfo[k:[-j]]
                # in this case we don't know when we get which keys tho
                parent = h.levelinfo[k[:-1]]
                atol = 0
                rtol = 0
                dt = 0
                try:
                    atol = parent["atol"+k[-1]]
                    rtol = parent["rtol"+k[-1]]
                    #dt = parent["dt"+k[-1]]
                except KeyError as e:
                    dt = h.levelinfo["0"]["param"]["dtbase"] #parent["dt"+k[-1]]
                #    #print("failed for", parent)
                #    #print(e)
                #    pass
                    # not adaptive, handle...
                descrs.append(intinfo(rhsname, intname, atol, rtol, dt, rhsevals, h, k))
                break
    return [h.levelinfo["0"], descrs]

tolfun = lambda x: "rtol" in x or "SSP(10)4" in x
dtfun = lambda x: "dt" in x or "SSP(10)4" in x
def build_dat(errs, metas, prefix, filtfun = None):
    """
    Build a convenient dataframe from the errors and integrator infos.
    We combine any "similar" simulations here, i.e. all STSRKL2 simulations
    using an adaptive integrator are folded into one, depending on prefix/filtfun.
    We can still map back from this to the original one by storing its key.

    :param errs: Dictionary of simname -> error.
    :param metas: Metainformation dictionary with same keys as errs.
    :param prefix: Only work on simulations with this prefix in the name.
    :param filtfun: Extra customization for filtering by a function.
    """
    idats = {}
    for k, e in errs.items():
        if filtfun is not None:
            if filtfun(k):
                continue
        mk = k #original name
        kk = str(metas[mk][1][0]) # to be simplified

        if prefix in kk:
            kk = kk.split("@")[0]
        else:
            continue
        try:
          rratio = metas[k][1][0]["rejects"]/metas[k][1][0]["steps"]
        except KeyError:
          rratio = np.nan
        
        if kk not in idats.keys():
            idats[kk] = []
        # some extra stuff is recorded for convenience
        # for other things, use origk and the original data
        idats[kk].append([metas[mk][0]["param"]["dx"],
                          metas[mk][0]["param"]["interface_param"],
                          metas[mk][0]["param"]["dtused"],
                          e,
                          metas[mk][1][0].rhsevals,
                          metas[mk][0]["runtime"],
                          metas[mk][0]["cellcount"],
                          metas[mk][1][0].atol,
                          metas[mk][1][0].rtol,
                          mk
                         ])
    dfs = {}
    for kk, vv in idats.items():
        dfs[kk] = pd.DataFrame(vv, 
                               columns = ["dx", "W", "dt", "error", "rhseval", "runtime", "cellcount", "atol", "rtol", "origk"])
    return dfs


def linpow(dx, n, pre):
    """"
    log of y = A*x**n, pre = lg(A)
    """
    return n*dx + pre

class TimedData:
    """
    Helper class to reload data only if some file related to it is more recent
    than the last time.

    Inherit from this, use is_current before doing actual work and skip stuff
    if not do_reload. mtimes needs to be updated by the inheriting class though,
    which should be done after everything else is done.
    """
    def __init__(self):
        self.mtimes = {}
    def is_current(self, pn, fn):
        do_reload = True
        try:
            mt = os.path.getmtime(fn)
            if pn in self.mtimes.keys():
                if mt <= self.mtimes[pn]:
                    do_reload = False
        except FileNotFoundError as e:
            mt = 0
            print(e)
            print("skipping due to missing file", pn)
            do_reload = False
        return mt, do_reload
    def remove_old(self, pns):
      datas = self.get_alldata() # this should be overridden by subclass
      old = {}
      for pn in self.mtimes.keys():
        if pn not in pns:
          old[pn] = True
      # avoid modifying mtimes during loop by saving it for now
      for pn in old.keys():
        for d in datas:
          del d[pn]
    def get_alldata(self):
      return self.mtimes

# common functions
cs0 = np.array([0.02, 0.98])
ks = np.array([500, 500])
def calc_psia(c, phiv, a, cs0, ks):
    mu = calc_mu(c, phiv)
    return - cs0[a]*mu - mu**2 / (2*ks[a])
def calc_psi(c, phiv):
    #cc = c.reshape(phiv.shape)
    mu = calc_mu(c, phiv)
    gsum =calc_ga(c, phiv)
    return gsum-mu*c

@numba.njit
def calc_ca(c, phiv, cs0=cs0, ks=ks):
    ca = np.zeros((2, *c.shape))
    prefac = 1/KS # since it's the same...
    idx0 = int(0 == 0)
    idx1 = int(1 == 0)
    pre1 = -2*0+1
    pre2 = -2*1+1
    ca[0] = (c * ks[idx0] + pre1 * (cs0[0]*ks[0]*(1-phiv) - cs0[1]*ks[1]*(1-phiv))) * prefac
    ca[1] = (c * ks[idx1] + pre2 * (cs0[0]*ks[0]*(phiv) - cs0[1]*ks[1]*(phiv))) * prefac
    return ca

@numba.njit
def calc_ga(c, phiv, cs0=cs0, ks=ks):
    ca = calc_ca(c, phiv)
    return 0.5 * ks[0] * (ca[0]-cs0[0])**2 * phiv + 0.5 * ks[1] * (ca[1]-cs0[1])**2 * (1-phiv)

class ClipInterpolation:
    """
    "Interpolate" to smallest value above input. Only used if we use moving window
    but we don't after all.
    """
    def __init__(self, xdat, ydat):
        self.x = xdat
        self.y = ydat
    def __call__(self, x):
        isarray = True
        try:
            itmp = len(x)
        except:
            isarray = False
        if not isarray:
            x = np.array(x)
        
        ret = np.zeros_like(x)
        for j, inix in enumerate(x):
            found = False
            for i, ix in enumerate(self.x):
                
                if ix > inix or (ix == inix and inix != 0):
                    found=True
                    break
            if (found):
                ret[j] = self.y[i-1]
            else:
                ret[j] = self.y[-1]
        return ret

def perfplot(dats_ad, dats_dt, ax, metas, xfac=1, calc_speedup = False, xax="rhseval", yax="error",
    colmap = None, add_label=True):
    if colmap is None:
        import layout
        colmap = layout.colmap
    if calc_speedup:
        itmp = dats_dt["FEuler(MPF::Obstacle)"]
        maxdt = itmp["dt"].max()
        rhs_fe = float((itmp[itmp["dt"] == maxdt])[xax].values[0])

    for k, v in dats_ad.items():
        col = colmap[k]
        Wfilt = v["W"] == 2.5
        vv = v[Wfilt]
        #grps = v[Wfilt].groupby("rhseval")
        if calc_speedup:
            # find smallest rhseval for this integrator
            minrhs = vv[xax].min()
            rejratios = []
            try:
                tols = []
                for pn in v["origk"]:
                    meta = metas[pn]
                    steps, rejs = meta[1][0]["steps"], meta[1][0]["rejects"]
                    tols.append(meta[1][0]["atol0"])
                    rejratios.append(rejs/steps)
                rejratios = [x for _, x in sorted(zip(tols, rejratios))][::-1] # sort by tols, loosest first
            except Exception as e:
              print(e)
            print("adaptive", k, "speedup", rhs_fe/float(minrhs), "rejection ratio", *rejratios)
        iv = vv.sort_values(by=xax)
        lab = k.replace("(MPF::Obstacle)", "") 
        pl = ax.plot(iv[xax]/xfac, iv[yax], "-o", color=col,  label=lab + " adaptive" if add_label else "",
                    markerfacecolor="None")
    for k, v in dats_dt.items():
        col = colmap[k]
        Wfilt = v["W"] == 2.5
        vv = v[Wfilt]
        #grps = v[Wfilt].groupby("rhseval")
        mark = "-x"
        zord=1
        if "FEuler" in k:
            zord = 10
        if calc_speedup:
            # find smallest rhseval for this integrator
            minrhs = vv[xax].min()
            print("fixed", k, "speedup", rhs_fe/float(minrhs))
                
        iv = vv.sort_values(by=xax)
        lab = k.replace("(MPF::Obstacle)", "")
        pl = ax.plot(iv[xax]/xfac, iv[yax], mark, color=col, label=lab if add_label else "",
                     zorder = zord,
                    markerfacecolor="None")
