In [None]:
import os,sys
import numpy as np
import pandas as pd
import glob
import copy
from scipy.stats import f
import scipy.optimize as sopt
import sklearn.metrics
import math
import matplotlib.pyplot as plt
import time

## 2022-05-09 CMTF(Coupled Matrix and Tensor Factorization) algorithm for joint analysis of microbiome and metabolomic
## Reference:
## Tan Z.C., Murphy M.C., Alpay H.S., Taylor S.D., Meyer A.S.
## Tensor-structured decomposition improves systems serology analysis
## Molecular Systems Biology, 17 (9) (2021), p. e10243

## Acar E, Kolda TG, Dunlavy DM (2011) All-at-once optimization for coupled
## matrix and tensor factorization. arXiv https://arxiv.org/abs/1105.3422(non conjugate gradient method optimization)

## Kolda T.G.,Bader B.W.,Tensor decompositions and applications.
## SIAM Rev. 2009; 51: 455-500

## Also partly modified from TensorLy:Tensor Learning in Python,
##            https://arxiv.org/abs/1610.09555.



## This is a rough version for testing


class factor_matrix:
    def __init__(self):
        self.tfm = None
        self.mfm = None
        self.weight = []
        self.variance = []
    # naive implementation of khatri-rao product
    def kr4t_product(self,outer_idx):
        if self.tfm:
            oper_vector = self.tfm.copy()
            oper_vector.pop(outer_idx)
            prod = np.einsum("ir,jr -> ijr",oper_vector[1],oper_vector[0]).reshape(oper_vector[1].shape[0]*oper_vector[0].shape[0],-1)
            return prod



def prototype_tensor_process(file_list,exclude_unclassified = False):
    
    # summarizing the total number of pathway and contributing microbes
    pwy_list = []
    mb_list = []
    sample_list = []
    for file in file_list:
        res = pd.read_csv(file,sep="\t",header = None)
        df = pd.DataFrame(res)

        
        for i in range(1,len(df)):
            tmp_feature = df.iloc[i,0].split("|")
            if len(tmp_feature) > 1:
                tmp_pwy = str(tmp_feature[0])
                if tmp_pwy != "UNMAPPED" and tmp_pwy != "UNINTEGRATED":
                    tmp_mb = str(tmp_feature[1])
                    if exclude_unclassified == False or tmp_mb != "unclassified":
                        if tmp_pwy not in pwy_list:
                            pwy_list.append(tmp_pwy)
                        if tmp_mb not in mb_list:
                            mb_list.append(tmp_mb)
            
    pwy_set = dict(zip(pwy_list,[x for x in range(len(pwy_list))]))
    mb_set = dict(zip(mb_list,[x for x in range(len(mb_list))]))

    tensor = np.zeros((len(pwy_list),len(mb_list),len(file_list)),dtype=float)
    # generating the pathway-microbe matrix for each sample 
    file_idx = 0
    for file in file_list:
        res=pd.read_csv(file,sep="\t",header=None)
        df = pd.DataFrame(res)
        filename = file.split("\\")[-1]
        samplename = filename.replace("_non_Homo_sapiens_pathabundance.tsv","")
        samplename = samplename.replace("_clean_w_host_pathabundance.tsv","")
        samplename = samplename.replace("_pathabundance.tsv","")
        sample_list.append(samplename)       
        for i in range(1,len(df)):
            tmp_feature = df.iloc[i,0].split("|")
            if len(tmp_feature) > 1:
                tmp_pwy = str(tmp_feature[0])
                if tmp_pwy != "UNMAPPED" and tmp_pwy != "UNINTEGRATED":
                    tmp_mb = str(tmp_feature[1])
                    if exclude_unclassified == False or tmp_mb != "unclassified":
                        tmp_relab = float(df.iloc[i,1])
                        tensor[pwy_set[tmp_pwy],mb_set[tmp_mb],file_idx] = tmp_relab
        file_idx+=1
    
    return tensor,pwy_list,mb_list,sample_list

#log scaling of total functional pathway abundance(followed by linear scaling of stratified species)
def tensor_logscale(mm_tensor):
    for i in range(mm_tensor.shape[0]):
        for j in range(mm_tensor.shape[2]):
            total_ab = sum(mm_tensor[i,:,j])
            if total_ab == 0:
                continue
            scale_fac = np.log(total_ab)
            tmp_contrib = np.array([x/total_ab*scale_fac for x in mm_tensor[i,:,j]])
            mm_tensor[i,:,j] = tmp_contrib
    return mm_tensor
        
       



# centered log ratio transform of metabolomic data(relative abundance transformed)/ probabilistic quotient normalization transform 
# of metabolomic data
def mb_preprocess(mb_df):
    def clr_trans(x,f_list):
        new_x = 0
        for e in f_list:
            if e != x:
                new_x+=np.log(x/e)
        new_x = new_x/len(f_list)
        return new_x
    def pqn_trans(mb_mat):
        new_mb_mat = []
        for i in range(mb_mat.shape[0]):
            # calculate the reference intensity for each metabolite
            tmp_ref = [np.mean(mb_mat[:,j]) for j in range(mb_mat.shape[1])]
            #calculate the denominator in transformation
            tmp_rdeno = np.median([float(mb_mat[i][x]/tmp_ref[x]) for x in range(len(tmp_ref))])
            new_ob_vec = mb_mat[i]/tmp_rdeno
            new_mb_mat.append(new_ob_vec)
        final_mb_mat = np.array(new_mb_mat)
        return final_mb_mat  
    
    #mb_rowsum = np.sum(mb_df,axis=1)
    #mb_df = mb_df.T/mb_rowsum
    #mb_df = mb_df.T
    #zto_df = np.where(mb_df==0,1,mb_df)
    trans_mb_list = []
    feature_size = mb_df.shape[1]
    for i in range(mb_df.shape[0]):
        new_vector = [clr_trans(x,mb_df[i]) for x in mb_df[i]]
        trans_mb_list.append(new_vector)
    trans_mb_df = np.array(trans_mb_list)
    #trans_mb_df = pqn_trans(mb_df)
    #print("probabilistic quotient normalization completed\n")
    #trans_mb_df = np.log(mb_df)
    print("logarithm after pqn transform completed\n")
    return trans_mb_df

def reorder(metabol_df,sample_list):
    try:
        ori_sample_list = list(metabol_df.columns)        
        if set(ori_sample_list) == set(sample_list):
            metabol_df = metabol_df[sample_list]
            candidate_mb = metabol_df.values
            zto_candidate_mb = np.where(candidate_mb==0,1,candidate_mb)
            print("starting clr transform\n")
            candidate_df = mb_preprocess(zto_candidate_mb.T)
            return candidate_df       
        else:
            raise ValueError("elements in metabolomics file must equal to microbiome functional file\n")
    except ValueError as ve:
        print("Critical error:",ve)
        

def metabol_extract(mbpath,mapperpath,sample_list,output_label='metabolite'):
    mb_file = pd.read_csv(mbpath,sep = "\t",header = 0,index_col=0)
    mb_df = pd.DataFrame(mb_file)
    metabolite_list = list(mb_df.index)
    
    map_file = pd.read_csv(mapperpath,sep = "\t",header = 0)
    map_df = pd.DataFrame(map_file)
    
    strg_seqid = [str(x) for x in list(map_df['seq_id'])]
    strg_sampid = [str(x) for x in list(map_df['sample_id'])]
    samp_label = [str(x) for x in list(map_df['type'])]
    
    ss_dict = dict(zip(strg_seqid,strg_sampid))

    new_sample_list = [ss_dict[x] for x in sample_list]
    
    
    new_mb_df = reorder(mb_df,new_sample_list)
    try:
        if output_label == 'metabolite':
            final_sample_list = new_sample_list
            label_dict = dict(zip(strg_sampid,samp_label))
        elif output_label == 'microbe':
            final_sample_list = sample_list
            label_dict = dict(zip(strg_seqid,samp_label))            
        else:
            raise ValueEror("invalid data type selected when determining output labels\n")
    except:
        print("Error: ",ValueError)
    return new_mb_df,final_sample_list,label_dict,metabolite_list


 
# FASTER implementation of kronecker product computation
def fast_kron(mat1,mat2):
    kron_prod = 1



# A refined algorithm designed for avoiding intermediate data explosion when computing mode-n-matrix · kr product of factor matrix C and B
# reference:U. Kang, E. E. Papalexakis, A. Harpale, and C. Faloutsos, “Gigatensor:
#scaling tensor analysis up by 100 times - algorithms and discoveries,”
#in KDD, 2012.
def big_mat_product_compute(fm,tensorshape,idx,ten2mat=None):
    lf = kr1.shape[1]
    
    M1 = np.array([])
    id_veclist = []
    for i in range(len(tensorshape)):
        tmp_vec = np.full(tensorshape[i],1,1,dtype = float)
        id_veclist.append(tmp_vec)
    al_tensorshape = tensorshape.copy()
    al_tensorshape.pop(idx)
    final_idvec = np.full(al_tensorshape[1]*al_tensorshape[0],1,1,dtype = float)

    tmp_fm = fm.tfm.copy()
    tmp_fm.pop(idx)
    tmp_idvec = id_veclist.copy()
    tmp_idvec.pop(idx)
    
    bin_ten2mat = np.where(ten2mat[idx]!=0,1,0)
    for r in range(lf):
        N1 = ten2mat[idx] * np.matmul(id_veclist[idx],np.kron(tmp_fm[1][:,r].reshape(tmp_fm[1].shape[0],1).T,tmp_idvec[0].T))
        N2 = bin_ten2mat * np.matmul(id_veclist[idx],np.kron(tmp_idvec[1].T,tmp_fm[0][:,r].reshape(tmp_fm[0].shape[0],1).T))
        N3 = N1 * N2
        if M1.size != 0:
            M1 = np.hstack(M1,np.matmul(N3,final_idvec))
        else:
            M1 = np.matmul(N3,final_idvec)
    
    return M1

# 2022-09-21 updated:independent function for khatri-rao product computation
def raw_krp(mat1,mat2):
    prod = np.einsum('ir,jr->ijr',mat1,mat2).reshape(mat1.shape[0]*mat2.shape[0],-1)
    return prod


