%% The script analyzes how stable the cluster signatures are to the number of patients analyzed.
% we subsample many times pairs, triplets, 4- 5- and 6- patients and calculate the
% mean gene expression of each of the 17 clusters. We then show a PCA and
% calculate an F score for each cluster and #patients - the mean distance
% within a cluster of the sub-sampled groups divided by the mean distance
% between clusters. Showing saturation of this F score vs. # patients shows
% stability and generality.

%% Load the scRNAseq data
addpath(genpath('./helperFunctions'));
load ./inputData.mat hint tbl

%% extract the cell type identities of all the atlas cells
patient=hint.metadata.human;
patientU=unique(patient);
clear patient_ind;
for i=1:length(patientU)
    patient_ind(find(strcmpi(patient,patientU{i})))=i;
end
cell_type=hint.metadata.seurat_clusters;
cell_typeU=unique(cell_type);
clear cell_type_annotation
for i=1:length(cell_typeU)
    index=find(tbl.cellNumber==cell_typeU(i));
    cell_type_annotation{i}=tbl.cellName{index};
end
far_ind=find(strcmpi(hint.metadata.treatment,'far'));
tumor_ind=find(strcmpi(hint.metadata.treatment,'tumor'));

% For T cells and carcinoma cells add 2 new clusters that include only the
% far cells
M=max(cell_typeU);
index=find(strcmpi(cell_type_annotation,'carcinoma'));
ind_cholangiocytes=intersect(find(cell_type==cell_typeU(index)),far_ind);
cell_type(ind_cholangiocytes)=M+1;
cell_type_annotation{end+1}='cholangiocytes';
index=find(strcmpi(cell_type_annotation,'T cells'));
ind_Tfar=intersect(find(cell_type==cell_typeU(index)),far_ind);
cell_type(ind_Tfar)=M+2;
cell_type_annotation{end+1}='T far';
cell_type_annotation{index}='T tumor';
cell_typeU(end+1)=M+1;
cell_typeU(end+1)=M+2;

