% Examine classification ROC curve for a change using criteria
clearvars;
clc; close all;

% Assumptions and notes
% - two models - constant or single change
% - examine misclassification rate
% - adds min epidemic max size and Rmax = 100
% - does not include I(0) = 10 in analysis (even at offst = 0)
% - truncate epidemics from startup issues

% Aditional plotting/partition package
addpath(genpath('/Users/kp10/Documents/MATLAB'));

% Set figure defaults
set(groot, 'defaultAxesTickLabelInterpreter', 'latex');
set(groot, 'defaultLegendInterpreter', 'latex');
set(0, 'defaultTextInterpreter', 'latex');
set(0, 'defaultAxesFontSize', 16);
grey1 = 0.8*ones(1, 3); grey2 = 0.5*ones(1, 3);

% Time code
tic;
% Save data and folder
saveTrue = 1;
thisDir = cd;
saveFol = 'rocData';

%% Define models and generate epidemics

% Time for epidemic observation (days)
tday = 1:201; nday = length(tday);
% No. runs for each test, and data size T
N = 1000; T = nday - 1;

% Offset for initial zeros and min max incidence
offst = 20; Imin = 30;
% Maximum R for FIA integral domain
Rmax = 100;

% Metrics to be evaluated
metrics = {'lik', 'bic', 'fia', 'qian', 'aic'};
nmet = length(metrics);

% Parameters for model selection
ks = 1:T; ks = ks(rem(T, ks) == 0);
ks = ks(ks >= 100); nks = length(ks);

% Groups corresponding to ks
grps = @(k) getGrpSz(k, T);
grps = arrayfun(grps, ks);
% Group numbers and likelihood on each k
nGrps = zeros(size(ks));  grpSzs = cell(1, nks);
for i = 1:nks
    nGrps(i) = grps(i).nGrp;
    grpSzs{i} = grps(i).grpSz;
end

% Serial distribution over all time (type 1 is geom and 2 erlang)
distType = 2;
serial = serialDistrs(nday + offst, distType);
% Single omega controlling distribution
omega = 14.2;
% Actual distribution over all tday
Pomega = serial(1/omega);

% Model indices and null prob
modIdx = [1 2]; n0 = 100;
P0s = linspace(0.01, 0.99, n0);
% Null and alternative model k
kval = [T T/2];
% True R profiles for each model
RTrue = 1.5*ones(2, nday + offst);
% Midpoint change for alternative model
tchange = kval(2)+offst+1; 
RTrue(2, tchange+1:end) = 0.5;

%% Generate specific epidemics and model select

% True model index and k
modTrue = zeros(n0, N); kTrue = modTrue;
% Selected model indices
modAIC = modTrue; modBIC = modAIC; modQian = modAIC;
modFIA = modAIC; modLik = modAIC;
% Chosen k values and correct prob
kEst = cell(1, n0); PTrue = kEst;
% Positive and negative counts
nNeg = zeros(1, n0); nPos = nNeg;
% True and false positive rates
TPR = zeros(nmet, n0); FPR = TPR;

