% Batch fraserEst for simulating and estimating epidemics
clearvars;
clc; close all;

% Assumptions and notes
% - has a different Rtrue for each N
% - examines the fixed width k, FIA and BIC
% - simulates N epidemics via renewal models
% - acausal Fraser estimator used in model selection

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Time code
tic;
% Diagnostic figs
diagFigs = 0;
% Save data and folder
saveTrue = 1;
thisDir = cd;
saveFol = 'skipdata400';

%% Setup each of N simulations

% No. epidemic runs
N = 10000;
% Time for epidemic observation (days)
tday0 = 1:401; nday0 = length(tday0);
% Actual no. data points
T = nday0 - 1;

% Metrics to be evaluated
metrics = {'lik', 'bic', 'fia', 'qian'};
nmet = length(metrics);

% Parameters for model selection
%ks = 20:T; nks = length(ks);
%disp(['k from ' num2str(ks(1)) ' to ' num2str(ks(end))]);

% Parameters for model selection
ks = 1:T; ks = ks(rem(T, ks) == 0); ks = ks(ks >= 10);
nks = length(ks);
disp(['k from ' num2str(ks(1)) ' to ' num2str(ks(end))]);

% Force grouping to be successive coarsening
ks = ks([1 3 5 7 10]); nks = length(ks);
ktrues = ks; lentrue = nks;

% Fixed width piecewise functions (factors of T)
%ktrues = 1:T; ktrues = ktrues(rem(T, ktrues) == 0);
%ktrues = ktrues(ktrues >= 30); lentrue = length(ktrues);

% Maximum R for FIA integral domain
Rmax = 10000;

% Fraction of times get k right, and MSE
kfrac = zeros(nmet, lentrue); kmse = kfrac;
% Best R estimate MSE from each criteria
Mlik = zeros(lentrue, N); Mbic = Mlik; Mfia = Mlik; Mqian = Mlik;

%% Model select at each true k, over all possible k values

for ii = 1:lentrue
    % Fixed width windows
    ktrue = ktrues(ii);
    disp(['True k = ' num2str(ktrue)]);
    
    % Change points (as day 1 not estimated start at 2)
    tchange = ktrue+1:ktrue:nday0;
    % Truncate to integer number of ktrue days
    nday = tchange(end); tday = 1:nday;
    % Segments with different R
    nSegs = length(tchange);
    
    % Serial distribution over all time (type 1 is geom and 2 erlang)
    distType = 2;
    serial = serialDistrs(nday, distType);
    % Single omega controlling distribution
    omega = 14.2;
    % Actual distribution over all tday
    Pomega = serial(1/omega);
    %Pomega = Pomega/sum(Pomega);
        
    % Selected k and no. groups
    kbest = zeros(nmet, N); nGrpbest = kbest;
    
    % Run each trajectory and estimate R with model selection
    i = 1;
    while i <= N
        % Main code - estimates per trajectory
        [R0, R1, R2, R3, kval, nGval, warnTrue, Rtrue] = fraserEstFnIndep(T,...
            Pomega, ks, nks, Rmax, nSegs, tday, tchange);
        % Store results only if run is valid
        if ~warnTrue
            % Runs with non-zero FI values (conditioning on epidemic survival)
            kbest(:, i) = kval; nGrpbest(:, i) = nGval;
            % Mean R estimates and MSE to Rtrue
            Mlik(ii, i) = mean((R0./Rtrue(2:end) - 1).^2);
            Mbic(ii, i) = mean((R1./Rtrue(2:end) - 1).^2);
            Mfia(ii, i) = mean((R2./Rtrue(2:end) - 1).^2);
            Mqian(ii, i) = mean((R3./Rtrue(2:end) - 1).^2);
            
            % Update successful run
            i = i + 1;
            %disp(['Completed: ' num2str(i-1) ' of ' num2str(N)]);
        end
    end
    
    % Get counts on unique k and nGrp values for each metric
    kvals = cell(1, nmet); kcount = kvals;
    nGvals = kvals; nGcount = kvals;
    for i = 1:nmet
        % Unique values
        val1 = unique(kbest(i, :));
        val2 = unique(nGrpbest(i, :));
        count1 = zeros(size(val1)); count2 = zeros(size(val2));
        for j = 1:length(val1)
            count1(j) = length(find(kbest(i, :) == val1(j)));
        end
        for j = 1:length(val2)
            count2(j) = length(find(nGrpbest(i, :) == val2(j)));
        end
        kcount{i} = count1; kvals{i} = val1;
        nGcount{i} = count2; nGvals{i} = val2;
    end
    
    % Stem plot of unique k values
    figure;
    for i = 1:nmet
        subplot(ceil(nmet/2), 2, i);
        h = stem(kvals{i}, kcount{i}/N);
        h.LineWidth = 2; h.Marker = '.';
        h.MarkerSize = 20;
        hold on;
        % True k value
        plot([ktrue ktrue], [0 1], 'k--', 'linewidth', 2);
        hold off; grid off; box off;
        h = gca; h.YLim = [0 1];
        ylabel(metrics{i});
        xlim([ks(1) ks(end)]);
        xlabel('$k$');
    end
    if saveTrue
        cd(saveFol);
        saveas(gcf, ['kstem_' num2str(ktrue)], 'fig');
        cd(thisDir);
        h = gcf; close(h);
    end
    
    % Fraction right and normalised mse for each metric
    kfrac(1:nmet, ii) = sum(kbest' == ktrue)/N;
    kerr = (kbest - ktrue)/ktrue; ksq = kerr.^2;
    kmse(1:nmet, ii) = mean(ksq, 2);
end

% MSE averaged across runs
MSruns = [mean(Mlik, 2) mean(Mbic, 2)  mean(Mfia, 2) mean(Mqian, 2)];

% Metrics of correct model selection
figure;
h = stem(kfrac');
for i = 1:nmet
    h(i).LineWidth = 2; h(i).Marker = '.';
    h(i).MarkerSize = 50;
end
h = gca; h.XTickLabel = ktrues;
xlabel('$k$ (true)');
ylabel('$\sum 1(k^* = k)$');
box off; grid off;
legend(metrics, 'location', 'best');
if saveTrue
    cd(saveFol);
    saveas(gcf, ['kfrac_' num2str(T)], 'fig');
    cd(thisDir);
    h = gcf; close(h);
end

figure;
h = stem(kmse');
for i = 1:nmet
    h(i).LineWidth = 2; h(i).Marker = '.';
    h(i).MarkerSize = 50;
end
h = gca; h.XTickLabel = ktrues; h.YScale = 'log';
xlabel('$k$ (true)');
ylabel('E$(k^* - k)^2$');
box off; grid off;
legend(metrics, 'location', 'best');
if saveTrue
    cd(saveFol);
    saveas(gcf, ['kmse_' num2str(T)], 'fig');
    cd(thisDir);
    h = gcf; close(h);
end

% Time and save data
cd(saveFol);
save('kfracData.mat');
cd(thisDir);
tsim = toc/60;
disp(['Run time = ' num2str(tsim)]);

