% ROC examination for skylines in batch
clearvars; clc;
close all; tic;

% Assumptions and notes
% - loops across trajectories and data (samples)
% - expects all data in a single batch folder
% - single shift binary hypothesis problem
% - select over skyline segment number using robust criteria
% - test with several criteria: MDL, BIC, likelihood

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Home and folder to save/load
savetrue = 0;
folload = 'shiftShort_';
folsave = 'rocData';
thisDir = cd;

% Metrics to be evaluated
metrics = {'$-\log L$', 'BIC', 'FIA', 'QK', 'AIC'};
nmet = length(metrics);

% Folder with all data from R sims
cd('using phylodyn/batch data/');
% All folders to load with shift data
files = dir([folload '*']);
nFiles = length(files);
cd(thisDir);

% Performance variables across sample runs
TPRs = cell(1, nFiles); FPRs = TPRs;
PSucc = TPRs; dimCs = TPRs; Cset = TPRs;
Ranks = TPRs; fracs = TPRs; ks = TPRs;
cshifts = TPRs;
ncs = zeros(1, nFiles); Ms = ncs;

% Define Nmax for FIA
Nmax = 10^5; % actual max is 10^3

for ii = 1:nFiles
    % Load current file
    dataload = ['using phylodyn/batch data/' files(ii).name];
    cd(dataload);
    % True shift time and models
    tshift = csvread('tshift.csv'); 
    frac = csvread('frac.csv'); 
    M = csvread('numTraj.csv'); 
    % No. coalescencent events in each trajectory
    nc = csvread('nc.csv'); 
    % Store from runs
    Ms(ii) = M; fracs{ii} = frac; ncs(ii) = nc;
    % Samples introduced at each sample time
    sampIntro = csvread('sampIntro.csv');
    cd(thisDir);
    
    % Set group sizes possible (non-repeated)
    k = 1:nc; k = k(rem(nc, k) == 0);
    k = k(k >= 4); lenk = length(k);
    %k = k(end-1:end); lenk = 2;
    ks{ii} = k;
    
    % Number of coalescent events either side of tshift
    cshift = zeros(M, 2);
    
    % Store frac of change, best models and criteria
    criteria = cell(1, M); nGrp = criteria;
    kbest = zeros(M, nmet); nGrpbest = kbest;
    for i = 1:M
        % Main function for selection models
        [criteria{i}, kbest(i, :), nGrpbest(i, :), nGrp{i}, cshift(i, :)]= rocSkyFn2(i,...
            dataload, k, lenk, sampIntro, tshift, Nmax);
    end
    cshifts{ii} = cshift;
    
    % Null model indices and true nGrp
    modTrueID = ones(size(frac));
    modTrueID(frac == 1) = 0;
    nGrpTrue = modTrueID + 1;
    
    % Boolean for if metric selected the correct model
    boolMet = nGrpbest == nGrpTrue;
    Psucc = sum(boolMet)/M;
    % Rank from best to worst
    [~, rankid] = sort(Psucc);
    ranks = metrics(rankid);
    ranks = ranks(end:-1:1);
    disp(ranks);
    
    % No. alternative model choices
    nPos = sum(modTrueID);
    % Null model
    nNeg = M - nPos;
    % Classification performance
    TPR = zeros(1, nmet); FPR = TPR;
    C = cell(1, nmet); dimC = TPR;
    for i = 1:nmet
        % Confusion matrix and dimension
        Ctemp = confusionmat(nGrpTrue, nGrpbest(:, i));
        C{i} = Ctemp; dimC(i) = length(Ctemp);
        % True and false positive rate
        TPR(i) = Ctemp(2, 2)/nPos;
        FPR(i) = 1 - Ctemp(1, 1)/nNeg;
    end
    % Store data
    TPRs{ii} = TPR; FPRs{ii} = FPR;
    PSucc{ii} = Psucc; Cset{ii} = C;
    dimCs{ii} = dimC; Ranks{ii} = ranks;
    disp(['Processed ' num2str(ii) ' of ' num2str(nFiles)]);
end

% Convert cells to matrices
TPRs = cell2mat(TPRs'); FPRs = cell2mat(FPRs');
PSucc = cell2mat(PSucc'); dimCs = cell2mat(dimCs'); 
% Re-order in increasing no. coalescent events
[ncs, idc] = sort(ncs);
TPRs = TPRs(idc, :); FPRs = FPRs(idc, :);
PSucc = PSucc(idc, :); dimCs = dimCs(idc, :);

% Examine successful detection
figure;
plot(ncs, PSucc(:, 2:end), '.-', 'linewidth', 2, 'MarkerSize', 30);
grid off; box off;
xlabel(['$m$ ($\tau$ = ' num2str(tshift) ')']);
ylabel('P(correct)');
legend(metrics(2:end), 'location', 'best');

% ROC curve
figure;
plot(FPRs(:, 2:end), TPRs(:, 2:end), '.', 'MarkerSize', 35);
grid off; box off;
xlabel('FPR');
ylabel('TPR');
legend(metrics(2:end), 'location', 'best');

% Examine successful detection
figure;
subplot(2, 1, 1);
plot(ncs, PSucc(:, 2:end), '.-', 'linewidth', 2, 'MarkerSize', 30);
grid off; box off;
xlabel(['$m$ ($\tau$ = ' num2str(tshift) ')']);
ylabel('P(correct)');
subplot(2, 1, 2);
% ROC curve
plot(FPRs(:, 2:end), TPRs(:, 2:end), '.', 'MarkerSize', 35);
grid off; box off;
xlabel('FPR');
ylabel('TPR');
legend(metrics(2:end), 'location', 'best');

% Log time
tsim = toc/60;
disp(['Execution time: ' num2str(tsim) ' mins']);