% Model selection over skyline segments
clearvars; clc;
close all; tic;

% Assumptions and notes
% - examine single cases - overfitting and underfitting
% - works directly in log population size
% - select over skyline segment number using robust criteria
% - test with several criteria: MDL, BIC, likelihood
% - examine monotone and cyclic dynamics

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Home and folder to save/load
savetrue = 0;
folload = 'test';
folsave = 'testRes';
% Figures to plot
allFigs = 0;

% Metrics to be evaluated
metrics = {'$-\log L$', 'BIC', 'FIA', 'QK', 'AIC'};
nmet = length(metrics);

%% Simulated coalescent data

% Possible trajectories to select
type = 8;
trajNames = {'logis', 'exp', 'steep', 'unif_low', 'unif_high', 'boom', 'cyc', 'bottle', 'mesa', 'binary'};
trajChoice = trajNames{type};

% Set data source
dataStr = folload;
% Read data generated from phylodyn package in R
thisDir = cd;
cd(['using phylodyn/' dataStr '/' trajChoice '_test']);

% Coalescent and sample times
tcoal = csvread('coaltimes.csv');
tsamp = csvread('samptimes.csv');
% Lineages driving each coalescent
coalLin = csvread('coalLin.csv');
% Trajectory
Nt = csvread('trajy.csv');
t = csvread('trajt.csv');
% Heterochronous tree
tree = phytreeread('tree.txt');
% Samples introduced at each sample time
sampIntro = csvread('sampIntro.csv');
cd(thisDir);

% Combine coalescent and sample times
tLin = sort([tcoal' tsamp']);
len = length(tLin);

% Define num of samples and coalescents
nc = length(tcoal); ns = length(tsamp);

% Get whether a coalescent or sample time (complementary sets)
isamp = ismember(tLin, tsamp);
icoal = ismember(tLin, tcoal);

%% Heterochronous LTT construction

% Construct LTT, must start with sample
nLin = zeros(size(tLin));
nLin(1) = sampIntro(1);
% Check started with at least 2 samples
if nLin(1) < 2
    error('Started with under 2 samples');
end

% Counters for samples and coalescents
c_coal = 0; c_samp = 1;
for j = 2:len
    if isamp(j)
        % Sample event has occurred
        c_samp = c_samp + 1;
        nLin(j) = nLin(j-1) + sampIntro(c_samp);
    else
        % Coalescent event has occurred
        c_coal = c_coal + 1;
        nLin(j) = nLin(j-1) - 1;
    end
end
% Lineages that drive stated events (nLin is after events originally)
nLinPre = nLin;
nLinPre(icoal) = nLinPre(icoal) + 1;
nLinPre(isamp) = nLinPre(isamp) - sampIntro';

% Check have lineage pre-post relationship right
if ~ all(nLinPre(2:end) == nLin(1:end-1))
    error('The nLin and nLinPre relationship is wrong');
