# CPU/CUDA version of tensor formulation of 2d classical models. 

import numpy as np
import scipy as sp  
import sys 
import os 
import time
from scipy.special import iv
from math import sqrt, log
from numpy import prod
from datetime import datetime
from math import sqrt, log, cos, cosh, sinh, tanh, pi
import scipy.integrate as integrate
from numpy import linalg as LA
from itertools import product

import torch 
use_cuda = torch.cuda.is_available()

print("NumPy version", np.__version__)
print("SciPy version", sp.__version__)
print("Torch version", torch.__version__)


# Check if the number of input parameters is correct or not
if len(sys.argv) != 5:
  print("Usage:", str(sys.argv[0]), "<Temperature, Bond dimension, Niter, model>")
  sys.exit(1)

# Model Parameters
Temp = float(sys.argv[1])          # Temperature
D_cut = int(sys.argv[2])           # Bond Dimension
Niters = int(sys.argv[3])          # Number of iterations
model = str(sys.argv[4])           # Model to run

beta = float(1.0/Temp)             # Inverse temperature
Ns = int(2**((Niters)))            # Number of lattice sites in each dimension
vol = Ns**2                        # Lattice volume                 
Dn = int(D_cut/2.0)                # For initial tensor for XY/GXY 

models_allowed = ['Ising', 'XY', 'GXY', 'Potts']  # The models that can be run using 2dTRG.py code

# Check if the model is an allowed model or not
if model not in models_allowed:
    print ("Model not supported. Exit")
    sys.exit(1) 
    
# Check whether CUDA available or not and load corresponding library 
use_cuda = torch.cuda.is_available()

# Start running the code
print ("STARTED", datetime.now())

# Some more model parameters
if model == 'GXY': delta, mcut = 0.8, 50       # For GXY Model
if model == 'XY': delta, mcut = 1.0, 0         # For XY Model
if model == 'Potts': qstate = 3                # For 3-state Potts model  

