% Model selection over skylines for different square waves
clearvars; clc;
close all; tic;

% Assumptions and notes
% - square wave model selection, different periods
% - underlying log N(t) trajectories fixed but random

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Home and folders to save/load
thisDir = cd;
folload = 'sqwave_';
folsave = 'sq_results';
% Figures to plot, data to save
allFigs = 0;
savetrue = 0;

% Metrics to be evaluated
metrics = {'$-\log L$', 'BIC', 'FIA', 'QK', 'AIC'};
nmet = length(metrics);

% Folder with all data from R sims
cd('using phylodyn/sq_batch/');
% All folders to load with shift data
fols = dir([folload '*']);
nFiles = length(fols);
cd(thisDir);

% Max for FIA volume
Nmax = 1000; V = log(Nmax);
disp(['Nmax for FIA = ' num2str(Nmax)]);

%% Load main data per run

% Performance variables across sample runs
PSucc = cell(1, nFiles); Ranks = PSucc;
Trus = PSucc; Fal1s = PSucc; Fal2s = PSucc;
ks = PSucc; cshifts = PSucc; wavNo = PSucc;
ncs = zeros(1, nFiles); Ms = ncs; tshift = ncs;
Nrange = zeros(2, nFiles); modEst = PSucc; modTrue = PSucc;

for ii = 1:nFiles
    % Load current folder
    dataload = ['using phylodyn/sq_batch/' fols(ii).name];
    cd(dataload);
    
    % Trajectory id and number
    idTraj = csvread('idTraj.csv');
    numTraj = csvread('numTraj.csv');
    
    % Wave types and repetitions
    Ms(ii) = csvread('numTraj.csv');
    wave = csvread('wave.csv'); 
    
    % Fixed tree characteristics across M
    ncs(ii) = unique(csvread('nc.csv'));
    sampIntro = csvread('sampIntro.csv');
    tsamp = csvread('samptimes.csv');
    
    % Shift time of segments and amp
    tshift(ii) = unique(diff(tsamp));
    Nrange(:, ii) = csvread('Nrange.csv');
    cd(thisDir);
    
    % Set space of group sizes based on wave
    nwave = size(wave, 1);
    knorm = 2.^(0:nwave-1);
    
    % Same no. samps on each segment
    k = knorm*unique(sampIntro);
    ks{ii} = k;
        
    % No. coalescences in each segment
    nshift = length(sampIntro);
    cshift = zeros(Ms(ii), nshift);
    
    % Store frac of change, best models and criteria
    criteria = cell(1, Ms(ii)); nGrp = criteria;
    kbest = zeros(Ms(ii), nmet); nGrpbest = kbest;
    modEstID = kbest;
    for i = 1:Ms(ii)
        % Main function for selection models
        [criteria{i}, kbest(i, :), nGrpbest(i, :), nGrp{i}, cshift(i, :), modEstID(i, :)] = ...
            sqWaveFn(i, dataload, k, nwave, sampIntro, V, tsamp);
    end
    
    % True model indices and prob success
    modTrue{ii} = idTraj; modEst{ii} = modEstID;
    Psucc = sum(modEst{ii} == modTrue{ii})/Ms(ii);
    
    % No. of each type of wave
    typeWv = zeros(1, nwave);
    for i = 1:nwave
        typeWv(i) = length(find(idTraj == i));
    end
    wavNo{ii} = typeWv;

    % Rank from best to worst
    [~, rankid] = sort(Psucc);
    ranks = metrics(rankid);
    ranks = ranks(end:-1:1);
    disp(ranks);
    
    % Classification performance
    Tru = zeros(1, nmet); Fal1 = Tru; Fal2 = Tru;
    C = cell(1, nmet); dimC = Tru;
    for i = 1:nmet
        % Confusion matrix and dimension
        Ctemp = confusionmat(modTrue{ii}, modEst{ii}(:, i));
        % True and erroneous classifications
        Tru(i) = sum(diag(Ctemp));
        for iii = 1:nmet-1
            Fal1(i) = Fal1(i) + sum(diag(Ctemp, -iii));
            Fal2(i) = Fal2(i) + sum(diag(Ctemp, iii));
        end
    end
    % Normalise by num runs
    Tru = Tru/Ms(ii); Fal1 = Fal1/Ms(ii); Fal2 = Fal2/Ms(ii);
    
    % Store data
    PSucc{ii} = Psucc; cshifts{ii} = cshift;
    Ranks{ii} = ranks; Trus{ii} = Tru; 
    Fal1s{ii} = Fal1; Fal2s{ii} = Fal2;
    disp(['Processed ' num2str(ii) ' of ' num2str(nFiles)]);
end

% Convert cells to matrices
PSucc = cell2mat(PSucc'); cshifts = cell2mat(cshifts); 
wavNo = cell2mat(wavNo'); Trus = cell2mat(Trus');
Fal1s = cell2mat(Fal1s'); Fal2s = cell2mat(Fal2s');

% Re-order in increasing no. coalescent events
[ncs, idc] = sort(ncs);
PSucc = PSucc(idc, :); cshifts = cshifts(idc, :);
wavNo = wavNo(idc, :); Trus = Trus(idc, :);
Fal1s = Fal1s(idc, :); Fal2s = Fal2s(idc, :);

% Avg detection rate and off-diagonal failures
Pdet = sum(Trus)/nFiles;
F1 = sum(Fal1s)/nFiles;
F2 = sum(Fal2s)/nFiles;

% Examine successful detection
figure;
plot(ncs, PSucc(:, 2:end), '.-', 'linewidth', 2, 'MarkerSize', 30);
grid off; box off;
xlabel(['$m$ ($\tau$ = ' num2str(unique(tshift)) ')']);
ylabel('P(correct)');
xlim([min(ncs) max(ncs)]);
legend(metrics(2:end), 'location', 'best');

% Overall performance of each metric
figure;
plot(1:nmet, [Pdet' F1' F2'], '.-', 'linewidth', 2, 'MarkerSize', 30);
grid off; box off;
h = gca; h.XTick = 1:nmet;
h.XTickLabel = metrics;
legend('P(true)', 'P(FP)', 'P(FN)', 'location', 'best');

% Log time
tsim = toc/60;
disp(['Execution time: ' num2str(tsim) ' mins']);