% RAPIDD challenge Fraser-Cori-Nouvellet method
clearvars;
clc; close all;

% Assumptions and notes
% - include FIA for MDL for renewal
% - only uses sinusoids for R at different freqs
% - simulates an epidemic via renewal models
% - for a single trajectory
% - acausal Fraser estimator used in model selection
% - uses a renewal model from Fraser 2011

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Time code
tic;
% Save data
saveTrue = 0;

% Define a scenario
scenNo = 2;
disp(['Scenario: ' num2str(scenNo)]);

%% Simulate epidemic scenarios

% Time for epidemic observation (days)
tday = 1:500; nday = length(tday);

% R varies sinusoidally with freq = scenNo/2
Rtrue = 1 + 0.5*sind((scenNo/2)*tday);

% Serial distribution over all time (type 1 is geom and 2 erlang)
distType = 2;
serial = serialDistrs(nday, distType);
% Single omega controlling distribution
omega = 14.2;
% Actual distribution over all tday
Pomega = serial(1/omega);

% Daily incidence
Iday = zeros(size(nday));
% Infectiousness, Poisson rate
Lam = Iday; rate = Iday;
% Initialise epidemic
Iday(1) = 20;

% Iteratively generate renewal epidemic
for i = 2:nday
    % Relevant part of serial distribution
    Pomegat = Pomega(1:i-1);
    % Total infectiousness
    Lam(i) = Iday(i-1:-1:1)*Pomegat';
    % Rate for ith day incidence
    rate(i) = Lam(i)*Rtrue(i);
    % Renewal incidence
    Iday(i) = poissrnd(rate(i));
end

% Smooth the incidence with m point averager
m = 7; B = ones(7, 1)/7;
Ifil = filter(B, 1, Iday);


%% Fraser-Cori MLEs for R and model selection

% Time vector over which R holds (length is data size)
tR = tday(2:end); lenRt = nday-1;

% Parameters for model selection
ks = 20:250; nks = length(ks);
disp(['k varies from ' num2str(ks(1)) ' to ' num2str(ks(end))]);

% Get group sizes for each k
grps = @(k) getGrpSz(k, lenRt);
grps = arrayfun(grps, ks);

% Reproductive numbers and log-likelihood with time and grouping
Rkset = zeros(nks, lenRt); likset = Rkset;
% Total likelihood of each trajectory, no. groups for each k
liksum = zeros(1, nks); nGrpset = liksum;
% Components of Qian MDL approximation
sc1 = zeros(1, nks); sc2 = sc1;
% Component for FIA MDL approximation
fia0 = zeros(1, nks);
% Fisher information
FIcomp = cell(1, nks); FIsum = zeros(1, nks);

% Model selection metrics used
metrics = {'aic', 'bic', 'hqc', 'qian' 'fia'};
nmet = length(metrics);

% Constant for R integral in FIA (2*sqrt(Rmax))
Rmax = 1000; V = 2*sqrt(Rmax);

% Group numbers and likelihood on each k
nGrps = zeros(size(ks)); liks = cell(1, nks);
% R values across k and time
Rks = zeros(nks, lenRt); Rgrps = liks;
% Sums of lambda and I over groups for each k
Lsum = liks; Isum = liks;

% Constant components of log-likelihoods
C1 = sum(-log(gamma(Iday(2:end) + 1)));
C2 = sum(Iday(2:end).*log(Lam(2:end)));

% For each set of ks get MLE and log-likelihoods
for i = 1:nks
    % Group properties at this k
    nGrp = grps(i).nGrp; grpSz = grps(i).grpSz;
    % Grouping at a specific k (num times combined)
    k = ks(i); nGrps(i) = nGrp;
    % Components of group MLEs and liks
    A = zeros(1, nGrp); B = A; lik = A; Rgrp = A;
    
    % Grouped MLEs for Rt and max log-likelihoods
    jstart = 1;
    for j = 1:nGrp
        % End-indices of group
        jstop = jstart + grpSz(j) - 1;
        ids = jstart:jstop;
        
        % Component sums: A (Lamj) and B (Ij)
        A(j) = sum(Lam(ids));
        B(j) = sum(Iday(ids));
        
        % Grouped MLE (replicated over t points in group i.e. ids)
        Rks(i, ids) = B(j)/A(j); Rgrp(j) = B(j)/A(j);
        % Maximum log-likelihood (also replicated to match R)
        lik(j) = getLogLikFraserSingle(Rgrp(j), Iday(ids), Lam(ids));
        
        % Update end indices
        jstart = jstop + 1;
    end
    % Segment log-likelihood
    liksum(i) = sum(lik);
    liks{i} = lik;
    Rgrps{i} = Rgrp;
    % Group sums for each k
    Lsum{i} = A;
    Isum{i} = B;
    
    % Fisher information over groups
    FI = (A.^2)./B;
    FIcomp{i} = FI; FIsum(i) = sum(FI);
    % Components for Qian criteria
    sc1(i) = sum(log(abs(Rgrp) + lenRt^(-0.25)));
    sc2(i) = 0.5*log(prod(FI));
    
    % FIA term
    fia0(i) = sum(0.5*log(A));