% Generate N epidemics at each n0 probabilities
for ii = 1:n0
    % Prob of null model
    P0 = P0s(ii);
    % Model indices weighted by P0
    modTrue(ii, :) = datasample(modIdx, N, 'Weights', [P0 1-P0]);
    ktrue = kval(modTrue(ii, :)); kTrue(ii, :) = ktrue;
    
    % Simulate N incidence trajectories
    Iday = zeros(N, T+1+offst);
    % Infectiousness, Poisson rate
    Lam = Iday; rate = Iday;
    % Initialise epidemics
    Iday(:, 1) = 10;
    
    % Iteratively generate renewal epidemic
    j = 1; warnTrue = 0;
    while j <= N
        % True R for chosen model
        Rtrue = RTrue(modTrue(ii, j), :);
        
        for i = 2:T+1+offst
            % Relevant part of serial distribution
            Pomegat = Pomega(1:i-1);
            % Total infectiousness
            Lam(j, i) = Iday(j, i-1:-1:1)*Pomegat';
            % Rate for ith day incidence
            rate(j, i) = Lam(j, i)*Rtrue(i);
            % Renewal incidence
            Iday(j, i) = poissrnd(rate(j, i));
        end
        
        % Ensure epidemic survives at minimum group size
        jstart = 1 + offst;
        A0 = zeros(1, nGrps(1)); B0 = A0;
        for l = 1:nGrps(1)
            % End-indices of group
            jstop = jstart + grpSzs{1}(l) - 1;
            ids = jstart:jstop;
            % Incidence and Lam sum over group
            B0(l) = sum(Iday(j, ids));
            A0(l) = sum(Lam(j, ids));
        end
        % Fisher information log product
        FIprod0 = sum(log((A0.^2)./B0));
        if ~(any(isinf(FIprod0) | isnan(FIprod0)))
            % Also test on zero incidence points
            if length(find(Iday(j, :) == 0)) <= nday/4 && max(Iday(j, :)) > Imin
                % Current trajectory valid
                j = j + 1;
            end
        end
    end
    % Truncate incidence and remove starting zeros
    [~, ~, Iday, Lam] = remGaps(Iday, N, offst, Rtrue, rate, Lam);
    
    % Selected k and model index
    kbest = zeros(nmet, N);
    % MLE based model selection
    for j = 1:N 
        [~, ~, ~, ~, ~, kbest(:, j), ~, ~, ~, ~, ~, ~] = fraserSubMultFn2(T,...
          ks, nks, Rmax, nGrps, grpSzs, Iday(j, :), Lam(j, :), 1);
    end
    % Get selected models and ks
    kEst{ii} = kbest;
    % Prob correct for each criteria
    PTrue{ii} = sum(kbest == kTrue(ii, :), 2)/N;
    
    % Convert to logical arrays with 0 for the null model
    logtrue = ktrue == kval(2);
    nPos(ii) = sum(logtrue); nNeg(ii) = N - nPos(ii);
    for j = 1:nmet
        % Classification of a metric
        logest = kbest(j, :) == kval(2);
        % Confusion mateix for TPR and FPR
        C = confusionmat(logtrue, logest); 
        TPR(j, :) = C(2, 2)/nPos(ii);
        FPR(j, :) = 1 - C(1, 1)/nNeg(ii);
    end
    disp(['Completed: ' num2str(ii) ' of ' num2str(n0)]);
end

% True probability by metric
PTrue = cell2mat(PTrue);
% Avg success rate
Psucc = sum(PTrue, 2)/n0; 

%% Plotting and saving

% Null and alternative models
figure;
plot(2:nday, RTrue(1, offst+2:end), 'color', grey1, 'linewidth', 2);
hold on;
plot(2:nday, RTrue(2, offst+2:end), '--', 'color', grey2, 'linewidth', 2);
hold off; grid off; box off;
xlim([2, nday]);
%ylim([RTrue(2, tchange+offst)-0.1, RTrue(1, tchange+offst)+0.1]);
xlabel('time (days)'); ylabel('$R$');

% Success probabilities (ignoring lik)
figure;
plot(P0s, PTrue(2:end, :), 'linewidth', 2);
h = gca; h.YLim(2) = 1.02;
grid off; box off;
xlabel('P(null)'); ylabel('P(correct)');
legend(metrics(2:end), 'location', 'best');
if saveTrue
    cd(saveFol);
    saveas(gcf, ['P0_' num2str(N) '_' num2str(n0) '_' num2str(Rmax) '_' num2str(T)], 'fig');
    %saveas(gcf, ['P0_' num2str(N) '_' num2str(n0) '_' num2str(Rmax) '_' num2str(T) 'rev'], 'fig');
    cd(thisDir);
end

% Success probs with models inset
figure;
plot(P0s, PTrue(2:end, :), 'linewidth', 2);
h = gca; h.YLim(2) = 1.02;
grid off; box off;
xlabel('P(null)'); ylabel('P(correct)');
 legend('BIC', 'FIA', 'QK', 'AIC', 'location', 'best');
h = axes('Position', [.23, .23, .25, .25]);
plot(h, 2:nday, RTrue(1, offst+2:end), 'color', grey1, 'linewidth', 2);
hold(h);
plot(h, 2:nday, RTrue(2, offst+2:end), '--', 'color', grey2, 'linewidth', 2);
box(h, 'off'); h.YLim = [0.45 1.55];
ylabel(h, '$R(t)$'); xlabel(h, '$t$ (days)'); hold(h);


% Time and save data
if saveTrue
    cd(saveFol);
    clearvars('rate', 'Lam', 'Iday', 'Rqian', 'Raic', 'Rbic', 'Rlik', 'Rfia');
    save(['P0_' num2str(N) '_' num2str(n0) '_' num2str(Rmax) '_' num2str(T) '.mat']);
    %save(['P0_' num2str(N) '_' num2str(n0) '_' num2str(Rmax) '_' num2str(T) 'rev.mat']);
    cd(thisDir);
end
tsim = toc/60;
disp(['Run time = ' num2str(tsim)]);