# alternating least squares for optimizing coupled matrix and tensor factorization model
def als_optimize(fm,tensor_unfold,mb_df=None,idx=None):
    V = np.array([])
    if idx != None:
        for d in range(len(fm.tfm)):
            if d != idx:
                if V.size != 0:
                    V = np.matmul(fm.tfm[d].T,fm.tfm[d])*V
                else:
                    V = np.matmul(fm.tfm[d].T,fm.tfm[d])

        comp_unfold = tensor_unfold[idx]
        kr_product = fm.kr4t_product(idx).T
        tensorshape = [lm.shape[0] for lm in fm.tfm]
        #M1 = big_mat_product_compute(fm,tensorshape,idx,tensor_unfold)
        if idx == 2:
            comp_unfold = np.hstack((tensor_unfold[idx],mb_df))
            kr_product = np.hstack((kr_product,fm.mfm.T))
            #M2 = np.matmul(mb_df,fm.mfm)
            #M1 = np.hstack(M1,M2)
            V = V + np.matmul(fm.mfm.T,fm.mfm)
        pinverseV = np.matmul(np.linalg.pinv(np.matmul(V.T,V)),V.T)
        F = np.matmul(np.matmul(comp_unfold,kr_product.T),pinverseV)
        #F = np.matmul(M1,pinverseV)
        return F
    else:
        M = np.matmul(mb_df.T,np.linalg.pinv(fm.tfm[2]).T)
        return M
    
def refold(unfold_t,mode,shape):
    ## unfold_t:a mode-n unfolded tensor ,in the form of matrix
    ## mode:an integer that represents the mode of unfolding
    ## shape:a list that represents the original shape of 3-way tensor
    shape_mkii = shape.copy()
    
    if mode == 0:
        recontruct_t = np.reshape(unfold_t,tuple(shape_mkii),order='F')
        return reconstruct_t
    elif mode == 1:
        shape_mkii[0],shape_mkii[1] = shape_mkii[1],shape_mkii[0]
        tmp_reconstruct_t = np.reshape(unfold_t,tuple(shape_mkii),order='F')
        re_list = []
        for d in range(shape_mkii[2]):
            re_list.append(tmp_reconstruct_t[:,:,d].T)
        reconstruct_t = np.dstack(tuple(re_list))
        return reconstruct_t
    elif mode == 2:
        tmp_reconstruct_t = unfold_t.T
        t_dim = shape_mkii[2]
        new_shape = shape_mkii.copy()
        new_shape.pop()
        re_list = []
        for d in range(t_dim):
            re_list.append(np.reshape(tmp_reconstruct_t[:,d],tuple(new_shape),order='F'))
        reconstruct_t = np.dstack(tuple(re_list))
        return reconstruct_t
    else:
        print("mode excesses the number of dimensions of tensor")

    

#evaluating the error
def r2_eval(fm,tensor,mat,log_check=False,mask=None):
    # use R2 as loss function
    numerator = 0
    denominator = 0
    mode2_tenkai = np.matmul(fm.tfm[1],fm.kr4t_product(1).T) 
    shape = list(tensor.shape)
    if log_check:
        print(shape)

    rc_tensor = refold(mode2_tenkai,1,shape)
    if log_check:
        print(rc_tensor.shape)
    numerator += np.linalg.norm(tensor-rc_tensor)
    
    rc_matrix = np.matmul(fm.tfm[2],fm.mfm.T)
    numerator += np.linalg.norm(mat-rc_matrix)

    total_norm = np.linalg.norm(tensor)+np.linalg.norm(mat)
    denominator+=total_norm
    r2 = 1 - numerator/denominator
    if log_check:
        print(numerator)
        print(denominator)
    return r2

def rmse_eval(fm,tensor,mat,mask=None):
    # use RMSE as loss function
    est_err = 0
    mode2_tenkai = np.matmul(fm.tfm[1],fm.kr4t_product(1).T)
    shape = list(tensor.shape)

    
    rc_tensor = refold(mode2_tenkai,1,shape)
    rc_matrix = np.matmul(fm.tfm[2],fm.mfm.T)

    if mask:
        #rc_tensor = mask[0]*rc_tensor
        #rc_matrix = mask[1]*rc_matrix
        idx_0 = np.nonzero(mask[0])
        idx_1 = np.nonzero(mask[1])
        est_err += np.linalg.norm(tensor[idx_0]-rc_tensor[idx_0])/np.sqrt(tensor[idx_0].size)
        est_err += np.linalg.norm(mat[idx_1]-rc_matrix[idx_1])/np.sqrt(mat[idx_1].size)
    else:
        est_err += np.linalg.norm(tensor-rc_tensor)/np.sqrt(tensor.size)
        est_err += np.linalg.norm(mat-rc_matrix)/np.sqrt(mat.size)
    return est_err

def fac_norm(fac_mat):
    #normalize factor matrix loadings into unit length (remove the effect of different scales of variables)
    #past the norm to the weight
    for d in range(len(fac_mat.tfm)):
        scale = np.linalg.norm(fac_mat.tfm[d],2,axis=0)# l2 norm across each latent factor,length = number_of_latent_factors
        scale_nz = np.where(scale==0,np.ones(fac_mat.tfm[d].shape[1],dtype=float),scale)
        #fac_mat.weight*=scale
        fac_mat.tfm[d]/=scale_nz
    mscale = np.linalg.norm(fac_mat.mfm,2,axis=0)
    #fac_mat.weight*=mscale
    fac_mat.mfm/=mscale 
        
    return fac_mat

def simple_normalize(fac_mat):
    # perform factor matrix normalization in the order of inf norm 
    # input: a factor matrix object containing factor matrices of each mode

    for d in range(len(fac_mat.tfm)):

        # find the largest abosolute value in each column
        scale = np.linalg.norm(fac_mat.tfm[d],ord=np.inf,axis=0)
        fac_mat.tfm[d]/=scale
    mscale = np.linalg.norm(fac_mat.mfm,ord=np.inf,axis=0)
    fac_mat.mfm/=mscale
    return fac_mat 


def sort_by_var(fac_mat,mm_tensor,mb_mat,final_var):
    # sort the latent component by explained variance
    ## method: In a for loop,rermove one component at a time, then calculate the ratio between total variance and component-reomoved
    ## model variance, bigger the ratio,the smaller variance that the removed-component can explain
    num_lf = fac_mat.tfm[0].shape[1]
    shape = [fac_mat.tfm[x].shape[0] for x in range(len(fac_mat.tfm))]
    rc_var = []
    
    mode2_tenkai = np.matmul(fac_mat.tfm[1],fac_mat.kr4t_product(1).T)
    #print(mode2_tenkai.shape)
    #mode2_tenkai = big_mat_product_compute(fac_mat.tfm)
    #print(shape)
    rc_tensor = refold(mode2_tenkai,1,shape)
    rc_matrix = np.matmul(fac_mat.tfm[2],fac_mat.mfm.T)
    #print(rc_tensor.shape)
    denominator = np.linalg.norm(rc_tensor)/np.sqrt(rc_tensor.size) + np.linalg.norm(rc_matrix)/np.sqrt(rc_matrix.size)
    #print(shape)
    #print("displaying error after deleting one latent component")
    #print(denominator)
    ## the old method ###
    for d in range(num_lf):
        eval_fac_mat = copy.deepcopy(fac_mat)
        shape = [eval_fac_mat.tfm[x].shape[0] for x in range(len(eval_fac_mat.tfm))]

        eval_fac_mat.tfm = [np.delete(fm,d,axis=1) for fm in eval_fac_mat.tfm]
        eval_fac_mat.mfm = np.delete(fac_mat.mfm,d,axis=1)
        
        e_mode2_tenkai = np.matmul(eval_fac_mat.tfm[1],eval_fac_mat.kr4t_product(1).T)
        #print(shape)
        e_rc_tensor = refold(e_mode2_tenkai,1,shape)
        #print(e_rc_tensor.shape)
        
        e_rc_matrix = np.matmul(eval_fac_mat.tfm[2],eval_fac_mat.mfm.T)
        
      
        
        
        numerator = np.linalg.norm(rc_tensor-e_rc_tensor)/np.sqrt(rc_tensor.size) + np.linalg.norm(rc_matrix-e_rc_matrix)/np.sqrt(rc_matrix.size)
        tmp_var = float(numerator)
        #print(tmp_var)
        
        rc_var.append(tmp_var)
    
    fac_mat.variance = [1-float(tmp_var/denominator) for tmp_var in sorted(rc_var)]
    #fac_mat.variance = [float(tmp_var/np.sum(fac_mat.variance)) for tmp_var in fac_mat.variance]
    #fac_mat.variance = [float(final_var-tmp_var) for tmp_var in sorted(rc_var)]
    var_ord = np.argsort(rc_var)
    fac_mat.tfm = [fm[:,var_ord] for fm in fac_mat.tfm]
    fac_mat.mfm = fac_mat.mfm[:,var_ord]
    
    return fac_mat

def sort_by_weight(fac_mat):
    total_weight = np.sum(fac_mat.weight)
    fac_mat.variance = [float(sub_weight/total_weight) for sub_weight in sorted(fac_mat.weight,reverse=True)]
    lf_ord = np.flip(np.argsort(fac_mat.weight))
    
    fac_mat.tfm = [fm[:,lf_ord] for fm in fac_mat.tfm]
    fac_mat.mfm = fac_mat.mfm[:,lf_ord]
    
    return fac_mat



