% Simulate multiple epidemics and investigate renewal model selection
clearvars;
clc; close all;

% Assumptions and notes
% - only works for ktrue that are divisors of nday
% - each epidemic used as data on same R prpofile
% - examines how fixed width k changes with data
% - simulates N epidemics via renewal models
% - multiple stream log-likelihood used

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Time code
tic;
% Diagnostic figs
diagFigs = 0;
% Save data and folder
saveTrue = 0;
thisDir = cd;
saveFol = 'dataAccum';

%% Setup and simulate N epidemics from a fixed R

% Time for epidemic observation (days)
tday = 1:401; nday = length(tday);
% No. epidemic runs from same R
N = 10; T = nday - 1;

% Metrics to be evaluated
metrics = {'lik', 'bic', 'fia', 'qian', 'aic'};
nmet = length(metrics);

% Parameters for model selection
%ks = 20:150;
ks = 1:T; ks = ks(rem(T, ks) == 0); ks = ks(ks >= 10);
nks = length(ks);
disp(['k from ' num2str(ks(1)) ' to ' num2str(ks(end))]);

% Force grouping to be successive coarsening
ks = ks([1 3 5 7 10]); nks = length(ks);

% Groups corresponding to ks
grps = @(k) getGrpSz(k, T);
grps = arrayfun(grps, ks);
% Group numbers and likelihood on each k
nGrps = zeros(size(ks));  grpSzs = cell(1, nks);
for i = 1:nks
    nGrps(i) = grps(i).nGrp;
    grpSzs{i} = grps(i).grpSz;
end

% Fixed width piecewise functions
%ktrue = datasample(ks, 1);
ktrue = 20;
disp(['True k = ' num2str(ktrue)]);
% Maximum R for FIA integral domain
Rmax = 10000;

% Change points (as day 1 not estimated start at 2)
tchange = ktrue+1:ktrue:nday;
% Ensure integer number of ktrue days
if nday ~= tchange(end)
    error('Check the day groupings');
end

% Segments with different R
nSegs = length(tchange);
Rsegs = unifrnd(0.5, 2.5, [1 nSegs]);
% Vector for true R
Rtrue = zeros(size(tday));
for i = 1:nSegs
    % R indices for current segment
    if i == 1
        idR = 1:tchange(i);
    elseif i == nSegs
        idR = tchange(i-1)+1:nday;
    else
        idR = tchange(i-1)+1:tchange(i);
    end
    % Values over this segment
    Rtrue(idR) = Rsegs(i);
end

% Serial distribution over all time (type 1 is geom and 2 erlang)
distType = 2;
serial = serialDistrs(nday, distType);
% Single omega controlling distribution
omega = 14.2;
% Actual distribution over all tday
Pomega = serial(1/omega);

% Simulate N incidence trajectories from Rtrue
Iday = zeros(N, T+1);
% Infectiousness, Poisson rate
Lam = Iday; rate = Iday;
% Initialise epidemics
Iday(:, 1) = 10;

% Iteratively generate renewal epidemic
j = 1; warnTrue = 0;
while j <= N
    for i = 2:T+1
        % Relevant part of serial distribution
        Pomegat = Pomega(1:i-1);
        %Pomegat = Pomegat/sum(Pomegat);
        % Total infectiousness
        Lam(j, i) = Iday(j, i-1:-1:1)*Pomegat';
        % Rate for ith day incidence
        rate(j, i) = Lam(j, i)*Rtrue(i);
        % Renewal incidence
        Iday(j, i) = poissrnd(rate(j, i));
    end
    
    % Ensure epidemic survives at minimum group size
    jstart = 1; A0 = zeros(1, nGrps(1)); B0 = A0;
    for l = 1:nGrps(1)
        % End-indices of group
        jstop = jstart + grpSzs{1}(l) - 1;
        ids = jstart:jstop;
        % Incidence and Lam sum over group
        B0(l) = sum(Iday(j, ids));
        A0(l) = sum(Lam(j, ids));
    end
    % Fisher information log product
    FIprod0 = sum(log((A0.^2)./B0));
    if ~(any(isinf(FIprod0) | isnan(FIprod0)))
        % Also test on zero incidence points
        if length(find(Iday(j, :) == 0)) <= nday/4
            % Current trajectory valid
            j = j + 1;
        end
    end
end

%% Examine model selection as add more data

