% Simulate multiple epidemics and investigate renewal model selection
clearvars;
clc; close all;

% Assumptions and notes
% - allows for doubling the data randomly
% - runs batch over cases with fixed k and tests consistency
% - does not include I(0) = 10 in analysis (even at offst = 0)
% - truncate epidemics from startup issues
% - each epidemic used as data with different R profile

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Time code
tic;

% Save data and folder
saveTrue = 1;
thisDir = cd;
saveFol = 'batchRunsFix2';
% Decide if doubling data
doubleTrue = 0;

% Time for epidemic observation (days)
tday = 1:401; nday = length(tday);
% No. epidemic batch runs for each ktrue
N = 2000; T = nday - 1;

% Add offset for initial zeros
offst = 50;
% Maximum R for FIA integral domain
Rmax = 100;

% Metrics to be evaluated
metrics = {'lik', 'bic', 'fia', 'qian', 'aic'};
nmet = length(metrics);

% Parameters for model selection
ks = 1:T; ks = ks(rem(T, ks) == 0);
ks = ks(ks >= 20); nks = length(ks);

% Groups corresponding to ks
grps = @(k) getGrpSz(k, T);
grps = arrayfun(grps, ks);
% Group numbers and likelihood on each k
nGrps = zeros(size(ks));  grpSzs = cell(1, nks);
for i = 1:nks
    nGrps(i) = grps(i).nGrp;
    grpSzs{i} = grps(i).grpSz;
end

% Serial distribution over all time (type 1 is geom and 2 erlang)
distType = 2;
serial = serialDistrs(nday + offst, distType);
% Single omega controlling distribution
omega = 14.2;
% Actual distribution over all tday
Pomega = serial(1/omega);

% Outputs of runs
kbests = cell(1, nks); idbests = kbests;
kfrac = zeros(nmet, nks); kerr = kfrac;

%% Batch over trajectories and k values