end
% Check that all events are used and that lineages match
if c_samp ~= ns || c_coal ~= nc || ~all(coalLin' == nLinPre(icoal))
    error('Computed LTT incorrectly');
end

% Number of intervals and lengths (times)
dtLin = diff(tLin); numInt = length(dtLin);
% The lineage count over dtLin(i) is nLinPre(i)
lendt = length(dtLin); cumInt = cumsum(dtLin);
% Cumulative sum of coalescents
icoalSum = cumsum(icoal);

%% Construct grouped skyline estimator

% Set group sizes (non-repeated)
k = 1:nc; k = k(rem(nc, k) == 0);
k = k(k >= 1); lenk = length(k);
disp(['Group sizes span ' num2str(k(1)) ' to ' num2str(k(end))]);

% Compute components for grouped skyline on each interval
alpha = 0.5*nLin(1:lendt).*(nLin(1:lendt) - 1);
compCoal = alpha.*dtLin;
% Max for FIA volume
Nmax = 10^6; V = log(Nmax);

% Store results from skyline groups
c = cell(1, lenk); logNavg = c; tEnd = c; 
likSeg = c; nGrp = zeros(1, lenk); idGrp = c;
% Model selection criteria for each grouping
lik = nGrp; fia = nGrp; bic = nGrp; qian = nGrp; aic = nGrp;

% Skyline plots for various k, and their log-likelihoods
for i = 1:lenk
    % Sklyine MLE function
    [logNavg{i}, c{i}, nGrp(i), tEnd{i}, idGrp{i}] = getGrpSkyLog(k(i), compCoal, nc,...
        icoal, icoalSum, dtLin, tLin);
    
    % Max log-likelihood and selection criteria
    [lik(i), likSeg{i}, fia(i), bic(i), qian(i), aic(i)] = getLogNModSel2(nc, nGrp(i),...
        idGrp{i}, V, compCoal, icoal);
end

% Combine criteria, get best selected model id
criteria = [-lik', bic', fia', qian', aic']; 
[modVal, modID] = min(criteria);
% Best k and nGrp for skyline selection
kbest = k(modID); nGrpbest = nGrp(modID);

% Display selections
disp(metrics);
disp('k = '); disp(kbest);
disp('nGrp =  '); disp(nGrpbest);

%% Interpolation and plotting

% Vectors for grouped skylines
logNavg0 = logNavg; tEnd0 = tEnd;
for i = 1:lenk
    logNavg0{i} = [logNavg{i}(1) logNavg{i}];
    tEnd0{i} = [0 tEnd{i}];
end
logNt = log(Nt);

% Interpolated vectors
logNavgt = zeros(lenk, length(t));
% Interpolation to Nt grid and stats
for i = 1:lenk
    logNavgt(i, :) = getInterp(tEnd0{i}, logNavg0{i}, t, logNt);
end


if allFigs
    % LTT with indicators of sampling epochs, and trace of events
    figure;
    subplot(3, 1, 1:2); % doubles size of this subplot
    stairs(tLin, nLin, 'color', 'c', 'linewidth', 2);
    ylabel('LTT');
    xlim([tLin(1) tLin(end)]);
    subplot(3, 1, 3);
    stem(tsamp, ones(size(tsamp)), 'color', grey1, 'Marker', 'none');
    hold on;
    stem(tcoal, ones(size(tcoal)), 'color', 'c', 'Marker', 'none');
    hold off; grid off; box off;
    ylim([0 1.1]);
    h = gca; h.YTick = [0 1];
    xlim([tLin(1) tLin(end)]);
    xlabel('time into past');
    
    % MLE skylines with varying k
    figure;
    plot(t, logNt, 'k--', 'linewidth', 2);
    hold all;
    for i = 1:lenk
        stairs(t, logNavgt(i, :), 'linewidth', 2);
    end
    hold off; grid off; box off;
    xlim([tLin(1) tLin(end)]);
    ylabel('$\hat{N}$');
    xlabel('time into past');
    
    % Segment log-likelihoods with k
    figure;
    hold all
    for i = 1:lenk
        plot(tEnd{i}, likSeg{i}, 'linewidth', 2);
    end
    hold off; grid off; box off;
    xlabel('time into past');
    ylabel('segment $\log L$');
    xlim([tLin(1) tLin(end)]);    
end

% Model selection criteria (ordered for nGrp x axis)
figure;
hold all;
plot(criteria(end:-1:1, :), 'linewidth', 2);
for i = 1:nmet
    plot(lenk-modID(i)+1, modVal(i), 'o', 'markersize', 10, 'linewidth', 2);
end
xlabel('$p$'); ylabel('criteria');
hold off; grid off; box off;
h = gca; xlim([1 lenk]);
h.XTick = 1:length(nGrp(end:-1:1));
h.XTickLabel = nGrp(end:-1:1);
legend(metrics, 'location', 'best');

% Best trajectories
figure;
hold all;
plot(t, logNt, 'k--', 'linewidth', 2);
for i = 1:nmet
    stairs(tEnd0{modID(i)}, logNavg0{modID(i)}, 'linewidth', 2);
end
hold off; grid off; box off;
legend(['true', metrics], 'location', 'best');
xlim([tLin(1) tLin(end)]);
ylabel('$\log(\hat{N})$');
xlabel('time into past');

% Under and over-fitting plots
figure;
subplot(1, 2, 1);
plot(t, logNt, 'k', 'linewidth', 2);
hold on;
stairs(tEnd0{end}, logNavg0{end}, 'c', 'linewidth', 2);
hold off; grid off; box off;
xlim([tLin(1) tLin(end)]);
ylabel('$\log(\hat{N})$');
xlabel('$t$ (units)');
subplot(1, 2, 2);
plot(t, logNt, 'k', 'linewidth', 2);
hold on;
stairs(tEnd0{1}, logNavg0{1}, 'c', 'linewidth', 2);
hold off; grid off; box off;
xlim([tLin(1) tLin(end)]);
ylabel('$\log(\hat{N})$');
xlabel('$t$ (units)');

% Log time
tsim = toc/60;
disp(['Execution time: ' num2str(tsim) ' mins']);