% Best R estimates from each criteria
Rlik = cell(1, N); Rbic = Rlik; Rqian = Rlik; Rfia = Rlik; Raic = Rlik;
% Selected k and no. groups
kbest = zeros(nmet, N); nGrpbest = kbest;
FIsum = zeros(N, nks); lik = FIsum; fia = FIsum; bic = FIsum;

% Select best k as data increases
for i = 1:N  
    % Streams considered
    nStream = i;
    IdaySel = Iday(1:i, :); LamSel = Lam(1:i, :);
    
    % MLE/selection from multiple streams
    [Rlik{i}, Rbic{i}, Rfia{i}, Rqian{i}, Raic{i}, kbest(:, i), nGrpbest(:, i),...
        FIsum(i, :), lik(i, :), fia(i, :), bic(i, :)] = fraserSubMultFn(T,...
        ks, nks, Rmax, nGrps, grpSzs, IdaySel, LamSel, nStream);
    %disp(['Completed: ' num2str(i) ' of ' num2str(N)]);
end

% Minima from fia, bic and -lik
[vallik, idlik] = min(-lik, [], 2);
[valbic, idbic] = min(bic, [], 2);
[valfia, idfia] = min(fia, [], 2);
% Best k selected
disp('Selected k = ');
disp(kbest);

% MSE with N from likelihood
mseR = zeros(1, N);
for i = 1:N
    mseR(1, i) = mean((Rlik{i} - Rtrue(2:end)).^2);
    %mseR(2, i) = mean((Rbic{i} - Rtrue(2:end)).^2);
    %mseR(3, i) = mean((Rfia{i} - Rtrue(2:end)).^2);
    %mseR(4, i) = mean((Rqian{i} - Rtrue(2:end)).^2);
end

%% Plotting and saving

figure;
plot(tday, Iday, 'linewidth', 2);
grid off; hold off;
ylabel('incidence curves');
xlabel('time (days)');
xlim([tday(1) tday(end)]);

% Selected R estimate with data
figure;
ax(1) = subplot(2, 2, 1);
stairs(2:nday, Rtrue(2:end), 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rlik{1}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{lik}^{1}$');
box off; grid off; hold off;
xlim([tday(1) tday(end)]);
ax(2) = subplot(2, 2, 2);
stairs(2:nday, Rtrue(2:end), 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rlik{N}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{lik}^{m}$');
box off; grid off; hold off;
xlim([tday(1) tday(end)]);
ax(3) = subplot(2, 2, 3);
stairs(2:nday, Rtrue(2:end), 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rfia{1}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{fia}^{1}$');
box off; grid off; hold off;
xlim([tday(1) tday(end)]);
xlabel('time (days)');
ax(4) = subplot(2, 2, 4);
stairs(2:nday, Rtrue(2:end), 'k', 'linewidth', 2);
hold on;
stairs(2:nday, Rfia{N}, 'c', 'linewidth', 2);
ylabel('$\hat{R}_{fia}^{m}$');
box off; grid off; hold off;
xlim([tday(1) tday(end)]);
xlabel('time (days)');
linkaxes(ax, 'xy');

% Changes with data of FI, -lik and criteria
figure;
subplot(2, 2, 1);
plot(1:nks, -lik', 'linewidth', 2);
hold on;
plot(idlik, vallik, 'o', 'MarkerSize', 8);
box off; grid off; hold off;
ylabel('$-\log L$');
h = gca;
subplot(2, 2, 2);
plot(1:nks, FIsum', 'linewidth', 2);
box off; grid off;
ylabel('$I_{sum}$');
subplot(2, 2, 3);
plot(1:nks, bic', 'linewidth', 2);
hold on;
plot(idbic, valbic, 'o', 'MarkerSize', 8);
box off; grid off; hold off;
ylabel('BIC');
xlabel('id of $k$');
subplot(2, 2, 4);
plot(1:nks, fia', 'linewidth', 2);
hold on;
plot(idfia, valfia, 'o', 'MarkerSize', 8);
box off; grid off; hold off;
ylabel('FIA');
xlabel('id of $k$');

% % Best k values against true one
% figure;
% plot(1:N, kbest, 'o', 'linewidth', 2);
% hold on;
% plot(1:N, ktrue*ones(1, N), 'k', 'linewidth', 2);
% xlabel('no. streams');
% ylabel('$k^*$');
% box off; grid off; hold off;
% legend('-lik', 'bic', 'fia', 'qian', 'location', 'best');
% xlim([1 N]);

% Time and save data
if saveTrue
    save(['incrData' num2str(N) '.mat']);
end
tsim = toc/60;
disp(['Run time = ' num2str(tsim)]);