#prototype function for coupling microbes-predicted functional pathways 3-order tensor and metabolites matrix
#return normalized unit length factor matrices of each dimensions and corresponding weight.
def joint_mt(mm_tensor,mb_mat,lf,tmask=np.array([]),mmask=np.array([])):
    #first initialize factorized rank-1 tensors and matrices using SVD
    try: 
        if len(mm_tensor.shape) != 3:
            raise ValueError("only accept 3-way tensor as input")
    except ValueError as ve:
        print("error:",ve)
        sys.exit(1)
    init_fm = factor_matrix()
    init_fm.tfm = []
    
    # generating data with missing 
    if tmask.size and mmask.size:
        mask = [tmask,mmask]
    else:
        mask = []
    #try:
        #if flag_sparse and flag_sparse < 1:
            #mask_tensor = np.random.choice([0,1],size=mm_tensor.shape,p=[flag_sparse,1-flag_sparse])
            #mask_matrix = np.random.choice([0,1],size=mb_mat.shape,p=[flag_sparse,1-flag_sparse])
            
            #mm_tensor = mask_tensor * mm_tensor
            #mb_mat = mask_matrix * mb_mat
        #elif flag_sparse and flag_sparse >= 1:
            #raise ValueError("percentage of missing entries cannot excel 1")
    #except:
        #print("error:",ValueError)
        #sys.exit(1)
    #init_fm.weight = np.ones((1,lf),dtype=float)
    for d in range(len(mm_tensor.shape)):
        m_unfold = np.reshape(np.moveaxis(mm_tensor,d,0),(mm_tensor.shape[d],-1),order='F')
        
        if d == 2:        
            m_unfold = np.hstack((m_unfold,mb_mat))

        # initialization of factor matrices based on SVD

        sub_eima = np.linalg.svd(m_unfold,full_matrices=False)[0]
        if lf <= sub_eima.shape[1]:
            init_fm.tfm.append(sub_eima[:,:lf])
        else:
            rnd_fm = np.random.rand(sub_eima.shape[0],lf-sub_eima.shape[1])
            init_fm.tfm.append(np.hstack((sub_eima,rnd_fm)))
    print("initialization of tensor factor matrix completed")
    # pre unfold tensor for downstream computation
    ori_unfold = []
    for d in range(len(mm_tensor.shape)):
        tmp_unfold = np.reshape(np.moveaxis(mm_tensor,d,0),(mm_tensor.shape[d],-1),order='F')
        ori_unfold.append(tmp_unfold)
    print("pre-unfold completed")
    #initialize the factor matrix of original matrix
    init_fm.mfm = als_optimize(init_fm,ori_unfold,mb_df=mb_mat,idx=None)
    print("initialization of matrix factor matrix completed")
    
    var_est = []
    if mask:
        print('masking...')
        init_var = rmse_eval(init_fm,mm_tensor,mb_mat,mask)
        
    else:
        init_var = rmse_eval(init_fm,mm_tensor,mb_mat)
    var_est.append(init_var)
    #iterating the optimizing procedure until the loss function reaches converge
    for i in range(1,2000+1):    
        # minimizing each mode of tensor(including sample mode)
        for d in range(len(mm_tensor.shape)):
            init_fm.tfm[d] = als_optimize(init_fm,ori_unfold,mb_df=mb_mat,idx=d)
        # minimizing matrix's extra info factor matrix
        init_fm.mfm = als_optimize(init_fm,ori_unfold,mb_df=mb_mat,idx=None)
        
        #normalize factor matrices and get weights after a full iteration on all dimensions
        #fac_norm(init_fm)
        if mask:
            print('masking...')
            var = rmse_eval(init_fm,mm_tensor,mb_mat,mask)
        else:
            var = rmse_eval(init_fm,mm_tensor,mb_mat)
        var_est.append(var)
        #print(var)
        #print(f"round {i} completed")
        if abs((var_est[i]-var_est[i-1])/var_est[i-1]) <= 1e-6 or i==2000:
        #if abs(var_est[i]-var_est[i-1]) <= 1e-6 or i==2000:
            print(f"round {i} reached convergence or reached max iteration")
            final_var = var_est[i]
            print(f'final error is {final_var}')
            break
    
    #fac_norm(init_fm)
    #init_fm = sort_by_weight(init_fm)
    init_fm = sort_by_var(init_fm,mm_tensor,mb_mat,final_var)
    #init_fm = fac_norm(init_fm)
    init_fm = simple_normalize(init_fm)
    
    return init_fm


    

#python implementation for computing product of tensor and vector(only for 3-ways tensor)
# Reference: Brett W. Bader, Tamara G. Kolda and others, Tensor Toolbox for MATLAB, Version 3.3, www.tensortoolbox.org, August 16, 2022
def tenvec(tensor,vector_list,mode):
    ## argument tensor: 3-way tensor with the shape of (m,n,k)
    ## argument vector_list: list of vector that is used to multiplied
    ## argument mode:the mode involved in multiplication 
    
    #transform all vectors shape into (x,)
    op_vector_list = []
    try:
        for vec in vector_list:
            if isinstance(vec,np.ndarray):
                tmp_vec = np.reshape(vec,(vec.shape[0],),order='F')
                op_vector_list.append(tmp_vec)
            elif isinstance(vec,list):
                tmp_vec = np.reshape(vec,(len(vec),),order='F')
                op_vector_list.append(tmp_vec)
            else:
                raise TypeError("unsupported data type of vector")
    except TypeError as te:
        print("error",te)
    #check argument consistency
    try:
        for m in mode:
            if m not in np.arange(len(tensor.shape)):
                raise ValueError("number of dimensions in tensor does not match the modes involved")
                break
        if len(mode) != len(op_vector_list):
            raise ValueError("number of modes involved does not match the number of vectors)")
        for i,m in enumerate(mode):
            if op_vector_list[i].shape[0] != tensor.shape[m]:
                raise ValueError("vector length does not match the corresponding tensor dimensions")
    except ValueError as ve:
        print("error:",ve)

    
    # sort mode and vector list from highest idx to lowest
    combine_list = zip(op_vector_list,mode)
    sorted_mode = [i for x,i in sorted(combine_list,key = lambda x:x[1],reverse=True)]
    combine_list = zip(op_vector_list,mode)
    op_vector_list = [x for x,i in sorted(combine_list,key = lambda x:x[1],reverse=True)]
     
    # 3-ways tensors only 
    tenvec_xp = np.array([])
    cnt = 1
    for i,d in enumerate(sorted_mode):
        if cnt == 1:
            tmp_tenkai = np.reshape(np.moveaxis(tensor,d,0),(tensor.shape[d],-1),order='F')
            tenvec_xp = np.matmul(op_vector_list[i].T,tmp_tenkai)
            cnt+=1
        else:
            tenvec_xp = np.reshape(tenvec_xp,(int(np.prod(tensor.shape[0:d])),int(tensor.shape[d])),order='F')
            tenvec_xp = np.squeeze(tenvec_xp)
            tenvec_xp = np.matmul(tenvec_xp,op_vector_list[i].T)
            cnt+=1
    return tenvec_xp

# NRX-044 Ashima
# generate objective function and gradients for nonlinear conjugate gradient optimization
def funcgrad_gen(mm_tensor,mb_mat,fm):
    #alpha param for feature weights,PROBABLY 1 is a decent choice
    # beta: sparsity parameter for weight of components
    # eps: a small constant which helps computing differentiation 
    
    #--------------------------------------
    # microbes-function profiling tensor of shape (pathway,microbe,sample)
    # metabolite matrix of shape (sample,metabolite)

    
    def obj_f(x,*args):
        mm_tensor,mb_mat,alpha,beta_ten,beta_mat,eps,lf = args
        shape = list(mm_tensor.shape)
        mshape = list(mb_mat.shape)
        
        tfm = [np.zeros((shape[x],lf),dtype=float) for x in range(len(shape))]
        #rebuild all factor matrices
        cnt_idx = 0
        for d in range(len(shape)):
            if d < 1:
                tfm[d] = x[0:(shape[d]*lf)].reshape(shape[d],lf)
            else:
                cnt_idx+=shape[d-1]*lf
                tfm[d] = x[cnt_idx:(cnt_idx+shape[d]*lf)].reshape(shape[d],lf)
        cnt_idx+=shape[-1]*lf
        mfm = x[cnt_idx:(cnt_idx+mshape[-1]*lf)].reshape(mshape[-1],lf)
        cnt_idx+=mshape[-1]*lf
        tweight = x[cnt_idx:cnt_idx+lf].reshape(1,lf)
        cnt_idx+=lf
        mweight = x[cnt_idx:cnt_idx+lf].reshape(1,lf)        
        
        '''reconstruct tensor by factor matrices and weights of components'''
        mode2_tenkai = np.matmul(tfm[1]*tweight,raw_krp(tfm[2],tfm[0]).T)
        rc_tensor =refold(mode2_tenkai,1,shape)
        '''reconstruct matrix by factor matrices and weights of components'''
        rc_matrix = np.matmul(tfm[2]*mweight,mfm.T)
        ''' add 0.5 multiplication for derivative computation efficiency'''
        obj_func = 0.5*np.square(np.linalg.norm(mm_tensor))-np.sum(np.multiply(mm_tensor,rc_tensor)) + 0.5*np.square(np.linalg.norm(rc_tensor))
        obj_func+= 0.5*np.square(np.linalg.norm(mb_mat))-np.sum(np.multiply(mb_mat,rc_matrix)) + 0.5*np.square(np.linalg.norm(rc_matrix))

        #adding constraints of weights in objective function
        for r in range(tfm[0].shape[1]):
            obj_func+=0.5*beta_ten*np.sqrt(np.square(tweight[0,r])+eps)+0.5*beta_mat*np.sqrt(np.square(mweight[0,r])+eps)
        #adding norm constraints of each feature in objective function
        all_fm = [tfm[0],tfm[1],tfm[2],mfm]
        for ft in range(len(all_fm)):
            for r in range(tfm[0].shape[1]):
                obj_func+=0.5*alpha*np.square(np.linalg.norm(all_fm[ft][:,r])-1)
        return obj_func
    #gradient expression
    #each element except last two in grad is a numpy array of shape(feature_length,r),the last two elements are  arrays of shape(,r)
    
    
    def total_grad(x,*args):
        mm_tensor,mb_mat,alpha,beta_ten,beta_mat,eps,lf = args
        shape = list(mm_tensor.shape)
        mshape = list(mb_mat.shape)
        
        tfm = [np.zeros((shape[x],lf),dtype=float) for x in range(len(shape))]
        #rebuild all factor matrices
        cnt_idx = 0
        for d in range(len(shape)):
            if d < 1:
                tfm[d] = x[0:(shape[d]*lf)].reshape(shape[d],lf)
            else:
                cnt_idx+=shape[d-1]*lf
                tfm[d] = x[cnt_idx:(cnt_idx+shape[d]*lf)].reshape(shape[d],lf)
        cnt_idx+=shape[-1]*lf
        mfm = x[cnt_idx:(cnt_idx+mshape[-1]*lf)].reshape(mshape[-1],lf)
        cnt_idx+=mshape[-1]*lf
        tweight = x[cnt_idx:cnt_idx+lf].reshape(1,lf)
        cnt_idx+=lf
        mweight = x[cnt_idx:cnt_idx+lf].reshape(1,lf)
        
        
        # feature facor matrices gradient
        ft_grad = [np.zeros((shape[x],tfm[0].shape[1]),dtype=float) for x in range(len(shape))]
        
        mode2_tenkai = np.matmul(tfm[1]*tweight,raw_krp(tfm[2],tfm[0]).T)
        rc_tensor =refold(mode2_tenkai,1,shape)
        rc_matrix = np.matmul(tfm[2]*mweight,mfm.T)
        for ft in range(len(shape)):
            diff_unfold = np.reshape(np.moveaxis(mm_tensor-rc_tensor,ft,0),((mm_tensor-rc_tensor).shape[ft],-1),order='F')
            tmp_idx = [0,1,2]
            tmp_idx.pop(ft)
            kr_prod = raw_krp(tfm[tmp_idx[1]],tfm[tmp_idx[0]])
            kr_prod = raw_krp(tweight,kr_prod)
            ft_grad[ft] = -np.matmul(diff_unfold,kr_prod)
            if ft == 2:
                diff_matrix = mb_mat-rc_matrix
                ft_grad[ft] -= np.matmul(diff_matrix,np.matmul(mfm,np.diag(np.squeeze(mweight))))
            ft_grad[ft]+=alpha*(tfm[ft]-(tfm[ft]/np.linalg.norm(tfm[ft],axis=0)))
        # compute gradient of another feature in mb matrix 
        diff_matrix = (mb_mat-rc_matrix).T
        ft_grad.append(-np.matmul(diff_matrix,np.matmul(tfm[2],np.diag(np.squeeze(mweight)))))
        ft_grad[-1]+=alpha*(mfm-(mfm/np.linalg.norm(mfm,axis=0)))

        #compute the gradient of weights     
        wgt_grad = []
        wgt_lambda = np.zeros((1,tfm[0].shape[1]),dtype=float)
        for r in range(tfm[0].shape[1]):
            diff_tensor = mm_tensor-rc_tensor
            vector_list = [tfm[x][:,r] for x in range(len(shape))]
            mode = [0,1,2]
            wgt_lambda[0,r]-=tenvec(diff_tensor,vector_list,mode)
            wgt_lambda[0,r]+= 0.5*beta_ten*tweight[0,r]*np.sqrt(np.square(tweight[0,r])+eps)
        wgt_grad.append(wgt_lambda)# a vector of shape(r,1)
    
        wgt_sigma = np.zeros((1,tfm[0].shape[1]),dtype=float)
        for r in range(tfm[0].shape[1]):
            diff_matrix = rc_matrix-mb_mat
            v4mat_list = [tfm[2][:,r],mfm[:,r]]
            mode2 = [0,1]
            wgt_sigma[0,r]+=tenvec(diff_matrix,v4mat_list,mode2)
            wgt_sigma[0,r]+=0.5*beta_mat*mweight[0,r]*np.sqrt(np.square(mweight[0,r])+eps)
        wgt_grad.append(wgt_sigma) # a vector of shape(r,1)

    
        #concatenate all gradients into a numpy array
        all_grad  = np.array([])
        for d in range(len(ft_grad)):
            if not all_grad.size:
                all_grad = ft_grad[d].ravel()
            else:
                all_grad = np.append(all_grad,ft_grad[d].ravel())
        for d in range(len(wgt_grad)):
            all_grad = np.append(all_grad,wgt_grad[d].ravel())
        
        return all_grad
           
    return obj_f,total_grad


    


      
        