# Print GPU-related information and load libraries
if use_cuda:
    print ('------------------------------------------------')
    print('__CUDNN VERSION:', torch.backends.cudnn.version())
    print('__Number CUDA Devices:', torch.cuda.device_count())
    print('__CUDA Device Name:',torch.cuda.get_device_name(0))
    print('__CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)
    print ('------------------------------------------------')


    from opt_einsum_torch import EinsumPlanner
    # Can replace by 'import opt_einsum_torch as ee' and delete line below. 
    ee = EinsumPlanner(torch.device('cuda:0'), cuda_mem_limit = 0.7)

# Import CPU-based python libraries if CUDA not available
else:
    import psutil
    import platform
    import multiprocessing
    from opt_einsum import contract 


def exact_free_energy_Ising(temp):
    beta = 1.0 / temp
    cc, ss = cosh(2.0 * beta), sinh(2.0 * beta)
    k = 2.0 * ss / cc**2

    def integrant(x):
        return log(1.0 + sqrt(abs(1.0 - k * k * cos(x)**2)))
    
    integral, err = integrate.quad(integrant, 0, 0.5 * pi, epsabs=1e-13, epsrel=1e-13)
    result = integral / pi + log(cc) + 0.5 * log(2.0)
    return -result / beta


def SVD(t, left_indices, right_indices, D):
    '''
    Perform singular value decomposition of a tensor by reshaping to matrix
    Return U out of U, s, V. 
    '''
    T = torch.permute(t, tuple(left_indices + right_indices)) if use_cuda else np.transpose(t, left_indices + right_indices)
    left_index_sizes = [T.shape[i] for i in range(len(left_indices))]
    right_index_sizes = [T.shape[i] for i in range(len(left_indices), len(left_indices) + len(right_indices))]
    xsize, ysize = np.prod(left_index_sizes), np.prod(right_index_sizes)
    T = torch.reshape(T, (xsize, ysize)) if use_cuda else np.reshape(T, (xsize, ysize))
    U, _, _ = torch.linalg.svd(T, full_matrices=False) if use_cuda else sp.linalg.svd(T, full_matrices=False)
    size = np.shape(U)[1]
    D = min(size, D)
    U = U[:, :D]
    U = torch.reshape(U, tuple(left_index_sizes + [D])) if use_cuda else np.reshape(U, left_index_sizes + [D]) 
    return U


def coarse_graining(t):
  
    Tfour = ee.einsum('jabe,iecd,labf,kfcd->ijkl', t, t, t, t) if use_cuda else contract('jabe,iecd,labf,kfcd->ijkl', t, t, t, t)
    U = SVD(Tfour,[0,1],[2,3],D_cut) 
    Tx = ee.einsum('abi,bjdc,acel,edk->ijkl', U, t, t, U) if use_cuda else contract('abi,bjdc,acel,edk->ijkl', U, t, t, U)
    Tfour = ee.einsum('aibc,bjde,akfc,flde->ijkl',Tx,Tx,Tx,Tx) if use_cuda else contract('aibc,bjde,akfc,flde->ijkl',Tx,Tx,Tx,Tx)
    U = SVD(Tfour,[0,1],[2,3],D_cut)
    Txy = ee.einsum('abj,iacd,cbke,del->ijkl', U, Tx, Tx, U) if use_cuda else contract('abj,iacd,cbke,del->ijkl', U, Tx, Tx, U)
    norm = torch.max(Txy) if use_cuda else np.max(Txy)
    Txy /= norm
    return Txy, norm


def weights(index, beta, delta):
    return sum([iv(index-2.0*j, beta*delta)*iv(j, beta*(1.0-delta)) for j in range(-mcut, mcut+1)])


def init_tensors(model):

    if model == 'GXY' or model == 'XY':
    
        L = [sqrt(weights(i, beta, delta)) for i in range(-Dn, Dn+1)]
    
        if use_cuda: t1 = torch.tensor(L)
        out = ee.einsum('i,j,k,l->ijkl', t1,t1,t1,t1) if use_cuda else contract('i,j,k,l->ijkl', L, L, L, L)
        
        for l,r,u,d in product(range (-Dn,Dn+1), repeat=4):
            
            index = l+u-r-d
            if index != 0:
                out[l+Dn][r+Dn][u+Dn][d+Dn] = 0.0
              
        return out


    if model == 'Ising':
        
        tau = 1 # This is np.exp(0.250000*beta*h) for finite 'h'. 
        a = np.sqrt(np.cosh(beta)) 
        b = np.sqrt(np.sinh(beta)) 
        W = np.array([[a*tau,b*tau],[(a/tau),-(b/tau)]]) 
        
        if use_cuda:
            t1 = torch.tensor(W)
            out = ee.einsum('ia,ib,ic,id->abcd', t1,t1,t1,t1)
        else:
            out = contract("ia, ib, ic, id  -> abcd", W, W, W, W) 
        
        return out


    if model == 'Potts':

        Wnew = np.zeros((qstate, qstate))
        for i in range (qstate):
            for j in range (qstate):
                if i == j:
                    Wnew[i][j] = np.exp(beta)
                else:
                    Wnew[i][j] = 1. 
        
        L = LA.cholesky(Wnew)
        
        if use_cuda: 
            L = torch.tensor(L)
            out = ee.einsum("ia, ib, ic, id  -> abcd", L, L, L, L)
        else: 
            out = contract("ia, ib, ic, id  -> abcd", L, L, L, L) 
        
        return out



if __name__ == "__main__":
    
    start = time.time()
    
    T = init_tensors(model)
    norm = torch.max(T) if use_cuda else np.max(T)
    T /= norm 
    C = log(norm)

    for i in range(Niters):

        print ("Iteration #",i+1,"Timestamp:",datetime.now())
        T, norm = coarse_graining(T)
        torch.cuda.empty_cache()
        C += log(norm)/4**(i+1)

        if i > Niters-4:
        # Only compute free energy in the last few iterations 

            Z1 = ee.einsum('aibj,bkal->ijkl', T, T) if use_cuda else contract('aibj,bkal->ijkl', T, T)
            Z = ee.einsum('abcd,badc->''', Z1, Z1) if use_cuda else contract('abcd,badc->''', Z1, Z1)
            if Z > 0:
               Free = -Temp * (C + (log(Z)/(4**Niters)))  
            else: 
                print ("WARNING: Z < or = 0 ")

    end = time.time()  

    if model == 'Ising': 
        exact = exact_free_energy_Ising(Temp)
        print ("Exact answer:", exact) 
        error_in_f_from_exact = abs((Free-exact)/(exact))


    path = os.path.join('./', str(model) + '_data') 
    if not os.path.exists(path):
        os.makedirs(path)


    fileout = model + str(int(datetime.now().strftime("%Y%m%d%H%M%S"))) + '_GPU' + '_N' + str(Niters) + '_D' + str(D_cut) + '.txt' if use_cuda else model + str(int(datetime.now().strftime("%Y%m%d%H%M%S"))) + '_CPU' + '_N' + str(Niters) + '_D' + str(D_cut) + '.txt'
    f=open(os.path.join(path, fileout), "a+") 

    if model == 'GXY' or model == 'XY':   
        f.write("%4.10f \t %4.14f  \t %2.0f \t %2.0f \t %2.4f \t %6.2f \n" % (Temp, Free, Niters, D_cut, delta, end-start)) 
        f.close()
    elif model == 'Ising': 
        f.write("%4.10f \t %4.10f  \t %2.0f \t %2.0f \t %2.3e \t %6.2f \n" % (Temp, Free, Niters, D_cut, error_in_f_from_exact, end-start)) 
        f.close()
    elif model == 'Potts': 
        f.write("%4.10f \t %4.10f  \t %2.0f \t %2.0f \t %2.0f \t %6.2f \n" % (Temp, Free, Niters, D_cut, qstate, end-start)) 
        f.close()
        
    print ("COMPLETED", datetime.now())
    print("Run time (in seconds):", round(end-start,2))
