% Renewal model estimation and model selection
clearvars;
clc; close all;

% Assumptions and notes
% - only does FIA and BIC
% - simulates an epidemic via renewal models
% - for a single trajectory
% - uses a renewal model from Fraser 2011

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Time code
tic;
% Save data
saveTrue = 0;

% Define a scenario
scenNo = 6;
disp(['Scenario: ' num2str(scenNo)]);

%% Simulate epidemic scenarios

% Time for epidemic observation (days)
tday = 1:201; nday = length(tday);

% Possible scenarios available
scenNam = {'const', 'boom-bust', 'bottle' 'linear', 'piecewise', 'fixed-width'};

% Functions for scenarios: R on a daily basis
switch(scenNo)
    case 1
        % Sinusoidal R across time
        Rtrue = 1.2 + 0.5*sind(tday);
    case 2
        % Exponential rise and fall
        Rtrue = zeros(size(tday)); tchange = floor(nday/2);
        trise = 1:tchange; tfall = tchange+1:nday;
        % Exponential rise to max at tchange
        Rtrue(trise) =  exp(0.009*(1:tchange)); Rmax = Rtrue(tchange);
        % Exponential decay from max
        Rtrue(tfall) = Rmax*exp(-0.009*(tfall - tchange)); 
    case 3
        % A constant, then fall, then back to constant
        Rtrue = zeros(size(tday)); 
        tchange = [floor(nday/3), floor(2*nday/3)];
        % First high segment
        Rtrue(1:tchange(1)) = 2;
        % Second low segment
        Rtrue(tchange(1)+1:tchange(2)) = 0.5;
        % Third high segment
        Rtrue(tchange(2)+1:nday) = 1;
    case 4
        % Linear increasing R
        Rtrue = 0.01*tday + 1;
    case 5
        % Piecewise-constant case with random no. segments
        nSegs = 8; Rsegs = unifrnd(0.5, 10, [1 nSegs+1]);
        tchange = sort(datasample(2:nday-1, nSegs, 'Replace',false));
        % Assign random R intensities to each
        Rtrue = zeros(size(tday));
        for i = 1:nSegs+1
            % R indices for current segment
            if i == 1
                idR = 1:tchange(i);
            elseif i == nSegs+1
                idR = tchange(i-1)+1:nday;
            else
                idR = tchange(i-1)+1:tchange(i);
            end
            % Values over this segment
            Rtrue(idR) = Rsegs(i);
        end
        disp(['No. random segments = ' num2str(nSegs+1)]);
    case 6
        % Fixed width windows
        ktrue = 50; disp(['True k = ' num2str(ktrue)]);
        % Change points (as day 1 not estimated start at 2)
        tchange = ktrue+1:ktrue:nday;
        % Truncate to integer number of ktrue days
        nday = tchange(end); tday = 1:nday;
        % Segments with different R
        nSegs = length(tchange);
        Rsegs = unifrnd(0.25, 2.5, [1 nSegs]);
        % Vector for true R
        Rtrue = zeros(size(tday));
        for i = 1:nSegs
            % R indices for current segment
            if i == 1
                idR = 1:tchange(i);
            elseif i == nSegs
                idR = tchange(i-1)+1:nday;
            else
                idR = tchange(i-1)+1:tchange(i);
            end
            % Values over this segment
            Rtrue(idR) = Rsegs(i);
        end
end

% Serial distribution over all time (type 1 is geom and 2 erlang)
distType = 2;
serial = serialDistrs(nday, distType);
% Single omega controlling distribution
omega = 14.2; 
% Actual distribution over all tday
Pomega = serial(1/omega);

% Daily incidence
Iday = zeros(size(nday)); 
% Infectiousness, Poisson rate 
Lam = Iday; rate = Iday;
% Initialise epidemic
Iday(1) = 10;

% Iteratively generate renewal epidemic
for i = 2:nday
    % Relevant part of serial distribution
    Pomegat = Pomega(1:i-1);
    % Total infectiousness
    Lam(i) = Iday(i-1:-1:1)*Pomegat';
    % Rate for ith day incidence
    rate(i) = Lam(i)*Rtrue(i);
    % Renewal incidence
    Iday(i) = poissrnd(rate(i));
end

% Smooth the incidence with m point averager
m = 7; B = ones(7, 1)/7;
Ifil = filter(B, 1, Iday);
 

%% Fraser-Cori MLEs for R and model selection

% Time vector over which R holds (length is data size)
tR = tday(2:end); lenRt = nday-1;

% Parameters for model selection
ks = 20:200; 
%ks = cumprod(2*ones(1, ceil(log2(nday-1)))); ks = ks(ks > 10);
nks = length(ks);
%disp(['k varies from ' num2str(ks(1)) ' to ' num2str(ks(end))]);

% Get group sizes for each k
grps = @(k) getGrpSz(k, lenRt);
grps = arrayfun(grps, ks);

% Reproductive numbers and log-likelihood with time and grouping
Rkset = zeros(nks, lenRt); likset = Rkset;
% Total likelihood of each trajectory, no. groups for each k
liksum = zeros(1, nks); nGrpset = liksum;
% Fisher information sum and product
FIsum = zeros(1, nks); FIprod = FIsum;

