In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import preprocessing
import random
# REFINED libraries
import pickle
import scipy.misc
import sys
import Toolbox 
from Toolbox import *
import math
from scipy import stats

import Utils
import utils_nn
import model
from model import *
from utils_nn import *
from Utils import *
import utils_tl_cind
from utils_tl_cind import *
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import OneHotEncoder
from lifelines.utils import concordance_index
from sklearn.model_selection import KFold

import torch 
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.init import kaiming_uniform_
from torch.nn.init import xavier_uniform_

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = "cpu"

In [None]:
Dir = '/TCGA/'

RNAseq_PD = pd.read_csv(Dir+'combined_rnaseq.csv') # RNAseq data for each patient
Cancers = RNAseq_PD['cancer_type'].unique().tolist()
GuanRank_all_PD = pd.read_csv(Dir+'/whole_data.csv') # Survival and GuanRank data for each patient

In [None]:
Guan_train_all, Guan_test_all = GuanRankSplit(GuanRank_all_PD)

In [None]:
# read refined mapping
mf = 'TSNE'
with open(Dir+'All_Disease_'+"_Init_" + mf + ".pickle",'rb') as file:
    eatur_names,dist_matr,init_map = pickle.load(file)
Nn = init_map.shape[0]

In [None]:
Disease = ['GBMLGG','KIPAN', 'HNSC', 'GBM', 'LUAD', 'LUSC', 'BLCA', 'BRCA', 'COADREAD']

In [None]:
def weights_init(m):
    if isinstance(m,nn.Conv2d) or isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)


In [None]:
RNAseq_PD = RNAseq_PD.drop_duplicates(subset='Patient_ID', keep="first")
n_epochs = 600
n_epochs_tune = 300
ES = 1

Res_all_pd = []
cind_ls_pre_tl_disease = []
cind_ls_tl_disease = []
cind_ls_disease = []
for cns in Disease:
    print("Cancer type: " + cns)
    cind_ls = []
    cind_ls_tl = []
    cind_ls_pre_tl = []
    for fold in range(5):
        print("Fold: " , fold)
        
        data = Dataset(Guan_train_all, Guan_test_all, RNAseq_PD, fold, cns, Dir, mf)
        data.read_REFINED()
        data.Split_Data(device)
        
        # net = RCNN_NRMSE(32,128).to(device) if RCNN_NRMSE is loaded then the loss function need to be NRMSE from (utils_nn) in the train function in line 32
        net = RCNN(32,64,128).to(device)
        
        net = net.apply(weights_init)
        
        ############## Deep REFINED CNN

        optimizer = optim.Adam(net.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
        
        net,loss_train_all,loss_val_all = train(net, optimizer, n_epochs, "all", data.X_REFINED_Train_rest, data.Guan_train_rest, 
                                                data.X_REFINED_Test_rest, data.Guan_test_rest, ES, device)
        
        print('Fold : ', fold)
        C_ind_test = evaluate_model_reg(net, data.X_REFINED_Test_rest, data.Guan_test_rest, device)
        print("C-ind test all: " + str(C_ind_test[0]))
        ################ fine tuning
        # freezing the layers
        for param in net.parameters():
            param.requires_grad = False

        for param in net.fc_l3.parameters():
            param.requires_grad = True
            
        for param in net.dropout2.parameters():
            param.requires_grad = True
    
        #freezing ends
        C_ind_disease = evaluate_model_reg(net, data.X_REFINED_Test_disease, data.Guan_test_disease, device)
        print("C-ind test pre-transfer learning " + cns + " : " + str(C_ind_disease[0]))
        cind_ls_pre_tl.append(C_ind_disease[0])
        
        val_set_size = int(0.1*data.X_REFINED_Train_disease.shape[0])
        
        optimizer_disease = optim.Adam(net.parameters(), lr = 0.001)
        net,loss_train_disease,loss_val_disease = train(net, optimizer_disease, n_epochs_tune, cns, data.X_REFINED_Train_disease[:-val_set_size], 
                                                        data.Guan_train_disease[:-val_set_size], data.X_REFINED_Train_disease[-val_set_size:], 
                                                        data.Guan_train_disease[-val_set_size:], ES, device)
        
        C_ind_disease = evaluate_model_reg(net, data.X_REFINED_Test_disease, data.Guan_test_disease, device)
        print("C-ind test post-transfer learning " + cns + " : " + str(C_ind_disease[0]))
        cind_ls_tl.append(C_ind_disease[0])
        
        # fine-tuning
        # freezing the layers
        for param in net.parameters():
            param.requires_grad = True
    

        optimizer_finetune = optim.Adam(net.parameters(), lr = 0.1*0.001)
        net,loss_train_disease,loss_val_disease = train(net, optimizer_finetune, n_epochs_tune, cns, data.X_REFINED_Train_disease[:-val_set_size], 
                                                        data.Guan_train_disease[:-val_set_size], data.X_REFINED_Train_disease[-val_set_size:], 
                                                        data.Guan_train_disease[-val_set_size:], ES, device)
        
        C_ind_disease = evaluate_model_reg(net, data.X_REFINED_Test_disease, data.Guan_test_disease, device)
        
        
        print("C-ind test post fine-tuning " + cns + " : " + str(C_ind_disease[0]))
        
        # final test "
        X_torch = torch.tensor(data.X_REFINED_Test_disease, dtype = torch.float).to(device)
    
        y_pred_test = net(X_torch)
        #p_value_disease = StatistTest(C_ind_disease, y_pred_test, data.Guan_test_disease, n_bootstrap = 1000)
        #print("p-value: ", p_value_disease)
        del net
        cind_ls.append(C_ind_disease[0])
        
    Res_PD = pd.DataFrame(data = np.array([average_ls(cind_ls), std_ls(cind_ls)]).reshape(1,-1), index = [cns], columns = ['Mean','STD'])
    Res_TL_PD = pd.DataFrame(data = np.array([average_ls(cind_ls_tl), std_ls(cind_ls_tl)]).reshape(1,-1), index = [cns], columns = ['Mean','STD'])
    Res_Pre_TL_PD = pd.DataFrame(data = np.array([average_ls(cind_ls_pre_tl), std_ls(cind_ls_pre_tl)]).reshape(1,-1), index = [cns], columns = ['Mean','STD'])
   
    cind_ls_disease.append(Res_PD)
    cind_ls_tl_disease.append(Res_TL_PD)
    cind_ls_pre_tl_disease.append(Res_Pre_TL_PD)
    print(Res_PD)
Res_all_PD = pd.concat(cind_ls_disease)
Res_TL_PD = pd.concat(cind_ls_tl_disease)
Res_Pre_TL_PD = pd.concat(cind_ls_pre_tl_disease)
print("pre_transfer_learning")
print(Res_Pre_TL_PD)
print("post_transfer_learning:")
print(Res_TL_PD)
Res_all_PD