## 2022-09-08 another method for coupling microbes-predicted functional pathways 3-order tensor and metabolites matrix
## Reference:Acar, E., Papalexakis, E.E., Gürdeniz, G. et al. Structure-revealing data fusion. 
## BMC Bioinformatics 15, 239 (2014). https://doi.org/10.1186/1471-2105-15-239
def all_in_one_mt(mm_tensor,mb_mat,lf):
    #get the initial values first
    try: 
        if len(mm_tensor.shape) != 3:
            raise ValueError("only accept 3-way tensor as input")
    except ValueError as ve:
        print("error:",ve)
        sys.exit(1)   
    init_fm = factor_matrix()
    init_fm.tfm = []
    #SVD-based factor matrices initialization
    for d in range(len(mm_tensor.shape)):
        tmp_unfold = np.reshape(np.moveaxis(mm_tensor,d,0),(mm_tensor.shape[d],-1),order='F')
        if d == 2:        
            tmp_unfold = np.hstack((tmp_unfold,mb_mat))
        sub_eima = np.linalg.svd(tmp_unfold,full_matrices=False)[0]
        if lf <= sub_eima.shape[1]:
            init_fm.tfm.append(sub_eima[:,:lf])
        else:
            rnd_fm = np.random.rand(sub_eima.shape[0],lf-sub_eima.shape[1])
            init_fm.tfm.append(np.hstack((sub_eima,rnd_fm)))
    uniq_mat4fac = np.linalg.svd(mb_mat.T,full_matrices=False)[0]
    if lf <= uniq_mat4fac.shape[1]:
        init_fm.mfm = uniq_mat4fac[:,:lf]
    else:
        rnd_fm = np.random.rand(uniq_mat4fac.shape[0],lf-uniq_mat4fac.shape[1])
        init_fm.mfm = np.hstack((uniq_mat4fac,rnd_fm))
        
    # weight initialization
    for i in range(2):
        init_fm.weight.append(np.ones((1,lf),dtype=float))
    
    x0 = np.array([])
    for d in range(len(mm_tensor.shape)):
        if not x0.size:
            x0 = init_fm.tfm[d].ravel()
            print("the ",d,"th factor matrix is of shape:",init_fm.tfm[d].shape)
        else:
            x0 = np.append(x0,init_fm.tfm[d].ravel())
            print("the ",d,"th factor matrix is of shape:",init_fm.tfm[d].shape)
    x0 = np.append(x0,init_fm.mfm.ravel())
    print("the unique factor matrix is of shape:",init_fm.mfm.shape)
    for i in range(2):
        x0 = np.append(x0,init_fm.weight[i].ravel())
        print("the ",i,"th weight is of shape:",init_fm.weight[i].shape)
    print(x0.shape)
    
    
    # generate objective function and corresponding gradients 
    func,grad = funcgrad_gen(mm_tensor,mb_mat,init_fm)
    print("obj function and gradient generation completed\n")
    
    
    
    
    # the constant argument in model
    ALPHA = 1
    BETA_TEN = 1e-3
    BETA_MAT = 1e-3
    EPS = 1e-8
    ARGS = (mm_tensor,mb_mat,ALPHA,BETA_TEN,BETA_MAT,EPS,lf)

    OPTS = {'maxiter' : None,
            'disp' : True,
            'gtol' : 1e-6,
            'norm' : np.inf,
            }
    #performing nonlinear conjugate gradient optimization to compute factor matrices and their weights
    print("started BFGS optimization...\n")
    all_res = sopt.minimize(func,x0,jac=grad,args=ARGS,method="BFGS",options=OPTS)
    print("Optimization Completed\n")
    return all_res

    

##---------------------------------- Use simulated data to test the basic function of CMTF_for_MM----------------------------------##


def kr_prod(m_vector,out_idx):
    oper_vector = m_vector.copy()
    oper_vector.pop(out_idx)
    prod = np.einsum("ir,jr -> ijr",oper_vector[1],oper_vector[0]).reshape(oper_vector[1].shape[0]*oper_vector[0].shape[0],-1)
    return prod


r = 3

A = []
B = []
C = []
D = []
mu = 10
sigma = 1
for i in range(r):
    tmp_a = np.random.normal(mu,sigma,(50,1))
    #tmp_a1 = np.random.normal(mu+i*0.8,sigma,(25,1))
    #tmp_a2 = np.random.normal(2*mu+i*0.8,sigma,(25,1))
    tmp_b = np.random.normal(mu,sigma,(30,1))
    tmp_c = np.random.normal(mu,sigma,(40,1))
    #tmp_c1 = np.random.normal(mu*10,sigma+i*5,(20,1))
    #tmp_c2 = np.random.normal(20*mu,sigma+i*5,(20,1))
    tmp_d = np.random.normal(mu,sigma,(20,1))
    #tmp_b1 = np.random.normal(mu,sigma+20*r,(10,1))
    #tmp_b2 = np.random.normal(mu*10,sigma+20*r,(10,1))
    #tmp_b3 = np.random.normal(mu*20,sigma+20*r,(10,1))
    if not len(A):
        #tmp_a = np.vstack((tmp_a1,tmp_a2))
        A = tmp_a
    else:
        #tmp_a = np.vstack((tmp_a1,tmp_a2))
        A = np.hstack((A,tmp_a))
    if not len(B):
        #tmp_b = np.vstack((tmp_b1,tmp_b2,tmp_b3))
        B = tmp_b
    else:
        #tmp_b = np.vstack((tmp_b1,tmp_b2,tmp_b3))
        B = np.hstack((B,tmp_b))
    if not len(C):
        #tmp_c = np.vstack((tmp_c1,tmp_c2))
        C = tmp_c
    else:
        #tmp_c = np.vstack((tmp_c1,tmp_c2))
        C = np.hstack((C,tmp_c))
    if not len(D):
        D = tmp_d
    else:
        D = np.hstack((D,tmp_d))
    

total_vector = []
total_vector.append(A)
total_vector.append(B)

sample_list = [x for x in range(1,41)]
sample_label = ['dist1' if x <=20 else 'dist2' for x in range(1,41)]
label_dict = dict(zip(sample_list,sample_label))

shuff_list = list(zip(sample_list,C))
np.random.shuffle(shuff_list)

sample_list_new,newC = zip(*shuff_list)

sample_list_new = list(sample_list_new)
newC = np.array(list(newC))


total_vector.append(newC)

def random_weight_simulation(mu,sigma,r):
    A = []
    B = []
    C = []
    D = []
    for i in range(r):
        tmp_a = np.random.normal(mu,sigma,(50,1))
        tmp_b = np.random.normal(mu,sigma,(30,1))
        tmp_c = np.random.normal(mu,sigma,(40,1))
        tmp_d = np.random.normal(mu,sigma,(20,1))
        if not len(A):
            A = tmp_a
        else:
            A = np.hstack((A,tmp_a))
        if not len(B):
            B = tmp_b
        else:
            B = np.hstack((B,tmp_b))
        if not len(C):
            C = tmp_c
        else:
            C = np.hstack((C,tmp_c))
        if not len(D):
            D = tmp_d
        else:
            D = np.hstack((D,tmp_d))
    total_vector = []
    total_vector.append(A)
    total_vector.append(B)

    sample_list = [x for x in range(1,41)]
    sample_label = ['dist1' if x <=20 else 'dist2' for x in range(1,41)]
    label_dict = dict(zip(sample_list,sample_label))
    shuff_list = list(zip(sample_list,C))
    np.random.shuffle(shuff_list)
    sample_list_new,newC = zip(*shuff_list)
    sample_list_new = list(sample_list_new)
    newC = np.array(list(newC))
    total_vector.append(newC)
    for i in range(r):
        total_vector[i] = total_vector[i]/np.linalg.norm(total_vector[i],axis=0)
    D = D/np.linalg.norm(D,axis=0)
    test_weight = [np.array([1,1,1]),np.array([1,1,1])]
    sim_matrix = np.matmul(total_vector[2]*test_weight[1],D.T)
    mode2_tenkai = np.matmul(B*test_weight[0],kr_prod(total_vector,1).T)
    sim_tensor = refold(mode2_tenkai,1,[50,30,40])
    return sim_tensor,sim_matrix
    

