% Model selection over skyline segments with increasing data
clearvars; clc;
close all; tic;

% Assumptions and notes
% - uses MDL, BIC etc over skylines with increasing samples
% - underlying log N(t) trajectories fixed
% - examine how optimal k or nGrp changes with increasing data

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Home and folder to save/load
savetrue = 0;
folload = 'samps';
folsave = 'sampsRes';
% Figures to plot
allFigs = 0;

% Metrics to be evaluated
metrics = {'$-\log L$', 'BIC', 'FIA', 'QK', 'AIC'};
nmet = length(metrics);

%% Simulated coalescent data

% Cyclic (exp) and bottleneck trajectories 
type = 1;
trajNames = {'cyclic', 'bottle', 'boom', 'steep'};
trajChoice = trajNames{type};

% Set data source
dataStr = folload;
% Read data generated from phylodyn package in R
thisDir = cd;
cd(['using phylodyn/' dataStr '/' trajChoice 'Samps_17']);

% Number of coalescences and samples (fixed)
nc = csvread('nc.csv');
ns = csvread('ns.csv');

% No. runs (data size) 
nRuns = length(ns);
if nRuns ~= length(nc)
    error('Coalescences and samples do not match');
end

% Trajectory (true) log N(t)
Nt = csvread('trajy.csv');
t = csvread('trajt.csv');
logNt = log(Nt);
% TMRCAs
tmax = csvread('tmax.csv');

% Load data per sampled tree
tcoal = cell(1, nRuns); coalLin = tcoal;
sampIntro = tcoal; tsamp = tcoal;
for i = 1:nRuns
    % Coalescent times and numbers (per run)
    tcoal{i} = csvread(['coaltimes' num2str(i) '.csv']);
    coalLin{i} = csvread(['coalLin' num2str(i) '.csv']);
    % Sample times and numbers (per run)
    tsamp{i} = csvread(['samptimes' num2str(i) '.csv']);
    sampIntro{i} = csvread(['sampIntro' num2str(i) '.csv']);
end
cd(thisDir);

% Heterochronous skyline statistics
icoal = cell(1, nRuns); icoalSum = icoal; 
compCoal = icoal; tLin = icoal; dtLin = icoal;
for i = 1:nRuns
    [icoal{i}, icoalSum{i}, ~, compCoal{i}, tLin{i}, dtLin{i}] = computeCoal(tcoal{i},...
        tsamp{i}, coalLin{i}, sampIntro{i});
end


%% Model select across grouped skylines with loci

% All possible sets of group sizes
kset = cell(1, nRuns); lenks = zeros(1, nRuns);
for j = 1:nRuns
    k = 1:nc(j); %k = k(rem(nc(j), k) == 0);
    k = k(k >= 4); kset{j} = k;
    lenks(j) = length(k);
end
% Max for FIA volume
Nmax = 10^4; V = log(Nmax);

% Find intersect of all k values
k = kset{1};
for j = 2:nRuns
    k = intersect(k, kset{j});
end
lenk = length(k);

