% Simulate multiple epidemics and investigate renewal model selection
clearvars;
clc; close all;

% Assumptions and notes
% - does not include I(0) = 10 in analysis (even at offst = 0)
% - truncate epidemics from startup issues
% - each epidemic used as data on same R profile
% - simulates N epidemics, uses multi-stream likelihood

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Time code
tic;
% Save data and folder
saveTrue = 0;
thisDir = cd;
saveFol = 'indivRuns';
% Choice of true R
scenNo = 0;

%% Setup and simulate N epidemics from a fixed R

% Time for epidemic observation (days)
tday = 1:401; nday = length(tday);
% No. epidemic runs from same R
N = 10; T = nday - 1;

% Add offset for initial zeros
offst = 20;

% Metrics to be evaluated
metrics = {'lik', 'bic', 'fia', 'qian', 'aic'};
nmet = length(metrics);

% Parameters for model selection
ks = 1:T; ks = ks(rem(T, ks) == 0); ks = ks(ks >= 10);
nks = length(ks);

% Groups corresponding to ks
grps = @(k) getGrpSz(k, T);
grps = arrayfun(grps, ks);
% Group numbers and likelihood on each k
nGrps = zeros(size(ks));  grpSzs = cell(1, nks);
for i = 1:nks
    nGrps(i) = grps(i).nGrp;
    grpSzs{i} = grps(i).grpSz;
end
% Maximum R for FIA integral domain
Rmax = 100;

% Piecewise-constant functions with fixed k
if scenNo == 0
    % Force grouping to be successive coarsening
    %ks = ks([1 3 5 7 10]); nks = length(ks);
    
    % Fixed width piecewise functions
    ktrue = 200;
    %ktrue = datasample(ks, 1); 
    idtrue = find(ks == ktrue);
    disp(['True k = ' num2str(ktrue)]);
    
    % Change points (start at offset)
    tchange = ktrue+offst+1:ktrue:nday+offst;
    tchange = [offst tchange];
    % Ensure integer number of ktrue days
    if nday ~= tchange(end) - offst
        error('Check the day groupings');
    end
    
    % Segments with different R, force intial R > 1
    nSegs = length(tchange);
    Rsegs = [1.2 unifrnd(0.75, 1.75, [1 nSegs-1])];
    % Vector for true R
    Rtrue = zeros(1, nday + offst);
    for i = 1:nSegs
        % R indices for current segment
        if i == 1
            idR = 1:tchange(i);
        elseif i == nSegs
            idR = tchange(i-1)+1:nday+offst;
        else
            idR = tchange(i-1)+1:tchange(i);
        end
        % Values over this segment
        Rtrue(idR) = Rsegs(i);
    end
else
    % Get true R based on some profile (offset added)
    Rtrue = getRTrue2(scenNo, 1:nday+offst, nday+offst);
end


% Serial distribution over all time (type 1 is geom and 2 erlang)
distType = 2;
serial = serialDistrs(nday + offst, distType);
% Single omega controlling distribution
omega = 14.2;
% Actual distribution over all tday
Pomega = serial(1/omega);

% Simulate N incidence trajectories from Rtrue
Iday = zeros(N, T+1+offst);
% Infectiousness, Poisson rate
Lam = Iday; rate = Iday;
% Initialise epidemics
Iday(:, 1) = 10;

% Iteratively generate renewal epidemic
j = 1; warnTrue = 0;
while j <= N
    for i = 2:T+1+offst
        % Relevant part of serial distribution
        Pomegat = Pomega(1:i-1);
        %Pomegat = Pomegat/sum(Pomegat);
        % Total infectiousness
        Lam(j, i) = Iday(j, i-1:-1:1)*Pomegat';
        % Rate for ith day incidence
        rate(j, i) = Lam(j, i)*Rtrue(i);
        % Renewal incidence
        Iday(j, i) = poissrnd(rate(j, i));
    end
    
    % Ensure epidemic survives at minimum group size
    jstart = 1 + offst; 
    A0 = zeros(1, nGrps(1)); B0 = A0;
    for l = 1:nGrps(1)
        % End-indices of group
        jstop = jstart + grpSzs{1}(l) - 1;
        ids = jstart:jstop;
        % Incidence and Lam sum over group
        B0(l) = sum(Iday(j, ids));
        A0(l) = sum(Lam(j, ids));
    end
    % Fisher information log product
    FIprod0 = sum(log((A0.^2)./B0));
    if ~(any(isinf(FIprod0) | isnan(FIprod0)))
        % Also test on zero incidence points
        if length(find(Iday(j, :) == 0)) <= nday/4
            % Current trajectory valid
            j = j + 1;
        end
    end