sim_data_trigger = 0
if sim_data_trigger == 1:
    #normalize the simulated factor matrices
    for i in range(3):
        total_vector[i] = total_vector[i]/np.linalg.norm(total_vector[i],axis=0)
    D = D/np.linalg.norm(D,axis=0)

    test_weight = [np.array([1,1,1]),np.array([1,1,1])]

    sim_matrix = np.matmul(total_vector[2]*test_weight[1],D.T)


    mode2_tenkai = np.matmul(B*test_weight[0],kr_prod(total_vector,1).T)
    sim_tensor = refold(mode2_tenkai,1,[50,30,40])
    print("Display the simulated data 's shape'")
    print(sim_matrix.shape)
    print(sim_tensor.shape)

    shape = list(sim_tensor.shape)
    mshape = list(sim_matrix.shape)
    lf = 3

    #compute the RMSE of ALS model
    ores_fm = joint_mt(sim_tensor,sim_matrix,lf)
    print(f'first lf loadings has negative value: {np.any(ores_fm.tfm[0]<0)}')
    firstmode2_tenkai = np.matmul(ores_fm.tfm[1],ores_fm.kr4t_product(1).T)
    first_tensor = refold(firstmode2_tenkai,1,shape)
    first_matrix = np.matmul(ores_fm.tfm[2],ores_fm.mfm.T)

    first_rmse = np.linalg.norm(sim_tensor-first_tensor)/np.sqrt(sim_tensor.size)
    first_rmse += np.linalg.norm(sim_matrix--first_matrix)/np.sqrt(sim_matrix.size)
    print("final rmse of ALS is: ",first_rmse)


    mode2_tenkai = np.matmul(B*test_weight[0],kr_prod(total_vector,1).T)
    sim_tensor = refold(mode2_tenkai,1,[50,30,40])
    
    shape = list(sim_tensor.shape)
    mshape = list(sim_matrix.shape)
    print(sim_matrix.shape)
    print(sim_tensor.shape)
    #/compute the RMSE of all-in-once model
    res_fm = all_in_one_mt(sim_tensor,sim_matrix,lf)
    tfm = [np.zeros((shape[x],lf),dtype=float) for x in range(len(shape))]
    cnt_idx = 0
    for d in range(len(shape)):
        if d < 1:
            tfm[d] = res_fm.x[0:(shape[d]*lf)].reshape(shape[d],lf)
            print(tfm[d].shape)
        else:
            cnt_idx+=shape[d-1]*lf
            tfm[d] = res_fm.x[cnt_idx:(cnt_idx+shape[d]*lf)].reshape(shape[d],lf)
            print(tfm[d].shape)
    cnt_idx+=shape[-1]*lf
    mfm = res_fm.x[cnt_idx:(cnt_idx+mshape[-1]*lf)].reshape(mshape[-1],lf)
    print(mfm.shape)
    cnt_idx+=mshape[-1]*lf
    tweight = res_fm.x[cnt_idx:cnt_idx+lf].reshape(1,lf)
    cnt_idx+=lf
    mweight = res_fm.x[cnt_idx:cnt_idx+lf].reshape(1,lf)
    

    finalmode2_tenkai = np.matmul(tfm[1]*tweight,raw_krp(tfm[2],tfm[0]).T)
    
    final_tensor =refold(finalmode2_tenkai,1,shape)
    print(final_tensor.shape)
    final_matrix = np.matmul(tfm[2]*mweight,mfm.T)
    
    final_rmse = np.linalg.norm(sim_tensor-final_tensor)/np.sqrt(sim_tensor.size)
    final_rmse += np.linalg.norm(sim_matrix-final_matrix)/np.sqrt(sim_matrix.size)
    print("final rmse of ALL-IN-ONCE is: ",final_rmse)
    
    model_idx = 0
    #(for sim data)prepare the plotting data object
    if model_idx == 1:
        decomp_fm = factor_matrix()
    
        decomp_fm.tfm = [np.zeros((shape[x],lf),dtype=float) for x in range(len(shape))]
        cnt_idx = 0
        for d in range(len(shape)):
            if d < 1:
                decomp_fm.tfm[d] = res_fm.x[0:(shape[d]*lf)].reshape(shape[d],lf)
                print(decomp_fm.tfm[d].shape)
            else:
                cnt_idx+=shape[d-1]*lf
                decomp_fm.tfm[d] = res_fm.x[cnt_idx:(cnt_idx+shape[d]*lf)].reshape(shape[d],lf)
                print(decomp_fm.tfm[d].shape)
        cnt_idx+=shape[-1]*lf
        decomp_fm.mfm = res_fm.x[cnt_idx:(cnt_idx+mshape[-1]*lf)].reshape(mshape[-1],lf)
        print(decomp_fm.mfm.shape)
        cnt_idx+=mshape[-1]*lf
        decomp_fm.weight.append(res_fm.x[cnt_idx:cnt_idx+lf].reshape(1,lf).squeeze())
        cnt_idx+=lf
        decomp_fm.weight.append(res_fm.x[cnt_idx:cnt_idx+lf].reshape(1,lf).squeeze())
    
        label_dict = []
    elif model_idx == 0:
        decomp_fm = ores_fm
        final_sample_list = sample_list_new
        label_dict = label_dict
        group_type = list(set(label_dict.values()))
        
elif sim_data_trigger == 2:
    
    sim_tensor = np.random.normal(mu,sigma,(50,30,40))
    sim_matrix = np.random.normal(mu,sigma,(40,20))
    #sim_tensor = np.random.rand(50,30,40)
    #sim_matrix = np.random.rand(40,20)
    

    print("Display the simulated data 's shape'")
    print(sim_matrix.shape)
    print(sim_tensor.shape)

    shape = list(sim_tensor.shape)
    mshape = list(sim_matrix.shape)
    #split the data into train set(80%) and test set(20%)
    #generate mask tensor
    miss_tcnt = int(0.2*sim_tensor.size)
    entry_tcnt = int(0.8*sim_tensor.size)
    tseq = np.array([0]*miss_tcnt+[1]*entry_tcnt)
    np.random.shuffle(tseq)
    tmask = tseq.reshape(sim_tensor.shape)

    miss_mcnt = int(0.2*sim_matrix.size)
    entry_mcnt = int(0.8*sim_matrix.size)
    mseq = np.array([0]*miss_mcnt+[1]*entry_mcnt)
    np.random.shuffle(mseq)
    mmask = mseq.reshape(sim_matrix.shape)
    
    def random_cv(sim_tensor,sim_matrix,tmask,mmask,lf):
    
        inv_tmask = 1-tmask
        inv_mmask = 1-mmask
        train_tensor = tmask * sim_tensor 
        train_matrix = mmask * sim_matrix
    
        #test_tensor = inv_tmask * sim_tensor
        #test_matrix = inv_mmask * sim_matrix
        #compute the RMSE of train set
        decomp_fm = joint_mt(train_tensor,train_matrix,lf,tmask,mmask)
    
        re_trainmode2_tenkai = np.matmul(decomp_fm.tfm[1],decomp_fm.kr4t_product(1).T)
        re_train_tensor = refold(re_trainmode2_tenkai,1,shape)
        re_train_matrix = np.matmul(decomp_fm.tfm[2],decomp_fm.mfm.T)
        #re_train_tensor = tmask*re_train_tensor
        #re_train_matrix = mmask*re_train_matrix
        
        tidx = np.nonzero(tmask)
        midx = np.nonzero(mmask)
        itidx = np.nonzero(inv_tmask)
        imidx = np.nonzero(inv_mmask)
        
        weight_tensor = sim_tensor.size/(sim_tensor.size+sim_matrix.size)
        weight_matrix = sim_matrix.size/(sim_tensor.size+sim_matrix.size)
        
        
        train_rce = weight_tensor*np.linalg.norm(sim_tensor[tidx]-re_train_tensor[tidx])/np.linalg.norm(sim_tensor[tidx])
        print("final rce of train set tensor only is: ",train_rce)
        train_rce += weight_matrix*np.linalg.norm(sim_matrix[midx]--re_train_matrix[midx])/np.linalg.norm(sim_matrix[midx])
        print("final rce of train set is: ",train_rce)
    

    
        test_rce = weight_tensor*np.linalg.norm(sim_tensor[itidx]-re_train_tensor[itidx])/np.linalg.norm(sim_tensor[itidx])
        print("final rce of test set tensor only is: ",test_rce)
        test_rce += weight_matrix*np.linalg.norm(sim_matrix[imidx]--re_train_matrix[imidx])/np.linalg.norm(sim_matrix[imidx])
        print("final rce of test set is: ",test_rce)
        
        return train_rce,test_rce
    # data for drawing latent factor-error curve
    train_x,test_x = np.arange(3,11),np.arange(3,11)
    
    train_y = []
    test_y = []
    
    for lf in range(3,11):
        train_rce,test_rce = random_cv(sim_tensor,sim_matrix,tmask,mmask,lf)
        train_y.append(train_rce)
        test_y.append(test_rce)
        
elif sim_data_trigger == 3:
    sim_matrix = np.random.normal(mu,sigma,(40,20))

    sim_tensor = np.random.normal(mu,sigma,(50,30,40))
    print("Display the simulated data's shape")
    print(sim_matrix.shape)
    print(sim_tensor.shape)
    
    shape = list(sim_tensor.shape)
    mshape = list(sim_matrix.shape)
    
    all_x = np.arange(3,11)
    all_y = []
    for lf in range(3,11):
        decomp_fm = joint_mt(sim_tensor,sim_matrix,lf)
        rc_mode2_tenkai = np.matmul(decomp_fm.tfm[1],decomp_fm.kr4t_product(1).T)
        rc_tensor = refold(rc_mode2_tenkai,1,shape)
        rc_matrix = np.matmul(decomp_fm.tfm[2],decomp_fm.mfm.T)
    
        
        nrmse = np.linalg.norm(sim_tensor-rc_tensor)/np.linalg.norm(sim_tensor)
        nrmse += np.linalg.norm(sim_matrix-rc_matrix)/np.linalg.norm(sim_matrix)
        print(f'the nrmse of {lf} model is {nrmse}')
        all_y.append(nrmse)
elif sim_data_trigger == 4:   
    sim_matrix = np.random.normal(mu,sigma,(40,20))

    sim_tensor = np.random.normal(mu,sigma,(50,30,40))
    print("Display the simulated data's shape")
    print(sim_matrix.shape)
    print(sim_tensor.shape)
    
    shape = list(sim_tensor.shape)
    mshape = list(sim_matrix.shape)

    decomp_fm = joint_mt(sim_tensor,sim_matrix,4)
    rc_mode2_tenkai = np.matmul(decomp_fm.tfm[1],decomp_fm.kr4t_product(1).T)
    rc_tensor = refold(rc_mode2_tenkai,1,shape)
    rc_matrix = np.matmul(decomp_fm.tfm[2],decomp_fm.mfm.T)
    nrmse = np.linalg.norm(sim_tensor-rc_tensor)/np.linalg.norm(sim_tensor)
    nrmse += np.linalg.norm(sim_matrix-rc_matrix)/np.linalg.norm(sim_matrix)
    print(f'the nrmse of 4 model is {nrmse}')
