% Batch fraserEst for simulating and estimating epidemics
clearvars;
clc; close all;

% Assumptions and notes
% - has a single fixed Rtrue
% - examines the fixed width k, FIA and BIC
% - simulates N epidemics via renewal models
% - acausal Fraser estimator used in model selection

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Time code
tic;
% Diagnostic figs
diagFigs = 0;
% Save data and folder
saveTrue = 0;
thisDir = cd;
saveFol = 'fixedkdata';

%% Setup each of N simulations

% No. epidemic runs
N = 200;
% Time for epidemic observation (days)
tday0 = 1:201; nday0 = length(tday0);
% Actual no. data points
T = nday0 - 1;

% Metrics to be evaluated
metrics = {'lik', 'bic', 'fia'};
nmet = length(metrics);

% Parameters for model selection
ks = 20:T; nks = length(ks);
disp(['k from ' num2str(ks(1)) ' to ' num2str(ks(end))]);

% Fixed width piecewise functions (factors of T)
%ktrues = 1:T; ktrues = ktrues(rem(T, ktrues) == 0);
ktrues = [25 50 100 200];
ktrues = ktrues(ktrues >= 10); lentrue = length(ktrues);

% Maximum R for FIA integral domain
Rmax = 100;

% Fraction of times get k right, and MSE
kfrac = zeros(nmet, lentrue); kmse = kfrac;

%% Model select at each true k, over all possible k values

for ii = 1:lentrue
    % Fixed width windows
    ktrue = ktrues(ii);
    disp(['True k = ' num2str(ktrue)]);
    
    % Change points (as day 1 not estimated start at 2)
    tchange = ktrue+1:ktrue:nday0;
    % Truncate to integer number of ktrue days
    nday = tchange(end); tday = 1:nday;
    
    % Segments with different R
    nSegs = length(tchange);
    Rsegs = unifrnd(0.5, 4, [1 nSegs]);
    
    % Vector for true R
    Rtrue = zeros(size(tday));
    for i = 1:nSegs
        % R indices for current segment
        if i == 1
            idR = 1:tchange(i);
        elseif i == nSegs
            idR = tchange(i-1)+1:nday;
        else
            idR = tchange(i-1)+1:tchange(i);
        end
        % Values over this segment
        Rtrue(idR) = Rsegs(i);
    end
    
    % Serial distribution over all time (type 1 is geom and 2 erlang)
    distType = 2;
    serial = serialDistrs(nday, distType);
    % Single omega controlling distribution
    omega = 14.2;
    % Actual distribution over all tday
    Pomega = serial(1/omega);
    Pomega = Pomega/sum(Pomega);
    
    % Best R estimates from each criteria
    Rlik = cell(1, N); Rbic = Rlik; Rfia = Rbic;
    % Selected k and no. groups
    kbest = zeros(nmet, N); nGrpbest = kbest;
    
    % Run each trajectory and estimate R with model selection
    i = 1;
    while i <= N
        % Main code - estimates per trajectory
        [R0, R1, R2, kval, nGval, warnTrue] = fraserEstFn(Rtrue, T, Pomega, ks, nks, Rmax);
        % Store results only if run is valid
        if ~warnTrue
            % Runs with non-zero FI values (conditioning on epidemic survival)
            Rlik{i} = R0; Rbic{i} = R1; Rfia{i} = R2;
            kbest(:, i) = kval; nGrpbest(:, i) = nGval;
            % Update successful run
            i = i + 1;
            disp(['Completed: ' num2str(i-1) ' of ' num2str(N)]);
        end
    end
    
    % Convert R estimates to array
    Rlik = cell2mat(Rlik'); 
    Rbic = cell2mat(Rbic'); 
    Rfia = cell2mat(Rfia');
    % Quantiles and median estimates
    Q = cell(1, nmet);
    Q{1} = quantile(Rlik, [0.025, 0.5, 0.975]);
    Q{2} = quantile(Rbic, [0.025, 0.5, 0.975]);
    Q{3} = quantile(Rfia, [0.025, 0.5, 0.975]);
    
    % Get counts on unique k and nGrp values for each metric
    kvals = cell(1, nmet); kcount = kvals;
    nGvals = kvals; nGcount = kvals;
    for i = 1:nmet
        % Unique values
        val1 = unique(kbest(i, :));
        val2 = unique(nGrpbest(i, :));
        count1 = zeros(size(val1)); count2 = zeros(size(val2));
        for j = 1:length(val1)
            count1(j) = length(find(kbest(i, :) == val1(j)));
        end
        for j = 1:length(val2)
            count2(j) = length(find(nGrpbest(i, :) == val2(j)));
        end
        kcount{i} = count1; kvals{i} = val1;
        nGcount{i} = count2; nGvals{i} = val2;
    end
    
    % Stem plot of unique k values
    figure;
    for i = 1:nmet
        subplot(ceil(nmet/2), 2, i);
        h = stem(kvals{i}, kcount{i}/N);
        h.LineWidth = 2; h.Marker = '.';
        h.MarkerSize = 20;
        hold on;
        % True k value
        plot([ktrue ktrue], [0 1], 'k--', 'linewidth', 2);
        hold off; grid off; box off;
        h = gca; h.YLim = [0 1];
        ylabel(metrics{i});
        xlim([ks(1) ks(end)]);
        xlabel('$k$');
    end
    if saveTrue
        cd(saveFol);
        saveas(gcf, ['kstem_' num2str(ktrue)], 'fig');
        cd(thisDir);
    end
    
    % Best estimates of R
    figure;
    for i = 1:nmet
        subplot(nmet, 1, i);
        hold on;
        stairs(tday(2:end), Rtrue(2:end), 'k', 'linewidth', 2);
        stairs(tday(2:end)', Q{i}', 'linewidth', 2);
        hold off; grid off; box off;
        if i == nmet
            xlabel('$\hat{R}_t$ vs $t$');
        end
        ylabel(metrics(i));
    end
    if saveTrue
        cd(saveFol);
        saveas(gcf, ['R_' num2str(ktrue)], 'fig');
        cd(thisDir);
    end
    
    % Fraction right and normalised mse for each metric
    kfrac(1:nmet, ii) = sum(kbest' == ktrue)/N;
    kmse(1:nmet, ii) = (mean(kbest')./ktrue - 1).^2;
    
end

% Time and save data
save('kfracData.mat');
tsim = toc/60;
disp(['Run time = ' num2str(tsim)]);

