from classes.Model import PlatformModel
import pyomo.environ as pyo
import numpy as np
from pyomo.opt import SolverFactory
import copy 
import shutil
import multiprocessing as mpp
   
class Algorithm:
    
    def __init__(self,S,users):
        self.epsilon=0.00000001
        self.S=S
        self.users=users
        self.platform=PlatformModel()
        
    def simple_algorithm(self, input_file,Path):
        
         
         opt_U=np.array([0 for i in range(self.S.D)],dtype=float)
         opt_n_c=np.array([0 for i in range(self.S.D)],dtype=float)
         opt_n_e=np.array([0 for i in range(self.S.D)],dtype=float)
         opt_x=np.array([[0 for it2 in range(self.S.D)] for it1 in range(self.S.N)],dtype=float)
         opt_y_c=np.array([0 for it1 in range(self.S.N)],dtype=float)
         opt_y_e=np.array([0 for it1 in range(self.S.N)],dtype=float)
         R_constraint=np.array([0 for it1 in range(self.S.N)],dtype=float)
        
         src=input_file
         dst =  Path + "/" + str(self.S.Lambda)+".dat"
         shutil.copyfile(src, dst)
         s,C=self.users.users_memory_energy_checking_cost(self.S)
         
         self.S.write_param_2D_in_platform_file(self.S.N,self.S.D,dst," s :",s)
         self.S.write_param_2D_in_platform_file(self.S.N,self.S.D,dst," C :",C)
         self.S.write_param_1D_in_platform_file(self.S.D,dst," data_size :",self.S.data_size,"s")
         self.S.write_param_2D_in_platform_file(self.S.N,self.S.D,dst," user_demand_matrix :",self.users.demands)
         self.S.write_param_1D_in_platform_file(self.S.N,dst," time :",self.users.T,"u")
         self.S.write_param_1D_in_platform_file(self.S.D,dst," edge_demand_matrix :",self.S.D_e,"s")
         self.S.write_param_1D_in_platform_file(self.S.D,dst," VM_demand_matrix :",self.S.D_c,"s")
        
         self.S.write_param_1D_in_platform_file(self.S.N,dst," network_Bandwidth :",self.users.B,"u")
         self.S.write_param_0D_in_platform_file(dst, " edge_energy_coefficient :", self.S.Beta_e)
         self.S.write_param_0D_in_platform_file(dst, " Lambda :", self.S.Lambda)
         self.S.write_param_0D_in_platform_file(dst, " max_edge_number :", self.S.N_e)
         self.S.write_param_0D_in_platform_file(dst, " VM_cost :", self.S.VM_cost)
         self.S.write_param_0D_in_platform_file(dst, " R_constraints :", self.S.R_bar)
         self.S.write_param_0D_in_platform_file(dst, " T :", self.S.T)
     
         data, model=self.platform.creat_platform_model(dst, self.S)
         instance = model.create_instance(data)
         # instance.VM_number["s2"]=n_c
         # instance.edge_number["s2"]=n_e
         # instance.U["s1"]=U[0]
         # instance.U["s2"]=U[1]
         
         opt = SolverFactory("baron",executable='/home/sedghani/baron-lin64/baron')
         opt.options['threads'] = int(mpp.cpu_count())
         opt.solve(instance,options={'MaxTime': 3600}, keepfiles=True,tee=True,logfile = Path+"/Baron_log")
         
        # string='{\n VM_number["s1"]: 0; \n VM_number["s2"]: '+ str(n_c) + '; \n edge_number["s1"]: 0; \n edge_number["s2"]: '+str(n_e)+'; \n U["s1"]: ' + str(U[0]) +'; \n U["s2"]: '+str(U[1])+'; \n }'
         #opt.solve(instance,options={'STARTING_POINT': string}, keepfiles=True,tee=True)
         #opt.solve(instance, keepfiles=True,tee=True, warmstart=True)
        
         opt_P= instance.Maximum_edge_profit("Value")
         opt_U[0]= instance.U["s1"]("Value")
         opt_U[1]= instance.U["s2"]("Value")
         opt_n_c[1]= instance.VM_number["s2"]("Value")
         opt_n_e[1]= instance.edge_number["s2"]("Value")
         i1=0
         k1=0
         for i in instance.Users:
             k1=0
             for k in instance.Deployments:
                 
                 opt_x[i1][k1]= instance.x[i,k]("Value")
                 
                 k1+=1
             opt_y_e[i1]=instance.y_e[i]("Value")
             opt_y_c[i1]=instance.y_c[i]("Value")
             R_constraint[i1]=sum(instance.user_demand_matrix[i,k]("Value") * instance.x[i,k]("Value") for k in instance.Deployments)+\
                sum(instance.data_size[k]("Value") * instance.x[i,k]("Value")/instance.network_Bandwidth[i]("Value") for k in instance.Deployments if k != "s1")+\
                sum((instance.edge_number[k]("Value") * instance.edge_demand_matrix[k]("Value") * instance.x[i,k]("Value") * instance.y_e[i]("Value"))/(instance.edge_number[k]("Value")-sum(instance.edge_demand_matrix[k]("Value") * instance.x[i1,k]("Value") * instance.Lambda("Value") *  instance.y_e[i1]("Value") for i1 in instance.Users))for k in instance.Deployments if k != "s1")+\
                    sum((instance.VM_number[k]("Value") * instance.VM_demand_matrix[k]("Value") * instance.x[i,k]("Value") * instance.y_c[i]("Value"))/(instance.VM_number[k]("Value")-sum(instance.VM_demand_matrix[k]("Value") * instance.x[i1,k]("Value") * instance.Lambda("Value") *  instance.y_c[i1]("Value") for i1 in instance.Users))for k in instance.Deployments if k != "s1") 
             i1+=1
         return opt_P, opt_U, opt_n_c, opt_n_e, opt_x, opt_y_e, opt_y_c, R_constraint
       
    def Baseline_alg(self):
       
       
        U_total=[]
        batch_result=[]
        
        
       
        for i in range(self.S.N):
            
            C1=self.users.T[i]*self.S.Lambda*self.users.beta[i]*self.users.dep_power_consumption[i][0]
            C2=self.users.T[i]*self.S.Lambda*self.users.beta[i]*self.users.dep_power_consumption[i][1]
            U_drop=(C1-C2)/(1-self.S.gamma)
            k=0
            if U_drop>=self.S.U_min and U_drop<=self.S.U_max :
                U_total.append((U_drop,k))
            k=1
            U_drop=self.users.T[i]*self.S.Lambda*self.users.beta[i]*self.users.dep_power_consumption[i][k]
            
            if U_drop*(1/self.S.gamma)>=self.S.U_min and U_drop*(1/self.S.gamma)<=self.S.U_max :
                U_total.append((U_drop,k))
        # pdb.set_trace()    
        # U=[0.3264822 , 0.22853754]
        # x=users.solve_users_problem(S,U)
        # R_bar_bar, Lambdas= S.compute_R_bar_bar_Lambda(users, x)
        # if Lambdas[1]>0:
        #            n_e, Lambda_e,n_c, Lambda_c= S.compute_opt_n_e_c(Lambdas,R_bar_bar)
        # P=self.platform.platform_revenue(S,x,U,n_e,n_c)
        # opt_n_c, opt_n_e, opt_Lambda_e, opt_Lambda_c,y_e, y_c=resource_assignment(S,users, U, n_e, Lambda_e, n_c, Lambda_c, x)
        # P = self.platform.platform_revenue(S,x,U,opt_n_e,opt_n_c)
        #pdb.set_trace()
        for U_k in U_total:
            n_e=0
            Lambda_e=0
            n_c=0
            Lambda_c=0
            U=np.array([U_k[0] for i in range(self.S.D)],dtype=float)
           
            U[1]=U_k[0]*(self.S.gamma)
            
            x=self.users.solve_users_problem(self.S,U)
            # if sum(x[0:])[0]==self.S.N:
            #     breakpoint()
            
            R_bar_bar, Lambdas= self.S.compute_R_bar_bar_Lambda(self.users, x)
            if Lambdas[1]>0:
                   n_e, Lambda_e,n_c, Lambda_c= self.S.compute_opt_n_e_c(Lambdas,R_bar_bar)
            
            P=self.platform.platform_revenue(self.S,x,U,n_e,n_c)
            batch_result.append((P,U,n_c,n_e,Lambda_c,Lambda_e,x))
            
            if U[0]+self.epsilon>= self.S.U_min:
                U[0]=U[0]+self.epsilon
                U[1]=copy.deepcopy(U[0])*(self.S.gamma)
                x=self.users.solve_users_problem(self.S,U)
               
                
                R_bar_bar, Lambdas= self.S.compute_R_bar_bar_Lambda(self.users, x)
                if Lambdas[1]>0:
                       n_e, Lambda_e,n_c, Lambda_c= self.S.compute_opt_n_e_c(Lambdas,R_bar_bar)
                
                P=self.platform.platform_revenue(self.S,x,U,n_e,n_c)
                batch_result.append((P,U,n_c,n_e,Lambda_c,Lambda_e,x))
        #pdb.set_trace()    
        batch_result.sort(key = lambda x: x[0],reverse=True) 
        return batch_result
    
    
    def resource_assignment(self, U, n_e, Lambda_e, n_c, Lambda_c, x):
            y_e=np.array([0 for i in range(len(x))],dtype=int)
            y_c=np.array([0 for i in range(len(x))],dtype=int)
            X=[]
            L_e=self.S.D_e[1]*Lambda_e
            L_c=self.S.D_c[1]*Lambda_c
            Opt_n_e=n_e
            Opt_n_c=n_c
           
            for i in range(len(x)):
               
                if x[i][1]>0:
                    X.append(i)
            V=[]   
            violation=False
            
            if len(X)>0:
                if  n_c<1 :
                       
                    for i in X:
                        y_e[i]=1
                        v=  self.users.demands[i][1]+self.S.data_size[1]/self.users.B[i]+ self.S.D_e[1]/(1-(L_e/n_e))- self.S.R_bar
                        V.append([i,v,0])
                       
                        if v>0:
                            violation=True
                   
                    if violation:
                        V.sort(key = lambda k: k[1]) 
                        j,v,l=V[-1]
                        de=self.S.R_bar - (self.users.demands[j][1]+self.S.data_size[1]/self.users.B[j])
                        Opt_n_e= (L_e *de)/ (de-self.S.D_e[1])
                        if Opt_n_e>round(Opt_n_e):
                            Opt_n_e=round(Opt_n_e)+1
                        else:
                            Opt_n_e=round(Opt_n_e)
                                  
                else:
                    
                     for i in X:
                        v=  self.users.demands[i][1]+self.S.data_size[1]/self.users.B[i]
                        V.append([i,v, -1]) 
                     V.sort(key = lambda k: k[1],reverse=True) 
                   #  pdb.set_trace()
                     i=int(Lambda_c/self.S.Lambda)
                     if i< len(V):
                         j=V[i][0]
                     for z in range(i):
                         y_c[V[z][0]]=1
                         V[z][2]=1
                     for z in range(i, len(V)):
                         violat=V[z][1]+self.S.D_e[1]/(1-(L_e/n_e))- self.S.R_bar
                         y_e[V[z][0]]=1
                         V[z][2]=0
                     if i< len(V):
                        de=self.S.R_bar - (self.users.demands[j][1]+self.S.data_size[1]/self.users.B[j])
                        Opt_n_e= (L_e *de)/ (de-self.S.D_e[1])
                        if Opt_n_e>round(Opt_n_e):
                            Opt_n_e=round(Opt_n_e)+1
                        else:
                            Opt_n_e=round(Opt_n_e)
                 
            while Opt_n_e>self.S.N_e:
                sorted_V= sorted(V, key=lambda element: (element[2], element[1]))
                j=sorted_V[0][0]
                y_e[j]=0
                y_c[j]=1
                i=[row[0] for row in V].index(j)
                V[i][2]=1
                Lambda_e-=self.S.Lambda
                Lambda_c+=self.S.Lambda
                L_e=self.S.D_e[1]*Lambda_e
                L_c=self.S.D_c[1]*Lambda_c
                de=self.S.R_bar - (self.users.demands[j][1]+self.S.data_size[1]/self.users.B[i])
                Opt_n_e= (L_e *de)/ (de-self.S.D_e[1])
                if Opt_n_e>round(Opt_n_e):
                    Opt_n_e=round(Opt_n_e)+1
                else:
                    Opt_n_e=round(Opt_n_e)
                
            if Lambda_c>0:
                j=V[0][0]
                de=self.S.R_bar - (self.users.demands[j][1]+self.S.data_size[1]/self.users.B[j])
                Opt_n_c= (L_c *de)/ (de-self.S.D_c[1])
                if Opt_n_c>round(Opt_n_c):
                    Opt_n_c=round(Opt_n_c)+1
                else:
                    Opt_n_c=round(Opt_n_c)
               
        
            return Opt_n_c, Opt_n_e, Lambda_e, Lambda_c, y_e, y_c
        
    def best_resource_assignment(self,batch_results,batch_size):
        best_P=0
        
        
        batch_result = batch_results[0:batch_size]
        for result in batch_result:
            U=result[1]
            n_c=result[2]
            n_e=result[3]
            Lambda_c=result[4]
            Lambda_e=result[5]
            x=result[6]
            
           
            opt_n_c, opt_n_e, opt_Lambda_e, opt_Lambda_c,y_e, y_c=self.resource_assignment(U, n_e, Lambda_e, n_c, Lambda_c, x)
            P = self.platform.platform_revenue(self.S,x,U,opt_n_e,opt_n_c)
            if P> best_P:
                best_P=P
                best_n_c=opt_n_c
                best_n_e=opt_n_e
                best_Lambda_e=opt_Lambda_e
                best_Lambda_c=opt_Lambda_c
                best_U=copy.deepcopy(U)
                best_x=copy.deepcopy(x)
                best_y_e=copy.deepcopy(y_e)
                best_y_c=copy.deepcopy(y_c)
        
        return best_P,best_n_c,best_n_e, best_Lambda_e, best_Lambda_c , best_U, best_x, best_y_e, best_y_c
    
    def all_dep1_Baseline_alg(self):
        
       
        U_total=[]
        batch_result=[]
        
        
       
        for i in range(self.S.N):
            
            #U_drop=self.users.T[i]*self.S.Lambda*self.users.beta[i]*self.users.dep_power_consumption[i][0]
            C1=self.users.T[i]*self.S.Lambda*self.users.beta[i]*self.users.dep_power_consumption[i][0]
            C2=self.users.T[i]*self.S.Lambda*self.users.beta[i]*self.users.dep_power_consumption[i][1]
            U_drop=(C1-C2)/(1-self.S.gamma) 
            if U_drop>=self.S.U_min and U_drop<=self.S.U_max :
                U_total.append(U_drop)
        # best_P=0
        # best_U=None
        # best_x=None
        # for U_k in U_total:
        #     n_e=0
        #     Lambda_e=0
        #     n_c=0
        #     Lambda_c=0
        #     U=np.array([U_k for i in range(self.S.D)],dtype=float)
           
        #     U[1]=0
            
        #     x=self.users.solve_users_problem(self.S,U)
           
        #     P=self.platform.platform_revenue(self.S,x,U,n_e,n_c)
        #     if P>best_P:
        #         best_P=P
        #         best_U=copy.deepcopy(U)
        #         best_x=copy.deepcopy(x)
            
        #     if U[0]+self.epsilon>= self.S.U_min:
        #         U[0]=U[0]+self.epsilon
        #         U[1]=0
        #         x=self.users.solve_users_problem(self.S,U)
               
        #         P=self.platform.platform_revenue(self.S,x,U,n_e,n_c)
        #         if P>best_P:
        #             best_P=P
        #             best_U=copy.deepcopy(U)
        #             best_x=copy.deepcopy(x)
        
        best_U=max(U_total)+self.epsilon
        U=np.array([best_U for i in range(self.S.D)],dtype=float)
        U[1]=0
        n_e=0
        Lambda_e=0
        n_c=0
        Lambda_c=0 
        x=self.users.solve_users_problem(self.S,U)
       
        P=self.platform.platform_revenue(self.S,x,U,n_e,n_c)
        best_P=P
        best_U=copy.deepcopy(U)
        best_x=copy.deepcopy(x)
        return best_P, best_x, best_U
    
    def all_dep2_best_price(self):
       
       
        U_total=[]
        batch_result=[]
        
       
        for i in range(self.S.N):
            
           
            U_drop=self.users.T[i]*self.S.Lambda*self.users.beta[i]*self.users.dep_power_consumption[i][1]
            
            if U_drop*(1/self.S.gamma)>=self.S.U_min and U_drop*(1/self.S.gamma)<=self.S.U_max :
                U_total.append(U_drop)
        best_P=0
        best_U=0
        for U_k in U_total:
            n_e=0
            Lambda_e=0
            n_c=0
            Lambda_c=0
            U=np.array([0 for i in range(self.S.D)],dtype=float)
           
            U[1]=U_k
            
            x=self.users.solve_users_problem(self.S,U)
          
                   
            P=self.platform.platform_revenue(self.S,x,U,n_e,n_c)
            if P>best_P:
                best_P=P
                best_U=U_k
            
            
            if U[1]*(1/self.S.gamma)+self.epsilon>= self.S.U_min:
                U[0]=0
                U[1]=U[1]+self.epsilon
                x=self.users.solve_users_problem(self.S,U)
               
                P=self.platform.platform_revenue(self.S,x,U,n_e,n_c)
                if P>best_P:
                    best_P=P
                    best_U=U_k
        U=np.array([0 for i in range(self.S.D)],dtype=float)
        U[1]=best_U
        x=self.users.solve_users_problem(self.S,U)
        participants=0
        max_v=0
        for i in range(self.S.N):
            if x[i][1]>0:
                participants+=1
                v=  self.users.demands[i][1]+self.S.data_size[1]/self.users.B[i]
                if v>max_v:
                    max_v=v
                    max_i=i
        return max_i, participants,x, U 
    
    
    def only_edge(self):
        n_c=0
        max_i, participants,x, U =self.all_dep2_best_price()
        de=self.S.R_bar - (self.users.demands[max_i][1]+self.S.data_size[1]/self.users.B[max_i])
        L_e=self.S.Lambda*participants
        Opt_n_e= (L_e *de)/ (de-self.S.D_e[1])
        if Opt_n_e>round(Opt_n_e):
            Opt_n_e=round(Opt_n_e)+1
        else:
            Opt_n_e=round(Opt_n_e)   
        if Opt_n_e<=self.S.N_e:
            feasible=True
            P=self.platform.platform_revenue(self.S,x,U,Opt_n_e,n_c)
        else:
            feasible=False
            P=0
        return feasible, P,x, U  
    
    def only_cloud(self):
        
        n_e=0
        max_i, participants,x, U =self.all_dep2_best_price()
        de=self.S.R_bar - (self.users.demands[max_i][1]+self.S.data_size[1]/self.users.B[max_i])
        L_e=self.S.Lambda*participants
        Opt_n_c= (L_e *de)/ (de-self.S.D_c[1])
        if Opt_n_c>round(Opt_n_c):
            Opt_n_c=round(Opt_n_c)+1
        else:
            Opt_n_c=round(Opt_n_c)   
        
        P=self.platform.platform_revenue(self.S,x,U,n_e,Opt_n_c)
        
        return  P,x,U
    
    
    