elif sim_data_trigger == 5:
    # investigating the effect of different variance on reconstruction
    v_als_rmse = []
    v_aio_rmse = []
    for sigma in range(1,11):
        #sim_tensor = np.random.normal(mu,sigma,(50,30,40))
        #sim_matrix = np.random.normal(mu,sigma,(40,20))
        A = []
        B = []
        C = []
        D = []
        for i in range(3):
            tmp_a = np.random.normal(mu,sigma,(50,1))
            tmp_b = np.random.normal(mu,sigma,(30,1))
            tmp_c = np.random.normal(mu,sigma,(40,1))
            tmp_d = np.random.normal(mu,sigma,(20,1))
            if not len(A):
                A = tmp_a
            else:
                A = np.hstack((A,tmp_a))
            if not len(B):
                B = tmp_b
            else:
                B = np.hstack((B,tmp_b))
            if not len(C):
                C = tmp_c
            else:
                C = np.hstack((C,tmp_c))
            if not len(D):
                D = tmp_d
            else:
                D = np.hstack((D,tmp_d))
        total_vector = []
        total_vector.append(A)
        total_vector.append(B)
        sample_list = [x for x in range(1,41)]
        sample_label = ['dist1' if x <=20 else 'dist2' for x in range(1,41)]
        label_dict = dict(zip(sample_list,sample_label))
        shuff_list = list(zip(sample_list,C))
        np.random.shuffle(shuff_list)
        sample_list_new,newC = zip(*shuff_list)
        sample_list_new = list(sample_list_new)
        newC = np.array(list(newC))
        total_vector.append(newC)
        for i in range(3):
            total_vector[i] = total_vector[i]/np.linalg.norm(total_vector[i],axis=0)
        D = D/np.linalg.norm(D,axis=0)
        test_weight = [np.array([1,1,1]),np.array([1,1,1])]
        sim_matrix = np.matmul(total_vector[2]*test_weight[1],D.T)
        mode2_tenkai = np.matmul(B*test_weight[0],kr_prod(total_vector,1).T)
        sim_tensor = refold(mode2_tenkai,1,[50,30,40])
        
        shape = list(sim_tensor.shape)
        mshape = list(sim_matrix.shape)
        #/ compute the RMSE of als model
        decomp_fm = joint_mt(sim_tensor,sim_matrix,3)
        rc_mode2_tenkai = np.matmul(decomp_fm.tfm[1],decomp_fm.kr4t_product(1).T)
        rc_tensor = refold(rc_mode2_tenkai,1,shape)
        rc_matrix = np.matmul(decomp_fm.tfm[2],decomp_fm.mfm.T)
        nrmse = np.linalg.norm(sim_tensor-rc_tensor)/np.linalg.norm(sim_tensor)
        nrmse += np.linalg.norm(sim_matrix-rc_matrix)/np.linalg.norm(sim_matrix)
        v_als_rmse.append(nrmse)

        #/compute the RMSE of all-in-once model
        res_fm = all_in_one_mt(sim_tensor,sim_matrix,3)
        tfm = [np.zeros((shape[x],3),dtype=float) for x in range(len(shape))]
        cnt_idx = 0
        for d in range(len(shape)):
            if d < 1:
                tfm[d] = res_fm.x[0:(shape[d]*3)].reshape(shape[d],3)
            else:
                cnt_idx+=shape[d-1]*3
                tfm[d] = res_fm.x[cnt_idx:(cnt_idx+shape[d]*3)].reshape(shape[d],3)
        cnt_idx+=shape[-1]*3
        mfm = res_fm.x[cnt_idx:(cnt_idx+mshape[-1]*3)].reshape(mshape[-1],3)
        cnt_idx+=mshape[-1]*3
        tweight = res_fm.x[cnt_idx:cnt_idx+3].reshape(1,3)
        cnt_idx+=3
        mweight = res_fm.x[cnt_idx:cnt_idx+3].reshape(1,3)
        finalmode2_tenkai = np.matmul(tfm[1]*tweight,raw_krp(tfm[2],tfm[0]).T)    
        
        final_tensor =refold(finalmode2_tenkai,1,shape)
        final_matrix = np.matmul(tfm[2]*mweight,mfm.T)
        final_rmse = np.linalg.norm(sim_tensor-final_tensor)/np.sqrt(sim_tensor.size)
        final_rmse += np.linalg.norm(sim_matrix-final_matrix)/np.sqrt(sim_matrix.size)
        v_aio_rmse.append(final_rmse)

        
        
elif sim_data_trigger == 6:
    # investigate the effect of sparsity on reconstruction
    sim_tensor,sim_matrix = random_weight_simulation(10,1,3)
    shape = list(sim_tensor.shape)
    mshape = list(sim_matrix.shape)
    sp_als_rmse = []
    sp_aio_rmse = []
    for sp in [0,0.2,0.4,0.6,0.8]:
        # randomly masking tensor and matrix to simulate sparsity 
        tmask = np.random.choice([0, 1], size=sim_tensor.shape, p=[sp, 1-sp])
        sim_tensor = sim_tensor * tmask
        mmask = np.random.choice([0,1], size=sim_matrix.shape, p=[sp,1-sp])
        sim_matrix = sim_matrix * mmask
        #/ compute the RMSE of als model
        decomp_fm = joint_mt(sim_tensor,sim_matrix,3)
        rc_mode2_tenkai = np.matmul(decomp_fm.tfm[1],decomp_fm.kr4t_product(1).T)
        rc_tensor = refold(rc_mode2_tenkai,1,shape)
        rc_matrix = np.matmul(decomp_fm.tfm[2],decomp_fm.mfm.T)
        nrmse = np.linalg.norm(sim_tensor-rc_tensor)/np.linalg.norm(sim_tensor)
        nrmse += np.linalg.norm(sim_matrix-rc_matrix)/np.linalg.norm(sim_matrix)
        sp_als_rmse.append(nrmse)

        #/compute the RMSE of all-in-once model
        res_fm = all_in_one_mt(sim_tensor,sim_matrix,3)
        tfm = [np.zeros((shape[x],3),dtype=float) for x in range(len(shape))]
        cnt_idx = 0
        for d in range(len(shape)):
            if d < 1:
                tfm[d] = res_fm.x[0:(shape[d]*3)].reshape(shape[d],3)
            else:
                cnt_idx+=shape[d-1]*3
                tfm[d] = res_fm.x[cnt_idx:(cnt_idx+shape[d]*3)].reshape(shape[d],3)
        cnt_idx+=shape[-1]*3
        mfm = res_fm.x[cnt_idx:(cnt_idx+mshape[-1]*3)].reshape(mshape[-1],3)
        cnt_idx+=mshape[-1]*3
        tweight = res_fm.x[cnt_idx:cnt_idx+3].reshape(1,3)
        cnt_idx+=3
        mweight = res_fm.x[cnt_idx:cnt_idx+3].reshape(1,3)
        finalmode2_tenkai = np.matmul(tfm[1]*tweight,raw_krp(tfm[2],tfm[0]).T)    
        
        final_tensor =refold(finalmode2_tenkai,1,shape)
        final_matrix = np.matmul(tfm[2]*mweight,mfm.T)
        final_rmse = np.linalg.norm(sim_tensor-final_tensor)/np.sqrt(sim_tensor.size)
        final_rmse += np.linalg.norm(sim_matrix-final_matrix)/np.sqrt(sim_matrix.size)
        sp_aio_rmse.append(final_rmse)
        
        
    
    
## Implementation of quantification of the group separation of PCA-like results(only between two groups)
## input : feature distribution dataframe(sample x feature),sample labels (sample : label)
def separation_test(ftdist_df,final_sample_list,label_dict):
    all_label = list(set(label_dict.values()))
    sub_collection = []
    if len(all_label) > 2:
        print("type of labels cannot excess 2")
        sys.exit()
    
    for type_label in all_label:
        sub_sample = [x for x in final_sample_list if x in label_dict and label_dict[x] == type_label]
        sub_idx = [final_sample_list.index(x) for x in sub_sample]
        sub_ftdist = ftdist_df[sub_idx,:]
        sub_collection.append(sub_ftdist)
    
    diff_mat = np.mean(sub_collection[0],axis=0)-np.mean(sub_collection[1],axis=0)
    # it should be of shape (2,)

    
    spool_cov = 0
    for group in sub_collection:
        spool_cov+=np.cov(group.T)*(group.shape[0]-1)
    free_degree = np.sum([x.shape[0]-1 for x in sub_collection])
    spool_cov/=free_degree
    spool_cov_inv = np.linalg.inv(spool_cov)
    
    m_dist = np.matmul(np.matmul(diff_mat,spool_cov_inv.T),diff_mat.T)

    size1 = sub_collection[0].shape[0]
    size2 = sub_collection[1].shape[0]
    t_sqe = (size1*size2)/(size1+size2)*m_dist
    
    f_statistic = (size1+size2-sub_collection[0].shape[1]-1)/(sub_collection[0].shape[1]*(size1+size2-2))*t_sqe
    
    fdist = f(sub_collection[0].shape[1],size1+size2-sub_collection[0].shape[1]-1)
    
    p_val = 1-fdist.cdf(f_statistic)
    return p_val

def calculate_sparsity(tensor):
    total_elements = float(tensor.size)
    #tensor = np.nan_to_num(tensor,0)
    nz_elements = np.count_nonzero(tensor)
    sparsity = 1 - (nz_elements / total_elements)
    return sparsity


real_data_trigger = 1