end

% Check for issues with FI
if any(isinf(sc2) | isnan(sc2))
    warning('Stoch complexity term not sensible');
end
% Check on R estimates
if any(any(any(isinf(Rkset)))) || any(any(any(isnan(Rkset))))
    error('MLEs for R are inf or nan');
end

% Likelihood contribution
L = -liksum;
% Selection criteria value evaluated at each k
aic = L + nGrps;
bic = L + (nGrps/2)*log(lenRt);
hqc = L + nGrps*log(log(lenRt));
qian = L + sc1 + sc2;
fia = L + (nGrps/2)*log(1/(2*pi)) + nGrps*log(V) + fia0;


% Select best model as min of criteria
[aicVal, aicMod] = min(aic);
[bicVal, bicMod] = min(bic);
[hqcVal, hqcMod] = min(hqc);
[qianVal, qianMod] = min(qian);
[fiaVal, fiaMod] = min(fia);
% Ad-hoc max of Fisher information
[fishVal, fishMod] = max(FIsum);

% Best ks and nGrps
modID = [aicMod bicMod hqcMod qianMod fiaMod];
kbest = ks(modID); nGrpbest = nGrps(modID);
disp(['k: [aic bic hqc qian fia] = [' num2str(kbest) ']' ]);
disp(['nGrp: [aic bic hqc qian fia] = [' num2str(nGrpbest) ']' ]);


%% Visualisation and post processing

% Noisy (augmented) and filtered daily incidence
figure;
hold all;
plot(1:nday, Iday, 'color', grey1, 'linewidth', 2);
plot(1:nday, Ifil, 'color', grey2, 'linewidth', 2);
grid off; box off;
xlabel('time (days)');
ylabel('EVD cases (augmented)');
xlim([1 nday]);

% Model selections across trajectories
figure;
hold all;
plot(ks, L, 'k', 'linewidth', 2);
plot(ks, aic, 'b', 'linewidth', 2);
plot(ks, bic, 'r', 'linewidth', 2);
plot(ks, qian, 'c', 'linewidth', 2);
plot(ks, fia, 'm', 'linewidth', 2);
plot(kbest([1 2 4]), [aicVal, bicVal, qianVal], 'ko', 'MarkerSize', 8);
h = gca; h.YTickLabel = '';
plot(kbest(1)*ones(1, 2), h.YLim, 'b--', 'linewidth', 2);
plot(kbest(2)*ones(1, 2), h.YLim, 'r--', 'linewidth', 2);
plot(kbest(4)*ones(1, 2), h.YLim, 'c--', 'linewidth', 2);
plot(kbest(5)*ones(1, 2), h.YLim, 'm--', 'linewidth', 2);
hold off; grid off; box off;
legend('lik', 'aic', 'bic', 'qian', 'fia', 'location', 'best');
xlabel('$k$ (days)');
ylabel('model criteria');
xlim([ks(1) ks(end)]);

% Total Fisher information
figure;
plot(ks, FIsum, 'c', 'linewidth', 2);
hold off; grid off; box off;
xlabel('$k$ (days)');
ylabel('$\sum I(R)$');
xlim([ks(1) ks(end)]);

% Best selected trajectories for R
figure;
hold on;
stairs(tR, Rtrue(2:end), 'k', 'linewidth', 2);
stairs(tR, Rks(modID(1), :), 'linewidth', 2);
stairs(tR, Rks(modID(2), :), 'linewidth', 2);
stairs(tR, Rks(modID(4), :), 'linewidth', 2);
stairs(tR, Rks(modID(5), :), 'linewidth', 2);
hold off; grid off; box off;
xlabel('time (days)');
ylabel('best $R_t$');
legend('true', 'aic', 'bic', 'qian', 'fia', 'location', 'best');
xlim([1 nday]);


%% Timing and data saving
tsim = toc/60;
disp(['Run time = ' num2str(tsim)]);