end

% Remove gaps of zeros from all trajectories and truncate
idStart = zeros(1, N);
for j = 1:N
    % Gaps between non-zero indicence values
    zerogaps = diff(find(Iday(j, :) ~= 0));
    % Remove startup sequence of zeros if big
    z1 = zerogaps(1); zrest = zerogaps(2:end);
    if z1 > 8
        % Update incidence and related vectors
        idStart(j) = z1+1;
        % Flag zero incidence regions after startup
        if max(zrest) > 8
            warning('Zero incidences beyond startup');
        end
    else
        % Flag any zero incidence region
        if max(zrest) > 8
            warning('Sequences of zero incidence');
        end
        % Un-truncated day set
        idStart(j) = 2;
    end
end
% Check starting beyond truncation id
if max(idStart) > offst
    error('Offset too small');
else
    % Truncate vectors to have length T
    Rtrue = Rtrue(offst+2:end); rate = rate(:, offst+2:end);
    Iday = Iday(:, offst+2:end); Lam = Lam(:, offst+2:end);
end

%% Examine model selection as add more data

% Best R estimates from each criteria
Rlik = cell(1, N); Rbic = Rlik; Rqian = Rlik; Rfia = Rlik; Raic = Rlik;
% Selected k and no. groups
kbest = zeros(nmet, N); nGrpbest = kbest;
FIsum = zeros(N, nks); lik = FIsum; fia = FIsum; bic = FIsum; aic = FIsum;

% Select best k as data increases
for i = 1:N  
    % Streams considered
    nStream = i;
    IdaySel = Iday(1:i, :); LamSel = Lam(1:i, :);
    
    % MLE/selection from multiple streams
    [Rlik{i}, Rbic{i}, Rfia{i}, Rqian{i}, Raic{i}, kbest(:, i), nGrpbest(:, i),...
        FIsum(i, :), lik(i, :), fia(i, :), bic(i, :), aic(i, :)] = fraserSubMultFn2(T,...
        ks, nks, Rmax, nGrps, grpSzs, IdaySel, LamSel, nStream);
    %disp(['Completed: ' num2str(i) ' of ' num2str(N)]);
end

% Minima from fia, bic and -lik
[vallik, idlik] = min(-lik, [], 2);
[valbic, idbic] = min(bic, [], 2);
[valfia, idfia] = min(fia, [], 2);
[valaic, idaic] = min(aic, [], 2);
% Best k selected
disp('Selected k = ');
disp(kbest);

% MSE with N from likelihood
mseR = zeros(nmet, N);
for i = 1:N
    mseR(1, i) = mean((Rlik{i} - Rtrue).^2);
    mseR(2, i) = mean((Rbic{i} - Rtrue).^2);
    mseR(3, i) = mean((Rfia{i} - Rtrue).^2);
    mseR(4, i) = mean((Rqian{i} - Rtrue).^2);
    mseR(4, i) = mean((Raic{i} - Rtrue).^2);
end

%% Plotting and saving

figure;
plot(tday(2:end), Iday, 'linewidth', 2);
grid off; hold off;
ylabel('incidence curves');
xlabel('time (days)');
xlim([tday(1) tday(end)]);
if saveTrue
    cd(saveFol);
    saveas(gcf, ['incid_' num2str(N) '_' num2str(scenNo)], 'fig');
    cd(thisDir);