%%
% Create all combinations of 2,3,4,5,6
clear combos
clear exp_tbl;
for i=1:6,
    i
    combos{i} = repmat(combntns(1:6,i),3,1); % 3 bootstrap iterations for each combo
    % create a table of #genes*#subsamples*#populations
    mean_exp{i}=zeros(length((hint.all_genes)),size(combos{i},1),length(cell_type_annotation));
    for j=1:size(combos{i},1)
        ind2include=[];
        for k=1:size(combos{i},2)
            ind2include=[ind2include;find(patient_ind==combos{i}(j,k))'];
        end
        % in addition bootstrap this patient data, so as to allow fair
        % comparison with the i=6 class
        ind2include=datasample(ind2include,length(ind2include));
        for k=1:length(cell_type_annotation)
            %indin=intersect(ind2include,find(cell_type==cell_typeU(k)));
            indin=ind2include(cell_type(ind2include)==cell_typeU(k));
            mean_exp{i}(:,j,k)=mean(hint.mat_norm(:,indin),2);
        end
    end
end

% For the 6 patients create random bootstrap versions and average the gene
% expression
for j=1:50,
    ind2include=datasample(1:size(hint.mat_norm,2),size(hint.mat_norm,2)); % bootstrap
    for k=1:length(cell_type_annotation)
        indin=ind2include(cell_type(ind2include)==cell_typeU(k));
        mean_exp{6}(:,j,k)=mean(hint.mat_norm(:,indin),2);
    end
end


%%
% Create 'randomized' sets with the same sizes as the sub-sampled patients (but random cells)
clear combos_rnd
clear exp_tbl_rnd;
for i=1:6,
    i
    combos{i} = repmat(combntns(1:6,i),3,1); % 3 bootstrap iterations for each combo
    % create a table of #genes*#subsamples*#populations
    mean_exp_rnd{i}=zeros(length((hint.all_genes)),size(combos{i},1),length(cell_type_annotation));
    for j=1:size(combos{i},1)
        ind2include=[];
        for k=1:size(combos{i},2)
            ind2include=[ind2include;find(patient_ind==combos{i}(j,k))'];
        end
        % now randomly sample the same number of cells
        L=length(ind2include);
        ord=randperm(size(hint.mat_norm,2));
        ind2include=ord(1:L);
        % in addition bootstrap this patient data, so as to allow fair
        % comparison with the i=6 class
        ind2include=datasample(ind2include,length(ind2include));
        for k=1:length(cell_type_annotation)
            %indin=intersect(ind2include,find(cell_type==cell_typeU(k)));
            indin=ind2include(cell_type(ind2include)==cell_typeU(k));
            mean_exp_rnd{i}(:,j,k)=mean(hint.mat_norm(:,indin),2);
        end
    end
end

% For the 6 patients create random bootstrap versions and average the gene
% expression
for j=1:50,
    ind2include=datasample(1:size(hint.mat_norm,2),size(hint.mat_norm,2)); % bootstrap
    for k=1:length(cell_type_annotation)
        indin=ind2include(cell_type(ind2include)==cell_typeU(k));
        mean_exp_rnd{6}(:,j,k)=mean(hint.mat_norm(:,indin),2);
    end
end


%% Now analyze for each population and sub-sampling the distance from the full atlas signature
EXP_THRESH=5*10^-6;
full_atlas_sig=zeros(size(mean_exp{6},1),size(mean_exp{6},3)); % #genes*#cell types
for k=1:length(cell_type_annotation)    
    indin=find(cell_type==cell_typeU(k));
    full_atlas_sig(:,k)=mean(hint.mat_norm(:,indin),2);
end
indicator_comb=[];
indicator_cell_type=[];
tab=[];
dist=[];
for i=1:6,
    % calculate the general distance between any two cell type clusters
    counter=1;
    for j=1:size(mean_exp{i},2)
        for k=1:size(mean_exp{i},3)
            tab(:,counter)=mean_exp{i}(:,j,k);
            indicator_comb=[indicator_comb i];
            indicator_cell_type=[indicator_cell_type k];
            matt=squeeze(mean_exp{i}(:,j,k));
            ind2include=find(matt>EXP_THRESH);
            if isempty(ind2include)
                dist=[dist NaN];
            else
                ref=full_atlas_sig(:,k);
                %dist=[dist sqrt(mean((ref-matt).^2))];
                dist=[dist corr(ref(ind2include),matt(ind2include),'type','spearman')];
            end
            counter=counter+1;
        end
    end
end
median_dist=zeros(6,19);
stat=zeros(1,19); % difference between median distance of 5 and 6 divided by the mad of 6

stat=zeros(1,19);
for j=1:19
    for i=1:6,
        indin=find(indicator_comb==i & indicator_cell_type==j);
        median_dist(i,j)=nanmedian(dist(indin));
        if i==6
            stat(j)=(median_dist(5,j)-mean(dist(indin)))./std(dist(indin));
        end
    end
end
% plot resulting cell populations sorted according to marginal increase in
% precision of cell-type signature
dd=diff(median_dist);
stat=dd(end,:)./mean(dd(end-2:end-1,:)); % sort the ratio between the delta median distance for 5 and 6
[y,ord]=sort(stat,'ascend');
f1=figure('Units','centimeters','Position',[54 4 25 25]);
for i=1:19,
    subplot(4,5,i);
    violinplot(dist(indicator_cell_type==ord(i)),indicator_comb(indicator_cell_type==ord(i)));
    hold on;
    plot(1:6,median_dist(1:end,ord(i)),'k-');
    title(cell_type_annotation{ord(i)});
    set(gca,'Tag',cell_type_annotation{ord(i)});
    box on;
    axis square
    axis tight
    xlim([min(xlim)-0.3 max(xlim)+0.3]); %add margins from box sides 
    ylim([min(ylim)-0.05 max(ylim)+0.05]); %add margins from box sides 
end

%control marker size
icons = findobj(f1,'type','Scatter');
for i=1:length(icons)
    set(icons(i),'SizeData',25);
end


%% Add median_dist_rnd
EXP_THRESH=5*10^-6;
full_atlas_sig=zeros(size(mean_exp{6},1),size(mean_exp{6},3)); % #genes*#cell types
for k=1:length(cell_type_annotation)    
    indin=find(cell_type==cell_typeU(k));
    full_atlas_sig(:,k)=mean(hint.mat_norm(:,indin),2);
end
indicator_comb_rnd=[];
indicator_cell_type_rnd=[];
tab_rnd=[];
dist_rnd=[];
for i=1:6,
    % calculate the general distance between any two cell type clusters
    counter=1;
    for j=1:size(mean_exp_rnd{i},2)
        for k=1:size(mean_exp_rnd{i},3)
            tab_rnd(:,counter)=mean_exp_rnd{i}(:,j,k);
            indicator_comb_rnd=[indicator_comb_rnd i];
            indicator_cell_type_rnd=[indicator_cell_type_rnd k];
            matt=squeeze(mean_exp_rnd{i}(:,j,k));
            ind2include=find(matt>EXP_THRESH);
            if isempty(ind2include)
                dist_rnd=[dist_rnd NaN];
            else
                ref=full_atlas_sig(:,k);
                %dist=[dist sqrt(mean((ref-matt).^2))];
                dist_rnd=[dist_rnd corr(ref(ind2include),matt(ind2include),'type','spearman')];
            end
            counter=counter+1;
        end
    end
end
median_dist_rnd=zeros(6,19);
stat_rnd=zeros(1,19); % difference between median distance of 5 and 6 divided by the mad of 6

stat=zeros(1,19);
for j=1:19
    for i=1:6,
        indin=find(indicator_comb_rnd==i & indicator_cell_type_rnd==j);
        median_dist_rnd(i,j)=nanmedian(dist_rnd(indin));
        if i==6
            stat(j)=(median_dist_rnd(5,j)-mean(dist_rnd(indin)))./std(dist_rnd(indin));
        end
    end
end
hold on
for i=1:19,
    subplot(4,5,i);    
    hold on;
    plot(1:6,median_dist_rnd(1:end,ord(i)),'r--','linewidth',1.5);    
    box on;
end

%calc delta between median_dist and median_dist_rand
delta=[];
for i=1:19,
delta(:,i)=median_dist_rnd(1:end,ord(i))-median_dist(1:end,ord(i));
end
subplot(4,5,20);   
hold all;
carcinoma_ind=find(strcmpi(cell_type_annotation(ord),'Carcinoma'));
other_cellTypes=setdiff(1:length(cell_type_annotation),carcinoma_ind);
ord2=ord(ord~=carcinoma_ind);
plot(1:6,delta(:,ord2),'linewidth',1,'color',[.5 .5 .5]);
plot(1:6,delta(:,carcinoma_ind),'linewidth',1.2,'color','k');
% plot([1 6],[0 0],'-','linewidth',1,'color',[.2 .2 .2]);
set(gca,'xtick',1:6);
set(gca,'xticklabel',1:6);
xlim([1 6]);
box on;
axis square
axis tight
ylabel('Correlation divergence from subsampled 6 patients')
legend('Carcinoma','other')
grid

% find carcinoma axes and make inset
set(0,'showhiddenhandles','on'); % Make the GUI figure handle visible
h = findobj(f1,'type','axes','-and','Tag','Carcinoma'); % Find the axes object in the GUI
f2=figure('Units','centimeters','Position',[54 4 21 21]);
s = copyobj(h,f2);
axis tight
yylim=ylim;
ylim([0.91 max(yylim)+0.005]);
xlim([3.5 6.5]);
set(gca,'Position',[h.Position(1) h.Position(2) h.Position(3)./1.4 h.Position(4)./1.4]);
axis square
set(gca,'XTick',[]);    
set(gca,'YTick',[]);    
box on
icons = findobj(f1,'type','Scatter');
for i=1:length(icons)
    set(icons(i),'SizeData',20);
end
icons = findobj(f2,'type','line');
for i=1:length(icons)
    set(icons(i),'linewidth',1);
end
