% Model selection over skylines for different square waves
clearvars; clc;
close all; tic;

% Assumptions and notes
% - simplified to fix the true trajectory
% - uses simpler folder structure
% - square wave model selection, different periods
% - underlying log N(t) trajectories fixed but random

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

%% Load main data and parse from R

% Home and folders to save/load
thisDir = cd;
folload = 'sqwaves_5';
folsave = 'sq_results';
% Figures to plot, data to save
allFigs = 0; savetrue = 0;

% Metrics to be evaluated
metrics = {'$-\log L$', 'BIC', 'FIA', 'QK', 'AIC'};
nmet = length(metrics);

% Max for FIA volume
Nmax = 1000; V = log(Nmax);
disp(['Nmax assumed for FIA = ' num2str(Nmax)]);

% Data controlling all sims
cd(['using phylodyn/sq_short/' folload]);

% Trajectory, no. runs and trees
idTraj = csvread('idTraj.csv');
M = csvread('numTraj.csv');
nRuns = csvread('numRuns.csv');

% Wave types and popul range
wave = csvread('wave.csv');
Nrange = csvread('Nrange.csv');
disp(['True [Nmax Nmin] = ' num2str(Nrange')]);

% Sample data (fixed no. samps intro each segment)
sampRuns = csvread('sampRuns.csv');
tsamp = csvread('samptimes.csv'); nsegs = length(tsamp);

% Coalescent event counts
ncs = sampRuns*nsegs - 1;

% Main coalescent data
tcoals = cell(1, nRuns); coalLins = tcoals;
for i = 1:nRuns
    % Coalescent times for M trees
    tcoals{i} = csvread(['coalT' num2str(i) '.csv']);
    % Corresponding lineage counts
    coalLins{i} = csvread(['coalL' num2str(i) '.csv']);
end

cd(thisDir);

% Normalised k and nGrp based on wave
nwave = size(wave, 1);
nGrpnorm = 2.^(0:nwave-1);
% Larger k as idTraj decreases
knorm = nGrpnorm(end:-1:1);

% True k and nGrp for each run
idTrue = unique(idTraj);
ktrue = knorm(idTrue)*sampRuns;
nGrpTrue = nGrpnorm(idTrue);

% Space of k to choose from at each sample run
kspace = cell(1, nRuns);
for i = 1:nRuns
    %kspace{i} = 5:5:ncs(i);
    kspace{i} = sort(knorm*sampRuns(i));
end

% Plot all possible waves
if allFigs
    figure;
    for i = 1:nwave
        subplot(ceil(nwave/2), 2, i);
        stem(1:nsegs, wave(i, :), 'linewidth', 2, 'markersize', 10);
        xlabel(num2str(i));
        box off; grid off;
    end
end


%% Model select from coalescent events

% Performance for each run averaged over replicates
msek = zeros(nRuns, nmet); kavg = msek;
nGrpavg = msek; Pdet = msek;

for ii = 1:nRuns
    % Coalescent data for replicates
    tcoal = tcoals{ii}; lcoal = coalLins{ii};
    % Local kspace at this run
    k = kspace{ii}; k0 = ktrue(ii);
    
    % No. coalescences in each segment
    cshift = zeros(M, nsegs);
    % Samples introduced each segment (fixed)
    sampIntro = sampRuns(ii)*ones(1, nsegs)';
    
     % Model selection criteria
    criteria = cell(1, M); nGrp = criteria;
    kbest = zeros(M, nmet); nGrpbest = kbest;
    for i = 1:M
        % Main function for selection models
        [criteria{i}, kbest(i, :), nGrpbest(i, :), nGrp{i}, cshift(i, :), ~] = ...
            sqWaveFn(tcoal(:, i), lcoal(:, i), k, nwave, sampIntro, V, tsamp);
    end
    
    % Square (normalised) error to ktrue
    kerr2 = (kbest/k0 - 1).^2;
    msek(ii, :) = mean(kerr2);
    % Avg chosen k and nGrp for each metric
    kavg(ii, :) = mean(kbest);
    nGrpavg(ii, :) = mean(nGrpbest);
    
    % Convert nGrps to an idTraj value
    idEst =  log2(nGrpbest) + 1;
    % Prob of correct detection 
    detID = idEst == idTraj;
    Pdet(ii, :) = sum(detID)/M;
end


% Log time
tsim = toc/60;
disp(['Execution time: ' num2str(tsim) ' mins']);