% Model selection metrics used
metrics = {'bic', 'fia'};
nmet = length(metrics);
% FIA and BIC selection criteria 
fia = zeros(1, nks); bic = fia;
% Constant for R integral in FIA (2*sqrt(Rmax))
Rmax = 100/2; V = 2*sqrt(Rmax);

% Group numbers and likelihood on each k
nGrps = zeros(size(ks)); liks = cell(1, nks);
% R values across k and time
Rks = zeros(nks, lenRt); Rgrps = liks;
% Sums of lambda and I over groups for each k
Lsum = liks; Isum = liks;

% Ids for group end points at each k
iddays = cell(1, nks);

% For each set of ks get MLE and log-likelihoods
for i = 1:nks
    % Group properties at this k
    nGrp = grps(i).nGrp; grpSz = grps(i).grpSz;
    % Grouping at a specific k (num times combined)
    k = ks(i); nGrps(i) = nGrp; 
    
    % Components of group MLEs, liks, criteria
    A = zeros(1, nGrp); B = A; lik = A; 
    Rgrp = A; FIAcomp = A; BICcomp = A;
    % Start and end indices
    idday = zeros(nGrp, 2);
    
    % Grouped MLEs for Rt and max log-likelihoods
    jstart = 1;
    for j = 1:nGrp
        % End-indices of group 
        jstop = jstart + grpSz(j) - 1;
        ids = jstart:jstop;
        % Ids (start and end) constituting each group
        idday(j, :) = [jstart jstop];
        
        % Component sums: A (Lamj) and B (Ij)
        A(j) = sum(Lam(ids));
        B(j) = sum(Iday(ids));

        % Grouped MLE (replicated over t points in group i.e. ids)
        Rks(i, ids) = B(j)/A(j); Rgrp(j) = B(j)/A(j);
        % Maximum log-likelihood (also replicated to match R)
        lik(j) = getLogLikEst(Rgrp(j), B(j));
        
        % Component for FIA
        FIAcomp(j) = -lik(j) + 0.5*log(2*Rmax*A(j)/pi);
        BICcomp(j) = -lik(j) + 0.5*log(lenRt);
        
        % Update end indices
        jstart = jstop + 1;
    end
    
    % Segment log-likelihood
    liksum(i) = sum(lik);
    liks{i} = lik; Rgrps{i} = Rgrp;
    % Group sums for each k
    Lsum{i} = A; Isum{i} = B;
    % Ids (days) constituting each group
    iddays{i} = idday;
    
    % FI over groups and product
    FI = (A.^2)./B; 
    FIprod(i) = prod(FI);
    
    % FIA and BIC for that model
    fia(i) = sum(FIAcomp);
    bic(i) = sum(BICcomp);
end

% Check for issues with FI
if any(isinf(FIprod) | isnan(FIprod))
    warning('Stoch complexity term not sensible');
end
% Check on R estimates
if any(any(any(isinf(Rkset)))) || any(any(any(isnan(Rkset))))
    error('MLEs for R are inf or nan');
end

% Select best model as min of criteria
[bicVal, bicMod] = min(bic);
[fiaVal, fiaMod] = min(fia);

% Min from log-likelihood alone
[LVal, LMod] = min(-liksum);
Lgrp = nGrps(LMod); Lk = ks(LMod);

% Best ks and nGrps
modID = [bicMod fiaMod];
kbest = ks(modID); nGrpbest = nGrps(modID);
disp(['k: [lik bic fia] = [' num2str([Lk kbest]) ']' ]);
disp(['nGrp: [lik bic fia] = [' num2str([Lgrp nGrpbest]) ']' ]);


%% Visualisation and post processing


% Noisy (augmented) and filtered daily incidence
figure;
hold all;
plot(1:nday, Iday, 'color', grey1, 'linewidth', 2);
plot(1:nday, Ifil, 'color', grey2, 'linewidth', 2);
hold off; grid off; box off;
if max(Iday) > 2000
    h = gca; h.YScale = 'log';
end
xlabel('time (days)');
ylabel('EVD cases (augmented)');
xlim([1 nday]);


% Model selections across trajectories
figure;
hold all;
plot(ks, -liksum, 'k', 'linewidth', 2);
plot(ks, bic, 'r', 'linewidth', 2);
plot(ks, fia, 'm', 'linewidth', 2);
plot([Lk kbest], [LVal bicVal, fiaVal], 'ko', 'MarkerSize', 8);
h = gca; h.YTickLabel = '';
plot(kbest(1)*ones(1, 2), h.YLim, 'r--', 'linewidth', 2);
plot(kbest(2)*ones(1, 2), h.YLim, 'm--', 'linewidth', 2);
hold off; grid off; box off;
legend('lik', 'bic', 'fia', 'location', 'best');
xlabel('$k$ (days)');
ylabel('model criteria');
xlim([ks(1) ks(end)]);

% Best selected trajectories for R
figure;
hold on;
stairs(tR, Rtrue(2:end), 'k', 'linewidth', 2);
stairs(tR, Rks(modID(1), :), 'linewidth', 2);
stairs(tR, Rks(modID(2), :), 'linewidth', 2);
hold off; grid off; box off;
xlabel('time (days)');
ylabel('best $R_t$');
legend('true', 'bic', 'fia', 'location', 'best');
xlim([1 nday]);


%% Timing and data saving
tsim = toc/60;
disp(['Run time = ' num2str(tsim)]);


