% Model selection over skylines for different square waves
clearvars; clc;
close all; tic;

% Assumptions and notes
% - simple folder structure, 1 wave per folder
% - square wave model selection, different periods

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Figures to plot, data to save
allFigs = 0; savetrue = 1;

% Home and folders to load
thisDir = cd;
mainFol = 'using phylodyn/sq_double/';
% Data controlling all sims
cd(mainFol);
files = dir('sqwaves_*');
nwave = length(files);
cd(thisDir);

% Metrics to be evaluated
metrics = {'$-\log L$', 'BIC', 'FIA', 'QK', 'AIC'};
nmet = length(metrics);

% Max for FIA volume
Nmax = 1000; V = log(Nmax);
disp(['Nmax assumed for FIA = ' num2str(Nmax)]);

% Store results
nGrps = cell(1, nwave); Pdet = nGrps;
nGrpTrue = zeros(1, nwave);

% Run model selection across all files
for i = 1:nwave
    % File path to load
    pathload = [mainFol files(i).name];
    % Main model selection function
    [nGrps{i}, Pdet{i}, nGrpTrue(i)] = sqWvFixFn(thisDir, V, nmet, pathload);
end

% Some fixed variables across all trajectories
cd(pathload);
wave = csvread('wave.csv'); nRuns = csvread('numRuns.csv');
M = csvread('numTraj.csv'); sampRuns = csvread('sampRuns.csv');
tsamp = csvread('samptimes.csv'); nsegs = length(tsamp);
cd(thisDir);

% Shift time of samples
tau = unique(diff(tsamp));

% Overall P(true)
Psum = Pdet{1};
for i = 2:nwave
    Psum = Psum + Pdet{i};
end
Psum = Psum/nwave;
% Best metrics via Psum
Pbest = sum(Psum)/nRuns;
[~, rankID] = sort(Pbest);
ranks = metrics(rankID(end:-1:1));
disp(ranks);

% Prob detection with samples
figure;
plot(sampRuns-1, Psum(:, 2:end), '.-',  'linewidth', 2, 'MarkerSize', 30);
xlabel(['$m$ ($\tau$ = ' num2str(tau) ')']);
ylabel('P(correct)');
grid off; box off;
legend(metrics(2:end), 'location', 'best');
%ax1 = axes('Position',[0.45 0.2 0.3 0.3]);
if savetrue
    saveas(gcf, ['Ptrue' num2str(Nmax)], 'fig');
end


if allFigs
    % Plot all possible waves
    figure;
    for i = 1:nwave
        subplot(ceil(nwave/2), 2, i);
        stem(1:nsegs, wave(i, :), 'filled', 'linewidth', 2, 'markersize', 7);
        xlabel(num2str(i));
        box off; grid off;
    end
end


% Log time
tsim = toc/60;
disp(['Execution time: ' num2str(tsim) ' mins']);

% Save data
if savetrue
    save(['sqRes_' num2str(nsegs) '_' num2str(M) '_' num2str(Nmax) '.mat']);
end
