% Compute risk function for soft-thresholding adaptive estimator, with
% constraint the worst case risk is 70% than YU
clear all
%%
% Global variables for functions
global b_grid y_grid Sigma Ky Kb Pi omega_grid options rho_tbl policy Eb dEb d2Eb l0 mu 


%% fmincon options
options = optimoptions(@fmincon,'Algorithm','sqp','Display','iter');
options = optimoptions(options,'MaxFunctionEvaluations',2000000,...
    'MaxIterations',40000,'ConstraintTolerance',1e-7,'OptimalityTolerance',1e-9);

%% Form scaling - lookup the minimax risk for bounded normal mean
rho_tbl = readmatrix('../Matlab/sim_results/minimax_rho_B9.csv');
%% Loop over 
Sigma_UO_grid = tanh((-3:0.05:-0.05)); 
Kcorr = length(Sigma_UO_grid);
%% Loop over upper bounds

B_grid = [9];

for i = 1:length(B_grid)
    B = B_grid(i);
    B
    %% Specify grid 
    b_grid = (-B:0.025:B)';
    y_grid = (-(B+3):0.05:(B+3))';
    bias_grid = b_grid; % if reparameterize the b_grid with /sqrt(Sigma_b)

    %% Grid sizes
    Ky = length(y_grid);
    Kb = length(b_grid);
    
    %% Define risk function for soft threshold at l
    Eb = @(l) 1+l^2+...
     (bias_grid.^2-1-l^2).*(normcdf(l-bias_grid)-normcdf(-l-bias_grid))+...
     (-bias_grid-l).*normpdf(l-bias_grid) - (l-bias_grid).*normpdf(-l-bias_grid);
    Eb0 = @(l) 1+l^2+...
     (-1-l^2) *(normcdf(l)-normcdf(-l))+...
     (-l) *normpdf(l) - (l) *normpdf(-l);
    % risk function for minimax estimator
    rho_grid = interp1(rho_tbl(:,1),rho_tbl(:,2),abs(bias_grid),'spline');

    %% Initialize the results matrix
    % save scalar thresholds
    st7_mat = zeros(Kcorr,1); 
    for idx = 1:length(Sigma_UO_grid)
        idx
        Sigma_UO = Sigma_UO_grid(idx);
        Sigma = [1 Sigma_UO;Sigma_UO 1];
        Sigma_t = Sigma(1,1);
        Sigma_tb = Sigma(1,2); % correlation coefficient
        Sigma_b = Sigma(2,2);
        global corr2
        corr2 = Sigma_tb^2/(Sigma_t*Sigma_b); % squared corr. coef. 
        cons = 1/corr2 - 1; 

        risk_oracle = corr2*rho_grid+ 1-corr2; 
        omega_grid = (risk_oracle).^(-1); 
        %% Scaling

        Eb_scaled = @(l) (corr2*Eb(l)+1-corr2).*omega_grid;
        st7 = fminimax(Eb_scaled,0,[],[],[],[],0,[],@st7_risk);
        st7_mat(idx) = st7; 
   
    end
end
% save('sim_results/const7_thresholds.mat','st7_mat');

%% 
load('../Matlab/sim_results/thresholds.mat');

for i = 1:length(Sigma_UO_grid)
    corr2 = Sigma_UO_grid(i)^2;
           
            max_st_v_unbiased(i) = corr2*max(Eb(st_mat(i)))+ 1-corr2; 
            min_st_v_unbiased(i) = corr2*min(Eb(st_mat(i)))+ 1-corr2; 

            max_st_v_unbiased7(i) = corr2*max(Eb(st7_mat(i)))+ 1-corr2;
            min_st_v_unbiased7(i) = corr2*min(Eb(st7_mat(i)))+ 1-corr2;
 

end
%% constrained st relative (to Y_U) risk plots 
fig = figure(1)
yyaxis left
p1=plot(Sigma_UO_grid.^2,100*(max_st_v_unbiased7-1),'-',...
    'DisplayName','Constained nearly adaptive \newline (max and min risk)','LineJoin','miter','LineWidth',2 )
hold on
p2=plot(Sigma_UO_grid.^2,100*(max_st_v_unbiased-1),'-.',...
    'DisplayName','Nearly adaptive \newline (max and min risk)','LineJoin','miter','LineWidth',2 )
hold on
p3=plot(Sigma_UO_grid.^2,100*(min_st_v_unbiased7-1),'-',...
    'DisplayName','Constained nearly adaptive \newline (min risk)','LineJoin','miter','LineWidth',2 )
hold on
p4=plot(Sigma_UO_grid.^2,100*(min_st_v_unbiased-1),'-.',...
    'DisplayName','Nearly adaptive \newline (min risk)','LineJoin','miter','LineWidth',2 )
xlabel('Correlation coefficient^2');
ylabel('Max and min risk','FontSize',14 );
hold on
yyaxis right
p5=plot(Sigma_UO_grid.^2,(1-min_st_v_unbiased7)./(max_st_v_unbiased7-1),':',...
    'DisplayName','Relative risk reduction \newline (Constrained nearly adaptive)','LineJoin','miter','LineWidth',2 );
hold on
p6=plot(Sigma_UO_grid.^2,(1-min_st_v_unbiased)./(max_st_v_unbiased-1),'--',...
    'DisplayName','Relative risk reduction \newline (Nearly adaptive)','LineJoin','miter','LineWidth',2 );
hold on
yline(1,'--','Color',[0.5 0.5 0.5]); ylim([0 3])
ylabel('- Min risk / Max risk','FontSize',14,'rotation',-90,'VerticalAlignment','bottom');
legend([p1 p5 p2 p6]...
    ,'Location','southoutside','Orientation','horizontal',...
    'FontSize',14 ,'NumColumns', 2);
legend('boxoff');

  %% Save figures
% figurename = strcat( 'sim_results/penalty_cons_st_against_corr.png');
figurename = strcat( '../../figures/sec44.png');
set(fig,'Units','Inches');
pos = get(fig,'Position');
set(fig,'PaperPositionMode','Auto','PaperUnits','Inches','PaperSize',[pos(3), pos(4)])
print(fig,figurename,'-dpng','-r0')