if real_data_trigger == 1:

    filedir = "/analysis/data/"

    mbpath = "/analysis/mbdata.txt"

    mapperpath = "/analysis/seq_id_mapper.txt"

    
    file_list = list(glob.glob(os.path.join(filedir,"*.tsv")))
    exclude_unclassified = False
    fc_tensor,pwy_list,mb_list,sample_list = prototype_tensor_process(file_list,exclude_unclassified)

    fc_tensor = tensor_logscale(fc_tensor)
    print("functional profiling tensor construction & scaling completed")

    for dim,dim_size in enumerate(fc_tensor.shape):
        print(dim)
        print(dim_size)
    fc_matrix,final_sample_list,label_dict,metabolite_list = metabol_extract(mbpath,mapperpath,sample_list,output_label='metabolite')
    print("metabolomics matrix construction completed")

    group_type = list(set(label_dict.values()))
    print(group_type)


    print("ALL data construction completed")
    print(fc_tensor.shape)
    print(fc_matrix.shape)
    ten_sparsity = calculate_sparsity(fc_tensor)
    mat_sparsity = calculate_sparsity(fc_matrix)
    print(ten_sparsity)
    print(mat_sparsity)
    
    lf = 3
    decomp_method = 'als'

    optim_lf_flag = 0
    
    if decomp_method == 'als':
        #the basic ALS model
        if optim_lf_flag == 0:
            start_time = time.time()
            decomp_fm = joint_mt(fc_tensor,fc_matrix,lf)
            print(f'completed in {time.time()-start_time} seconds')
            final_var2 = rmse_eval(decomp_fm,fc_tensor,fc_matrix)
            print(f'final error version 2 is: {final_var2}')
            print(decomp_fm.variance)
        elif optim_lf_flag == 1:
            start_time = time.time()
            lf_var_list = {}
            for lf in range(1,11):
                decomp_fm = joint_mt(fc_tensor,fc_matrix,lf)
                print(f'completed in {time.time()-start_time} seconds')
                final_var2 = rmse_eval(decomp_fm,fc_tensor,fc_matrix)
                print(f'final error version 2 is: {final_var2}')
                lf_var_list[lf] = final_var2
            print(lf_var_list)
            # draw lf to RMSE curve, show the optimal number of lf
            plt.style.use('fivethirtyeight')
            fig = plt.figure(figsize=(10,10))
            ax = fig.add_subplot(1,1,1)
            #ax.set_ylim(0.85,0.95)
            #ax.plot(train_x,train_y,color = 'grey',alpha=1,label='train set')
            ax.plot(list(lf_var_list.keys()),list(lf_var_list.values()),color = 'red',alpha=1,label='covid gut')
            ax.legend(loc = 'upper right',fontsize=10)
            ax.set_ylabel('reconstruction error',fontsize=10)
            ax.set_xlabel('latent factors',fontsize=10)
            plt.tight_layout()
            plt.show()
            fig.savefig("/results/cg_optim_num_lf.pdf",dpi=800)
            
                

    elif decomp_method == 'aio':
    #the independent weighted all-in-once model
        start_time = time.time()
        shape = list(fc_tensor.shape)
        mshape = list(fc_matrix.shape)
    
        res_fm = all_in_one_mt(fc_tensor,fc_matrix,lf)
        decomp_fm = factor_matrix()
    
        decomp_fm.tfm = [np.zeros((shape[x],lf),dtype=float) for x in range(len(shape))]
        cnt_idx = 0
        for d in range(len(shape)):
            if d < 1:
                decomp_fm.tfm[d] = res_fm.x[0:(shape[d]*lf)].reshape(shape[d],lf)
                print(decomp_fm.tfm[d].shape)
            else:
                cnt_idx+=shape[d-1]*lf
                decomp_fm.tfm[d] = res_fm.x[cnt_idx:(cnt_idx+shape[d]*lf)].reshape(shape[d],lf)
                print(decomp_fm.tfm[d].shape)
        cnt_idx+=shape[-1]*lf
        decomp_fm.mfm = res_fm.x[cnt_idx:(cnt_idx+mshape[-1]*lf)].reshape(mshape[-1],lf)
        print(decomp_fm.mfm.shape)
        cnt_idx+=mshape[-1]*lf
        decomp_fm.weight.append(res_fm.x[cnt_idx:cnt_idx+lf].reshape(1,lf).squeeze())
        cnt_idx+=lf
        decomp_fm.weight.append(res_fm.x[cnt_idx:cnt_idx+lf].reshape(1,lf).squeeze())
        print(f'completed in {time.time()-start_time} seconds')
        final_var = rmse_eval(decomp_fm,fc_tensor,fc_matrix)
        print(f'final error is: {final_var}')
    else:
        print('unmatched factorization method\n')
        sys.exit(1)

def factormatrix_to_table(decomp_fm,fttag,mb_list,pwy_list,metabolite_list,outpath):
    df_dict = {}
    for lf in range(3):
        if decomp_fm.variance:
            var_exp = decomp_fm.variance[lf]*100
            ft_keyname = 'feature_lf'+str(lf+1)+"_"+str(var_exp)
        if fttag < 2:
            total_ftdist = list(decomp_fm.tfm[fttag][:,lf])
        else:
            total_ftdist = list(decomp_fm.mfm[:,lf])
        if fttag == 0:
            total_idx = [x for x in range(len(pwy_list))]
        elif fttag == 1:
            total_idx = [x for x in range(len(mb_list))]
        elif fttag == 2:
            total_idx = [x for x in range(len(metabolite_list))]
        cosort_arr = zip(total_ftdist,total_idx)
        sorted_ftdist = [x for x,i in sorted(cosort_arr,key=lambda p:abs(p[0]),reverse=True)][:30]
        cosort_arr = zip(total_ftdist,total_idx)
        sorted_idx = [i for x,i in sorted(cosort_arr,key=lambda p:abs(p[0]),reverse=True)][:30]
        if fttag == 0:
            sub_ft_list = [pwy_list[x].split(":")[-1] for x in sorted_idx]
        elif fttag == 1:
            sub_ft_list = [mb_list[x].split(".s__")[-1] for x in sorted_idx]
        elif fttag == 2:
            sub_ft_list = [metabolite_list[x] for x in sorted_idx]
        lv_keyname = 'loading_val_lf'+str(lf+1)
        df_dict[ft_keyname] = sub_ft_list
        df_dict[lv_keyname] = sorted_ftdist
    cand_df = pd.DataFrame(df_dict)
    cand_df.to_csv(outpath,sep='\t',index=False)
        
        
        
        
        
        
        
# quick plotting function of feature loadings
def feature_plot(decomp_fm,fttag,mb_list,pwy_list,metabolite_list,important_mb,important_mc,important_pwy):
    plt.style.use('fivethirtyeight')
    
    for lf in range(3):
        if decomp_fm.variance:
            var_exp = decomp_fm.variance[lf]*100
        fig = plt.figure(figsize=(10,10))
        ax = fig.add_subplot(1,1,1)
        
        if fttag < 2:
            total_ftdist = list(decomp_fm.tfm[fttag][:,lf])
        else:
            total_ftdist = list(decomp_fm.mfm[:,lf])
        if fttag == 0:
            total_idx = [x for x in range(len(pwy_list))]
        elif fttag == 1:
            total_idx = [x for x in range(len(mb_list))]
        elif fttag == 2:
            total_idx = [x for x in range(len(metabolite_list))]
        
        cosort_arr = zip(total_ftdist,total_idx)
        sorted_ftdist = [x for x,i in sorted(cosort_arr,key=lambda p:abs(p[0]),reverse=True)][:30]
        cosort_arr = zip(total_ftdist,total_idx)
        sorted_idx = [i for x,i in sorted(cosort_arr,key=lambda p:abs(p[0]),reverse=True)][:30]
        min_bound = min(total_ftdist)-0.2    
        max_bound = max(total_ftdist)+0.2
        ax.set_ylim(min_bound,max_bound)
        
        if fttag == 0:
            sub_ft_list = [pwy_list[x].split(":")[-1] for x in sorted_idx]
        elif fttag == 1:
            sub_ft_list = [mb_list[x].split(".s__")[-1] for x in sorted_idx]
        elif fttag == 2:
            sub_ft_list = [metabolite_list[x] for x in sorted_idx]
            #query_idx = metabolite_list.index('butyrate')
            #print(total_ftdist[query_idx])
            #test_idx = metabolite_list.index('nicotinuric acid')
            #print(total_ftdist[test_idx])
            
        ax.set_xticks([x for x in range(len(sub_ft_list))])
        ax.set_xticklabels(sub_ft_list,rotation = 45,fontsize = 10)
        color_idx = ['black' for x in range(len(sub_ft_list))]
        
        for idx,pre_ft in enumerate(sub_ft_list):
            if fttag == 0:
                if pre_ft in important_pwy:
                    color_idx[idx] = 'red'
            elif fttag == 1:
                if pre_ft in important_mc:
                    color_idx[idx] = 'red'
            elif fttag == 2:
                if pre_ft in important_mb:
                    color_idx[idx] = 'red'
            else:
                break

                    
        
        for ticklabel,tickcolor in zip(ax.get_xticklabels(),color_idx):
            ticklabel.set_color(tickcolor)
        
        x_coor = [x for x in range(len(sub_ft_list))]
        
        ax.bar(x_coor,sorted_ftdist,color = 'blue')
        
        
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
                 rotation_mode="anchor",fontsize=10)
        if decomp_fm.variance:
            ax.set_title(f'variance explained: {var_exp}%')
        elif decomp_fm.weight:
            ax.set_title(f'tensor weight: {decomp_fm.weight[0][lf]} matrix weight: {decomp_fm.weight[1][lf]}',fontsize=10)
        plt.tight_layout()
        plt.show()
        if fttag == 0:
            keyword = 'pathway'
        elif fttag == 1:
            keyword = 'microbe'
        elif fttag == 2:
            keyword = 'metabolite'
        savepath = "/results/highlighted_"+keyword+'_top30inlf'+str(lf+1)+".pdf"
        fig.savefig(savepath,dpi=800)

def enhanced_linkplot(decomp_fm,pwy_list,metabolite_list,link_dict):
    eh_lp = 5

output_fac_mat_tag = 1
if output_fac_mat_tag == 1:
    for fttag in range(3):
        if fttag == 0:
            keyword = 'pathway'
        elif fttag == 1:
            keyword = 'microbe'
        elif fttag == 2:
            keyword = 'metabolite'
        outpath = os.path.join('/results/','fc_'+keyword+".tsv")
        factormatrix_to_table(decomp_fm,fttag,mb_list,pwy_list,metabolite_list,outpath)
        
        
vistag = 0
sep_idx = 0

#print(decomp_fm.tfm[1].shape)
#print(len(mb_list))
#print(decomp_fm.tfm[0][:,0])
#print(mb_list[0])

#print(decomp_fm.mfm.shape)
#print(len(metabolite_list))
#print(decomp_fm.mfm[:,0])