end

% Selected R estimate with data for lik and fia
figure;
ax(1) = subplot(2, 2, 1);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rlik{1}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{lik}^{1}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
ax(2) = subplot(2, 2, 2);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rlik{N}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{lik}^{m}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
ax(3) = subplot(2, 2, 3);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rfia{1}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{fia}^{1}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
xlabel('time (days)');
ax(4) = subplot(2, 2, 4);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rfia{N}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{fia}^{m}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
xlabel('time (days)');
linkaxes(ax, 'xy');
if saveTrue
    cd(saveFol);
    saveas(gcf, ['R2_' num2str(N) '_' num2str(scenNo)], 'fig');
    cd(thisDir);
end

% Selected R estimate with data for several criteria
figure;
ax(1) = subplot(4, 2, 1);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rlik{1}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{lik}^{1}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
ax(2) = subplot(4, 2, 2);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rlik{N}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{lik}^{m}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
ax(3) = subplot(4, 2, 3);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rfia{1}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{fia}^{1}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
ax(4) = subplot(4, 2, 4);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rfia{N}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{fia}^{m}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
ax(5) = subplot(4, 2, 5);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rbic{1}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{bic}^{1}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
xlabel('time (days)');
ax(6) = subplot(4, 2, 6);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rbic{N}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{bic}^{m}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
ax(7) = subplot(4, 2, 7);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Raic{1}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{aic}^{1}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
ax(8) = subplot(4, 2, 8);
stairs(2:nday, Rtrue, 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Raic{N}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{aic}^{m}$');
box off; grid off; hold off;
xlim([tday(2) tday(end)]);
xlabel('time (days)');
linkaxes(ax, 'xy');
if saveTrue
    cd(saveFol);
    saveas(gcf, ['Rall_' num2str(N) '_' num2str(scenNo)], 'fig');
    cd(thisDir);
end

% Changes with data of FI, -lik and criteria
figure;
subplot(2, 2, 1);
plot(1:nks, -lik', 'linewidth', 2);
hold on;
plot(idlik, vallik, 'o', 'MarkerSize', 8);
if scenNo == 0
    h = gca;
    plot([idtrue idtrue], h.YLim, 'kx--', 'MarkerSize', 8, 'linewidth', 1);
end
box off; grid off; hold off;
ylabel('$-\log L$');
h = gca;
subplot(2, 2, 2);
hold on;
plot(1:nks, aic', 'linewidth', 2);
plot(idaic, valaic, 'o', 'MarkerSize', 8);
if scenNo == 0
    h = gca;
    plot([idtrue idtrue], h.YLim, 'kx--', 'MarkerSize', 8, 'linewidth', 1);
end
hold off; box off; grid off;
ylabel('AIC');
subplot(2, 2, 3);
plot(1:nks, bic', 'linewidth', 2);
hold on;
plot(idbic, valbic, 'o', 'MarkerSize', 8);
if scenNo == 0
    h = gca;
    plot([idtrue idtrue], h.YLim, 'kx--', 'MarkerSize', 8, 'linewidth', 1);
end
box off; grid off; hold off;
ylabel('BIC');
xlabel('id of $k$');
subplot(2, 2, 4);
plot(1:nks, fia', 'linewidth', 2);
hold on;
plot(idfia, valfia, 'o', 'MarkerSize', 8);
if scenNo == 0
    h = gca;
    plot([idtrue idtrue], h.YLim, 'kx--', 'MarkerSize', 8, 'linewidth', 1);
end
box off; grid off; hold off;
ylabel('FIA');
xlabel('id of $k$');
if saveTrue
    cd(saveFol);
    saveas(gcf, ['criteria_' num2str(N) '_' num2str(scenNo)], 'fig');
    cd(thisDir);
end

% Time and save data
if saveTrue
    cd(saveFol);
    save(['ex_' num2str(N) '_' num2str(scenNo) '.mat']);
    cd(thisDir);
end
tsim = toc/60;
disp(['Run time = ' num2str(tsim)]);

