import sys
import numpy as np
import h5py
from gtda.homology import CubicalPersistence
from gtda.diagrams import PersistenceImage
import itertools
from multiprocessing import Pool
from struct import *

# define functions

def matrix(a, b):
    return np.matrix([[a, -b.conjugate()], [b, a.conjugate()]])

# turn quaternion entries into SU(2) matrices
def linksQ(e):
    return (matrix(complex(e[0], e[3]), complex(e[2], e[1])),
            matrix(complex(e[4], e[7]), complex(e[6], e[5])),
            matrix(complex(e[8], e[11]), complex(e[10], e[9])),
            matrix(complex(e[12], e[15]), complex(e[14], e[13])))

def parseConfig(cnfg):
    NG, GLB_T, GLB_X, GLB_Y, GLB_Z, plaq = unpack('>iiiiid', cnfg[0:28])
    s = [GLB_X*GLB_Y*GLB_Z, GLB_Y*GLB_Z, GLB_Z]
    config = np.array([[[[linksQ(unpack('>dddddddddddddddd',cnfg[28+128*(t*s[0]+x*s[1]+y*s[2]+z):28+128*(t*s[0]+x*s[1]+y*s[2]+z+1)]))
                      for z in range(GLB_Z)]
                     for y in range(GLB_Y)]
                    for x in range(GLB_X)]
                   for t in range(GLB_T)])
    return config

def wilsonLoop(conf, t, x, y, z, directions):
    dx = [[1 if i == d else 0 for i in range(4)] for d in directions]
    mat = (np.matrix(conf[t,x,y,z,directions[0]])
           * np.matrix(conf[(t+dx[0][0])%conf.shape[0],(x+dx[0][1])%conf.shape[1],(y+dx[0][2])%conf.shape[2],(z+dx[0][3])%conf.shape[3],directions[1]])
           * np.matrix(conf[(t+dx[1][0])%conf.shape[0],(x+dx[1][1])%conf.shape[1],(y+dx[1][2])%conf.shape[2],(z+dx[1][3])%conf.shape[3],directions[0]]).H
           * np.matrix(conf[t,x,y,z,directions[1]]).H)
    return np.array(np.real(mat.trace()))[0][0]

plane_labels = {(0,1) : 0, (0,2) : 1, (0,3) : 2, (1,2) : 3, (1,3) : 4, (2,3) : 5}

def plaqs(conf):
    plaqs = np.zeros(conf.shape[:4] + (6,))
    for t,x,y,z in itertools.product(*[range(s) for s in conf.shape[:4]]):
        for k, v in plane_labels.items():
            plaqs[t,x,y,z,v] = wilsonLoop(conf,t,x,y,z,k)
    return plaqs

pm = np.array([[[1,0,0,0],[-1,0,0,0]], [[0,1,0,0],[0,-1,0,0]],
                [[0,0,1,0],[0,0,-1,0]], [[0,0,0,1],[0,0,0,-1]]])