% Compute individual skyline complexities then combine
criteria = cell(1, nRuns); nGrps = zeros(lenk, nRuns);
logNavg = cell(lenk, nRuns); tEnd = logNavg;
for j = 1:nRuns
    % Model selection criteria for each grouping
    lik = zeros(1, lenk); fia = lik; bic = lik; qian = lik; aic = lik;
    
    % Skyline plots for various k, and their log-likelihoods
    for i = 1:lenk
        % Sklyine MLE function
        [logNavg{i, j}, ~, nGrps(i, j), tEnd{i, j}, idGrp] = getGrpSkyLog(k(i), compCoal{j}, nc(j),...
            icoal{j}, icoalSum{j}, dtLin{j}, tLin{j});
        
        % Max log-likelihood and selection criteria
        [lik(i), ~, fia(i), bic(i), qian(i), aic(i)] = getLogNModSel2(nc(j), nGrps(i, j),...
            idGrp, V, compCoal{j}, icoal{j});
    end
    % Combine criteria, get best selected model id
    criteria{j} = [-lik', bic', fia', qian', aic'];
end

% % Num groups should be unique
% nGrp = zeros(1, lenk);
% for i = 1:lenk
%     nGrp(i) = unique(nGrps(i, :));
% end

% Successively sum criteria and pick best
critSum = zeros(size(criteria{1}));
critComb = cell(1, nRuns);
modVal = critComb; modID = critComb;
kbest = critComb; nGrpbest = critComb;
for j = 1:nRuns
    % Combine across loci
    critSum = critSum + criteria{j};
    critComb{j} = critSum;
    % Select best model
    [modVal{j}, modID{j}] = min(critSum);
    % Best k and nGrp for skyline selection
    kbest{j} = k(modID{j}); 
    nGrpbest{j} = nGrps(modID{j}, j);
end

% Convert to matrices
nGrpbest = cell2mat(nGrpbest); nGrpbest = nGrpbest';
kbest = cell2mat(kbest');

% Best FIA and lik trajectories
dataid = [1 nRuns];
idfia = find(strcmp('FIA', metrics));
idlik = 1; % always start with lik
modfia = [modID{1}(idfia), modID{end}(idfia)];
modlik = [modID{1}(1), modID{end}(1)];

% Get trajectories at these data extremes
logNfia = cell(1, 2); logNlik = logNfia;
tEndfia = logNfia; tEndlik = logNfia;
logNfiat = zeros(2, length(t)); logNlikt = logNfiat;
for i = 1:2
    % Raw trajectories
    logNfia{i} = logNavg{modfia(i), dataid(i)};
    logNlik{i} = logNavg{modlik(i), dataid(i)};
    tEndfia{i} = tEnd{modfia(i), dataid(i)};
    tEndlik{i} = tEnd{modlik(i), dataid(i)};
    % Add limits for plotting (stairs)
    logNfia{i} = [logNfia{i}(1) logNfia{i}];
    tEndfia{i} = [0 tEndfia{i}];
    logNlik{i} = [logNlik{i}(1) logNlik{i}];
    tEndlik{i} = [0 tEndlik{i}];
    % Interpolate trajectories
    logNfiat(i, :) = getInterp(tEndfia{i}, logNfia{i}, t, logNt);
    logNlikt(i, :) = getInterp(tEndlik{i}, logNlik{i}, t, logNt);
end

%% Visualisation and results storage

% Best group choices
figure;
subplot(2, 1, 1);
plot(nc, nGrpbest, '.-', 'linewidth', 2, 'MarkerSize', 30);
box off; grid off;
legend(metrics, 'location', 'best');
xlabel('$m$');
ylabel('$p^*$');
subplot(2, 1, 2);
plot(nc, kbest, '.-', 'linewidth', 2, 'MarkerSize', 30);
box off; grid off;
legend(metrics, 'location', 'best');
xlabel('$m$');
ylabel('$k^*$');

% Best FIA trajectory at min and max data
figure;
j = 1; hax = zeros(1, 4);
for i = [1 3]
    hax(i) = subplot(2, 2, i);
    plot(t, logNt, 'k', 'linewidth', 2);
    hold on;
    stairs(tEndlik{j}, logNlik{j}, 'c', 'linewidth', 2);
    if j == 2
        xlabel('$t$ (units)');
    end
    hold off; grid off; box off;
    j = j + 1;
    xlim([0 150]);
end
j = 1;
for i = [2 4]
    hax(i) = subplot(2, 2, i);
    plot(t, logNt, 'k', 'linewidth', 2);
    hold on;
    stairs(tEndfia{j}, logNfia{j}, 'c', 'linewidth', 2);
    if j == 2
        xlabel('$t$ (units)');
    end
    hold off; grid off; box off;
    j = j + 1;
    xlim([0 150]);
end
linkaxes(hax, 'xy');

% Log time
tsim = toc/60;
disp(['Execution time: ' num2str(tsim) ' mins']);