for ii = 1:nks
    % True k value
    ktrue = ks(ii); idtrue = ii;
    
    % Simulate N incidence trajectories from Rtrue
    Iday = zeros(N, T+1+offst);
    % Infectiousness, Poisson rate
    Lam = Iday; rate = Iday;
    % Initialise epidemics
    Iday(:, 1) = 10;
    
    % Change points (start at offset)
    tchange = [offst ktrue+offst+1:ktrue:nday+offst];
    % Ensure integer number of ktrue days
    if nday ~= tchange(end) - offst
        error('Check the day groupings');
    end
    % Define true R profile with ktrue width
    nSegs = length(tchange);
    Rsegs = [1.2 unifrnd(0.75, 1.75, [1 nSegs-1])];
    % Vector for true R
    Rtrue = zeros(1, nday + offst);
    for i = 1:nSegs
        % R indices for current segment
        if i == 1
            idR = 1:tchange(i);
        elseif i == nSegs
            idR = tchange(i-1)+1:nday+offst;
        else
            idR = tchange(i-1)+1:tchange(i);
        end
        % Values over this segment
        Rtrue(idR) = Rsegs(i);
    end
        
    % Iteratively generate renewal epidemic
    j = 1; warnTrue = 0;
    while j <= N
        for i = 2:T+1+offst
            % Relevant part of serial distribution
            Pomegat = Pomega(1:i-1);
            %Pomegat = Pomegat/sum(Pomegat);
            % Total infectiousness
            Lam(j, i) = Iday(j, i-1:-1:1)*Pomegat';
            % Rate for ith day incidence
            rate(j, i) = Lam(j, i)*Rtrue(i);
            % Renewal incidence
            Iday(j, i) = poissrnd(rate(j, i));
        end
        
        % Ensure epidemic survives at minimum group size
        jstart = 1 + offst;
        A0 = zeros(1, nGrps(1)); B0 = A0;
        for l = 1:nGrps(1)
            % End-indices of group
            jstop = jstart + grpSzs{1}(l) - 1;
            ids = jstart:jstop;
            % Incidence and Lam sum over group
            B0(l) = sum(Iday(j, ids));
            A0(l) = sum(Lam(j, ids));
        end
        % Fisher information log product
        FIprod0 = sum(log((A0.^2)./B0));
        if ~(any(isinf(FIprod0) | isnan(FIprod0)))
            % Also test on zero incidence points
            if length(find(Iday(j, :) == 0)) <= nday/4 && max(Iday(j, :)) > 200
                % Current trajectory valid
                j = j + 1;
            end
        end
    end
    
    % Remove gaps of zeros from all trajectories and truncate
    idStart = zeros(1, N); jwarn = zeros(1, N);
    for j = 1:N
        % Gaps between non-zero indicence values
        zerogaps = diff(find(Iday(j, :) ~= 0));
        % Remove startup sequence of zeros if big
        z1 = zerogaps(1); zrest = zerogaps(2:end);
        if z1 > 8
            % Update incidence and related vectors
            idStart(j) = z1+1;
            % Flag zero incidence regions after startup
            if max(zrest) > 8
                %warning('Zero incidences beyond startup');
                jwarn(j) = 1;
            end
        else
            % Flag any zero incidence region
            if max(zrest) > 8
                %warning('Sequences of zero incidence');
                jwarn(j) = 1;
            end
            % Un-truncated day set
            idStart(j) = 2;
        end
    end
    % Check starting beyond truncation id
    if max(idStart) > offst
        error('Offset too small');
    else
        % Remove trajectories with warnings
        idw = find(~jwarn); Ncurr = length(idw);
        disp(['N changed from ' num2str(N) ' to ' num2str(Ncurr)]);
        
        % Truncate vectors to have length T
        Rtrue = Rtrue(offst+2:end); rate = rate(idw, offst+2:end);
        Iday = Iday(idw, offst+2:end); Lam = Lam(idw, offst+2:end);        
    end
    
    % Best R estimates from each criteria
    Rlik = cell(1, Ncurr); Rbic = Rlik; Rqian = Rlik; 
    Rfia = Rlik; Raic = Rlik;
    % Selected k and no. groups
    kbest = zeros(nmet, Ncurr); nGrpbest = kbest;
    lik = zeros(Ncurr, nks); fia = lik; bic = lik; aic = lik;
    
    % Select best k for each stream
    for i = 1:Ncurr
        % MLE/selection from multiple streams
        if doubleTrue
            iuse = [i datasample(setdiff(1:Ncurr, i), 1)];
            nStream = 2;
        else
            iuse = i; nStream = 1;
        end
        [Rlik{i}, Rbic{i}, Rfia{i}, Rqian{i}, Raic{i}, kbest(:, i), nGrpbest(:, i), ~, ...
            lik(i, :), fia(i, :), bic(i, :), aic(i, :)] = fraserSubMultFn2(T, ks, nks, ...
            Rmax, nGrps, grpSzs, Iday(iuse, :), Lam(iuse, :), nStream);
    end
    % ID for each selection
    idbest = kbest;
    for i = 1:nks
        idbest(kbest == ks(i)) = i;
    end
    
    % Main data about performance
    kbests{ii} = kbest; idbests{ii} = idbest;
     % Fraction right and index error for each metric
    kfrac(1:nmet, ii) = sum(kbest' == ktrue)/Ncurr;
    kerr(1:nmet, ii) = sum(abs(idbest - idtrue), 2); 
    disp(['Completed: ' num2str(ii) ' of ' num2str(nks)]);
end

% Summaries of performance of metrics
kf = sum(kfrac, 2); ke = sum(kerr, 2);
disp(metrics); disp([kf/max(kf) ke/max(ke)]');

%% Plotting and saving

% Fraction and error performance
figure;
subplot(2, 1, 1);
plot(1:nks, kfrac(2:end, :), '.-', 'linewidth', 2, 'MarkerSize', 30);
grid off; box off;
ylabel('$P(k = k^*)$');
subplot(2, 1, 2);
zz = max(max(kerr(2:end, :)));
plot(1:nks, kerr(2:end, :)/zz, '.-', 'linewidth', 2, 'MarkerSize', 30);
grid off; box off;
xlabel('$k$ index');
ylabel('$\|$ id($k$) - id($k^*$)$\|$');
legend(metrics(2:end), 'location', 'best');
if saveTrue
    cd(saveFol);
    if doubleTrue
        saveas(gcf, ['doub_' num2str(N) '_' num2str(nks)], 'fig');
    else
        saveas(gcf, ['sing_' num2str(N) '_' num2str(nks)], 'fig');
    end
    cd(thisDir);
end

% Time and save data
if saveTrue
    cd(saveFol);
    clearvars('rate', 'Lam', 'Iday', 'Rqian', 'Raic', 'Rbic', 'Rlik', 'Rfia');
    if doubleTrue
        save(['doub_' num2str(N) '_' num2str(nks) '.mat']);
    else
        save(['sing_' num2str(N) '_' num2str(nks) '.mat']);
    end
    cd(thisDir);
end
tsim = toc/60;
disp(['Run time = ' num2str(tsim) ' mins']);