## some tabulation for highlighting the features of interest in IBDMDB dataset
'''important_mb = ['NH4_C16:1 MAG','NH4_C22:5 CE','NH4_C54:6 TAG','NH4_C51:3 TAG','NH4_C52:4 TAG','NH4_C54:5 TAG','NH4_C36:3 DAG','NH4_C36:4 DAG',
              'NH4_C48:4 TAG','NH4_C48:3 TAG','NH4_C56:6 TAG','NH4_C48:0 TAG','NH4_C20:3 CE','NH4_C52:0 TAG','NH4_C18:0 CE','NH4_C20:4 CE',
              'NH4_C16:0 CE','NH4_C20:5 CE','NH4_C22:6 CE','NH4_C18:3 CE','NH4_C49:1 TAG','NH4_C56:1 TAG','NH4_C38:5 DAG','NH4_C44:2 TAG',
              'NH4_C44:0 TAG','NH4_C44:1 TAG','NH4_C53:2 TAG','NH4_C51:1 TAG','NH4_C34:2 DAG','NH4:C52:1 TAG','NH4_C32:2 DAG','NH4_C34:3 DAG',
              'NH4_C56:4 TAG','NH4_50;1 TAG','NH4_C51:2 TAG','urobilin','nicotinuric acid','C16-OH carnitine','C4 carnitine','C12:1 carnitine',
              'C8 carnitine','C12 carnitine','C18:1-OH carnitine','C14:2 carnitine','C10 carnitine','C14:1 carnitine','C9 carnitine','nicotinate','butyrate',
               'propionate','cholate','taurocholate','glycocholate','lithocholate','deoxycholate','taurochenodeoxycholate','glycochenodeoxycholate','urate',
               'uridine','arachidonate']'''
'''important_mc = ['Alistipes_finegoldii','Alistipes_putredinis','Faecalibacterium_prausnitzii','Bacteroides_vulgatus','Bacteroides_ovatus',
                'Bacteroides_fragilis','Bacteroides_stercoris','Bacteroides_uniformis','Escherichia_coli','Prevotella_copri','Roseburia_intestinalis',
               'Eubacterium_rectale','Ruminococcus_torques','Parabacteroides_distasonis','Klebsiella_pneumoniae']'''
## some tabulation for highlighting the features of interest in covid gut dataset
important_mb = ['Pyroglutamic acid','2-Methylglutaric acid','Succinic acid semialdehyde','p-Cresol','L-Norleucine']
important_mc = ['Escherichia_coli','Faecalibacterium_prausnitzii','Bacteroides_vulgatus','Bacteroides_uniformis','Bifidobacterium_longum','Akkermansia_muciniphila',
               'Hungatella_hathewayi']
important_pwy = [' L-glutamate and L-glutamine biosynthesis',' L-glutamine biosynthesis III',' incomplete reductive TCA cycle',' superpathway of L-tyrosine biosynthesis'
                ]

color_box = ['blue','red','green','yellow']
if vistag == 1:
    # visualizing feature loadings
    for fttag in range(3):
        print(f'plotting figure of feature {fttag}')
        feature_plot(decomp_fm,fttag,mb_list,pwy_list,metabolite_list,important_mb,important_mc,important_pwy)
if vistag == 2:
    # older version for visualizing simulated data distribution
    plt.style.use('fivethirtyeight')
    fig = plt.figure(figsize=(10,10))
    fig.tight_layout()

    #ax.tick_params(axis='x',length = 0)
    for lf in range(3):
        var_exp = decomp_fm.variance[lf]*100
        ax = fig.add_subplot(3,1,lf+1)
        ax.set_xticks([x for x in range(30)])
        ax.set_ylim(-1.05,1.05)
        ax.set_xticklabels(sample_label_new,rotation = 45,fontsize = 10)
        print(len(sample_label_new))
        
        candidate_ftdist = decomp_fm.tfm[1][:,lf]
        
        for sub_group in range(3):
            numerical_label = [x for x in range(0+10*sub_group,10+10*sub_group)] 
            x_coor = [sample_label_new.index(str(x+1)) for x in numerical_label]
            sub_ftdist = candidate_ftdist[x_coor]
            ax.scatter(x_coor,sub_ftdist,color = color_box[sub_group],alpha = 0.5)
            
        #plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
                 #rotation_mode="anchor",fontsize=10)
        ax.set_title(f'variance explained: {var_exp}%',fontsize=10)
    plt.show()
if vistag == 3:
    # visualizing the sample distribution after CMTF 
    plt.style.use("fivethirtyeight")
    fig = plt.figure(figsize=(10,10))
    
    # test for significance of group separation
    if label_dict and sep_idx == 1:
        all_keys = list(label_dict.keys())
        sub_dict = dict(label_dict) 
        for tmp_key in all_keys:
            if label_dict[tmp_key] == "C":
                del sub_dict[tmp_key]
        print(list(sub_dict.values()))
    
        test_ftdistdf = decomp_fm.tfm[2][:,0:2]
        test_pv = separation_test(test_ftdistdf,final_sample_list,sub_dict)
        print(f'separation between {list(set(sub_dict.values()))[0]} and {list(set(sub_dict.values()))[1]} :')
        print(f'p value is: {test_pv}')   
    for lf in range(3):
        if decomp_fm.variance:
            var_exp = decomp_fm.variance[lf]*100
        ax = fig.add_subplot(3,1,lf+1)
        ax.set_xticks([x for x in range(len(final_sample_list))])

        ax.set_xticklabels(final_sample_list,rotation=45,fontsize =  10)
        print(len(final_sample_list))
        if decomp_fm.weight:
            print(f"present the weight of lf {lf} in microbiome functional profiling tensor: ")
            print(decomp_fm.weight[0][lf])
            print(f"present the weight of lf {lf} in metabolites profiling matrix: ")
            print(decomp_fm.weight[1][lf])
        
        candidate_ftdist = decomp_fm.tfm[2][:,lf]  
        min_bound = min(candidate_ftdist)-0.2    
        max_bound = max(candidate_ftdist)+0.2
        
        ax.set_ylim(min_bound,max_bound)
        
        for idx in range(len(group_type)):
            sub_sample = [x for x in final_sample_list if label_dict[x] == group_type[idx]]
            x_coor = [final_sample_list.index(x) for x in sub_sample]
            sub_ftdist = candidate_ftdist[x_coor]
            ax.scatter(x_coor,sub_ftdist,color=color_box[idx],alpha=0.5,label = group_type[idx])
        
        if decomp_fm.variance:
            ax.set_title(f'variance explained: {var_exp}%',fontsize=10)
        elif decomp_fm.weight:
            ax.set_title(f'tensor weight: {decomp_fm.weight[0][lf]} matrix weight: {decomp_fm.weight[1][lf]}',fontsize=10)
        ax.legend(loc = 'upper right',fontsize=10)
    plt.tight_layout()
    plt.show()

    
if vistag == 4:
    # 2d plot of sample distribution
    plt.style.use('fivethirtyeight')
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(1,1,1)
    ax.set_ylim(-1.1,1.1)
    ax.set_xlim(-1.1,1.1)
    
    xstr = "variance explained "+str(decomp_fm.variance[0]*100)+"%"
    ystr = "variance explained "+str(decomp_fm.variance[1]*100)+"%"
    ax.set_xlabel(xstr,fontsize = 10)
    ax.set_ylabel(ystr,fontsize = 10)
    all_keys = list(label_dict.keys())
    sub_dict = dict(label_dict)
    if label_dict and sep_idx == 1:
        for tmp_key in all_keys:
            if label_dict[tmp_key] == "SASC":
                del sub_dict[tmp_key]
        test_ftdistdf = decomp_fm.tfm[2][:,0:2]
        test_pv = separation_test(test_ftdistdf,final_sample_list,sub_dict)
        print(f'separation between {list(set(sub_dict.values()))[0]} and {list(set(sub_dict.values()))[1]} :')
        print(f'p value is: {test_pv}')    
    
    candidate_ftdist_2d = decomp_fm.tfm[2][:,0:2]
    for idx in range(len(group_type)):
        sub_sample = [x for x in final_sample_list if label_dict[x] == group_type[idx]]
        x_coor = [final_sample_list.index(x) for x in sub_sample]
        sub_ftdist = candidate_ftdist_2d[x_coor,:]
        sub_x = sub_ftdist[:,0]
        sub_y = sub_ftdist[:,1]
        ax.scatter(sub_x,sub_y,color=color_box[idx],alpha=1,s=100,label=group_type[idx])
    ax.legend(loc= 'lower right',fontsize=10)
    plt.tight_layout()
    plt.show()
    fig.savefig("/results/flight_canbin_obs_dist_2d_published2.pdf",dpi=800)


if vistag == 5:
    #visualizing latent factor-reconstruction error curve
    plt.style.use('fivethirtyeight')
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(1,1,1)
    
    ax.set_ylim(0.85,0.95)
    ax.plot(train_x,train_y,color = 'grey',alpha=1,label='train set')
    ax.plot(test_x,test_y,color = 'red',alpha=1,label='test set')
    ax.legend(loc = 'upper right',fontsize=10)
    ax.set_ylabel('reconstruction error',fontsize=10)
    ax.set_xlabel('latent factors',fontsize=10)
    plt.tight_layout()
    plt.show()
    fig.savefig("/results/train_test_curve_published.pdf",dpi=800)

if vistag == 6:
    # visualizing relationships between metabolites and pathways
    enhance_linkplot(decomp_fm,pwy_list,metabolite_list,link_dict)

if vistag == 7:
    # visualizing effect of different variance on factorization
    print(v_als_rmse)
    print(v_aio_rmse)
    plt.style.use('fivethirtyeight')
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(1,1,1)

    var_x = np.arange(1,11)
    ax.plot(var_x,v_als_rmse,color = 'black',alpha=1,label="basic CMTF")
    ax.plot(var_x,v_aio_rmse,color = 'red',alpha=1,label="advanced CMTF")
    ax.legend(loc = 'upper right',fontsize=20)
    ax.set_ylabel('reconstruction error',fontsize=20)    
    ax.set_xlabel('variance',fontsize=20)    
    plt.tight_layout()
    plt.show()
    fig.savefig("/results/var_on_rmse_published.pdf",dpi=800)
if vistag == 8:
    # visualizing effect of different sparsity on factorization
    print(sp_als_rmse)
    print(sp_aio_rmse)
    plt.style.use('fivethirtyeight')
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(1,1,1)
    
    sp_x = [0,0.2,0.4,0.6,0.8]
    ax.plot(sp_x,sp_als_rmse,color = 'black',alpha=1,label="basic CMTF")
    ax.plot(sp_x,sp_aio_rmse,color = 'red',alpha=1,label="advanced CMTF")
    ax.legend(loc = 'upper right',fontsize=20)
    ax.set_ylabel('reconstruction error',fontsize=20)    
    ax.set_xlabel('data sparsity',fontsize=20)        
    plt.tight_layout()
    plt.show()
    fig.savefig("/results/sp_on_rmse_published.pdf",dpi=800)  