# compute in 3 sweeps:
#  - first compute the plaquettes
#  - then set the 3-cubes to the max of its faces and 1-cubes to the min of its cofaces
#  - then similarly for the 4-cubes and 0-cubes
def cubicalFiltration(plaqs):
    filt = np.zeros(tuple([2*s for s in plaqs.shape[:4]]))
    for t,x,y,z in itertools.product(*[range(2*s) for s in plaqs.shape[:4]]):
        dim = t%2 + x%2 + y%2 + z%2
        if dim == 2:
            # look at orthogonal directions
            orthDirs = tuple([i for i in range(4) if [t,x,y,z][i]%2 == 0])
            parDirs = [1 if [t,x,y,z][i]%2 == 1 else 0 for i in range(4)]
            p = plaqs[((t//2)+parDirs[0])%plaqs.shape[0],
                        ((x//2)+parDirs[1])%plaqs.shape[1],
                        ((y//2)+parDirs[2])%plaqs.shape[2],
                        ((z//2)+parDirs[3])%plaqs.shape[3],
                        plane_labels[orthDirs]]
            filt[t,x,y,z] = 0.5 * p
    for t,x,y,z in itertools.product(*[range(2*s) for s in plaqs.shape[:4]]):
        dim = t%2 + x%2 + y%2 + z%2
        if dim == 1:
            dirs = np.concatenate(pm[[i for i in range(4) if [t,x,y,z][i]%2 == 0]])
            filt[t,x,y,z] = min([filt[(t+dt)%(2*plaqs.shape[0]),(x+dx)%(2*plaqs.shape[1]),(y+dy)%(2*plaqs.shape[2]),(z+dz)%(2*plaqs.shape[3])] for [dt,dx,dy,dz] in dirs])
        elif dim == 3:
            dirs = np.concatenate(pm[[i for i in range(4) if [t,x,y,z][i]%2 == 1]])
            filt[t,x,y,z] = max([filt[(t+dt)%(2*plaqs.shape[0]),(x+dx)%(2*plaqs.shape[1]),(y+dy)%(2*plaqs.shape[2]),(z+dz)%(2*plaqs.shape[3])] for [dt,dx,dy,dz] in dirs])
    for t,x,y,z in itertools.product(*[range(2*s) for s in plaqs.shape[:4]]):
        dim = t%2 + x%2 + y%2 + z%2
        if dim == 0:
            dirs = np.concatenate(pm[[i for i in range(4) if [t,x,y,z][i]%2 == 0]])
            filt[t,x,y,z] = min([filt[(t+dt)%(2*plaqs.shape[0]),(x+dx)%(2*plaqs.shape[1]),(y+dy)%(2*plaqs.shape[2]),(z+dz)%(2*plaqs.shape[3])] for [dt,dx,dy,dz] in dirs])
        elif dim == 4:
            dirs = np.concatenate(pm[[i for i in range(4) if [t,x,y,z][i]%2 == 1]])
            filt[t,x,y,z] = max([filt[(t+dt)%(2*plaqs.shape[0]),(x+dx)%(2*plaqs.shape[1]),(y+dy)%(2*plaqs.shape[2]),(z+dz)%(2*plaqs.shape[3])] for [dt,dx,dy,dz] in dirs])
    return filt

if __name__ == '__main__':
    # load settings

    with open("input", "r") as f:
        input_lines = f.read().splitlines()

    betas, Nt, Ns, runs, NP = None, None, None, None, 1
    for l in input_lines:
        if l.startswith("betas="):
            betas = [float(b) for b in l.split("=")[1].split(" ")]
        elif l.startswith("Nt="):
            Nt = int(l.split("=")[1])
        elif l.startswith("Ns="):
            Ns = int(l.split("=")[1])
        elif l.startswith("runs="):
            runs = int(l.split("=")[1])
        elif l.startswith("NP="):
            NP = int(l.split("=")[1])

    if (betas is None) or (Nt is None) or (Ns is None) or (runs is None):
        sys.exit("Missing arguments from input. Make sure there is a value for betas, Nt, Ns and runs.")

    # load configs

    configs = []
    for b in betas:
        configs_b = []
        for r in range(runs):
            fn = "confs/run1_"+str(Nt)+"x"+str(Ns)+"x"+str(Ns)+"x"+str(Ns)+"nc2b" + '%.6f'%b + "an1.000000n" + str(r)
            with open(fn, 'rb') as f:
                configs_b.append(parseConfig(f.read()))
        configs.append(configs_b)

    # compute filtrations

    with Pool(NP) as p:
        plaqsC = [p.map(plaqs, cs) for cs in configs]

    with Pool(NP) as p:
        filts = [p.map(cubicalFiltration, cs) for cs in plaqsC]

    # compute persistence diagrams

    homology_dims = [0,1,2,3]

    cp = CubicalPersistence(homology_dimensions=homology_dims, coeff=2, periodic_dimensions=np.array([True, True, True, True]), reduced_homology=False, infinity_values=1.1, n_jobs=NP)

    pers = [cp.fit_transform(fs) for fs in filts]

    # compute persistence images

    def weight(x):
        return x
    pi_res = 25
    pi_sigma = 0.05

    pi = PersistenceImage(sigma=pi_sigma, n_bins=pi_res, weight_function=weight, n_jobs=NP)
    _=pi.fit([sum([[[-1,-1,d],[-1,1,d],[1,1,d]] for d in homology_dims], [])])

    pis = [pi.fit_transform(ps) for ps in pers]

    # save persistence images

    filename = "pis_Nt=" + str(Nt) + "_Ns=" + str(Ns) + ".h5"

    with h5py.File(filename, 'w') as hf:
            hf.create_dataset("persistence_images", data=pis)
