% Replicate the Wald and OLS estimates of Angrist and Krueger (1991)
clear all
AK91 = readmatrix('../../data/Angrist and Krueger (1991)/cohort3039.csv');
AK91 = AK91 - mean(AK91); % partial out the constant term
Y = AK91(:,1); X = AK91(:,2); Z = AK91(:,3);
clear AK91
n = length(Y); % sample size
%% Replicate the OLS estimate

XX = X'*X;
XY = X'*Y;
YR = XX\XY;
e = Y - YR*X;
v_e = e'*e/n; % homoskedastic specification
v_ols = (XX/n)\v_e; % asymptotic variance
sqrt(v_ols/n) % confirm that we replicate the s.e. reported

%% Replication the Wald estimate

XPX = X'*Z*((Z'*Z)\(Z'*X));
XPY = X'*Z*((Z'*Z)\(Z'*Y));
YU = XPX\XPY;
u = Y - YU*X;
v_u = u'*u/n; % homoskedastic specification
v_wald = (XPX/n)\v_u; % asymptotic variance
sqrt(v_wald/n) % confirm that we replicate the s.e. reported
%% Covariance between the OLS and Wald estimate
v_wald_ols = (XX/n)\(e'*u/n);
VUR = v_wald_ols/n;
disp('The n*variance of Y_U is')
disp(v_wald)
disp('The n*variance of Y_O is')
disp(v_ols - 2*v_wald_ols + v_wald)
disp('The n*covariance between Y_U and Y_O is')
disp(v_wald_ols - v_wald)

%% Over-id test
YO = YR - YU ;
VO = (v_ols - 2*v_wald_ols + v_wald)/n;
VUO = (v_wald_ols - v_wald)/n;
VU = v_wald/n; VR = v_ols/n;
disp('The over-id test statistic is')
tO = YO/sqrt(VO);
disp(tO)
disp('The efficient estimator is')
CUE = YU - VUO/VO * YO;
disp(CUE)
disp('The correlation coefficient is')
corr = VUO/sqrt(VO)/sqrt(v_wald/n) 
%% Compute the adaptive estimator
% corr is outside the bound and will need to manually 

% Global variables for functions
global b_grid y_grid Ky Kb Pi omega_grid options rho_tbl policy Eb corr2
global mu0 x0 x0_bimodal % starting values for functions
load('../Matlab/sim_results/init_priors.mat'); % starting value for tanh(-3) as initial guess
% fmincon options
options = optimoptions(@fmincon,'Algorithm','sqp','Display','off');
options = optimoptions(options,'MaxFunctionEvaluations',2000000,...
    'MaxIterations',40000,'ConstraintTolerance',1e-7,'OptimalityTolerance',1e-9);
Sigma_UO = corr; corr_str = '1';

%  Form scaling - lookup the minimax risk for bounded normal mean
rho_tbl = readmatrix('../Matlab/sim_results/minimax_rho_B9.csv');
B = 9;
%  Specify grid
b_grid = (-B:0.025:B)';
y_grid = (-(B+3):0.05:(B+3))';

%  Grid sizes
Ky = length(y_grid);
Kb = length(b_grid);

Sigma = [1 Sigma_UO;Sigma_UO 1];
Sigma_t = Sigma(1,1);
Sigma_tb = Sigma(1,2); % correlation coefficient
Sigma_b = Sigma(2,2);
bias_grid = b_grid/sqrt(Sigma_b);
tic
addpath('../Matlab/')
[x,x_bimodal,prior,psi, bayes, psi_bimodal_grid,risk_function_adaptive, risk_function_bimodal] = risk_calc(Sigma);
toc
rho_b_over_sigma = interp1(rho_tbl(:,1),rho_tbl(:,2),abs(b_grid/sqrt(Sigma_b)),'spline');
risk_oracle = rho_b_over_sigma + 1/Sigma_tb^2 -1;
risk_bimodal = ones(Kb,1)*(1/Sigma_tb^2);
risk_bimodal(b_grid == 0) = (1/Sigma_tb^2-1);
	     
% writetable(table(y_grid,psi,bayes,psi_bimodal_grid),strcat( 'sim_results/minimax_adaptive_psi_sigmatb_',corr_str,'_B',string(B),'.csv'));
% writetable(table(b_grid,risk_function_adaptive,risk_oracle,...
%                 risk_function_bimodal,risk_bimodal),...
%                 strcat( 'sim_results/risk_and_oracle_risk_sigmatb_',corr_str,'_B',string(B),'.csv'));
% writetable(table(b_grid,x,x_bimodal,prior),strcat( 'sim_results/minimax_mu_sigmatb_',corr_str,'_B',string(B),'.csv'));
% adaptive soft and hard threshold
Eb = @(l) 1+l^2+...
     (bias_grid.^2-1-l^2).*(normcdf(l-bias_grid)-normcdf(-l-bias_grid))+...
     (-bias_grid-l).*normpdf(l-bias_grid) - (l-bias_grid).*normpdf(-l-bias_grid);
rho_grid = interp1(rho_tbl(:,1),rho_tbl(:,2),abs(bias_grid),'spline');
corr2 = Sigma_tb^2/(Sigma_t*Sigma_b); % squared corr. coef. 
cons = 1/corr2 - 1; 

% oracle risk funcion (minimax over |b|<=B with B set to true value of |b|)
risk_oracle = corr2*rho_grid+ 1-corr2;

omega_grid = (risk_oracle).^(-1); 
% Scaling

Eb_scaled = @(l) (corr2*Eb(l)+1-corr2).*omega_grid;
  
[st,f_fun,f_max]  = fminimax(Eb_scaled,0,[],[],[],[],0,[],[]);
risk_function_st_adaptive = (Eb(st) + 1/corr2 - 1);

%  hard threshold
Eb_ht = @(l) 1+(b_grid.^2-1).*(normcdf(l-b_grid)-normcdf(-l-b_grid))+...
 (l-b_grid).*normpdf(l-b_grid) - (-l-b_grid).*normpdf(-l-b_grid); % hard threshold
Eb_ht_scaled = @(l) (corr2*Eb_ht(l)+1-corr2).*omega_grid;
[ht,f_fun,f_max]  =fminimax(Eb_ht_scaled,st,[],[],[],[],0,[],[]);
risk_function_ht_adaptive = (Eb_ht(ht) + 1/corr2 - 1);
%  Save results

% writematrix([st;ht],...
% strcat( 'sim_results/minimax_mu_st_ht_sigmatb_',corr_str,'_B',string(B),'.csv'));

% writetable(table(bias_grid,risk_function_st_adaptive,...
% risk_function_ht_adaptive,...
% risk_oracle),...
% strcat( 'sim_results/risk_st_and_oracle_risk_sigmatb_',corr_str,'_B',string(B),'.csv'));
% Compute the constrained soft-threshold adaptive estimator
st5 = fminimax(Eb_scaled,0,[],[],[],[],0,[],@st5_risk);
% save('sim_results/AK91_const_thresholds.mat','st5');

%% Compute the constrained adaptive estimator
global x1_init
% Tbl = readtable(strcat( 'sim_results/minimax_mu_sigmatb_1_B9.csv'));  % preliminary estimates
% x1_init = Tbl.x; % initial guess for AK91
x1_init = x;
t0 = 0.0094; % initial guess for AK91
[x1,x1_grid,risk_function_x1] = risk_calc_const_tuning(Sigma,t0);

%check if the worst case risk is less than targeted
t_mat = [0.0094; 0.01]; idx = 1; % determine the increment
counter = 1; incre = (t_mat(idx+1,1)-t_mat(idx,1))/2;
max_nonlinear_v_unbiased = abs(max(risk_function_x1* Sigma_tb^2)-1.5); % for AK91
while max_nonlinear_v_unbiased  > 0.001
    if max(risk_function_x1* Sigma_tb^2) > 1.5  % for AK91
        t = t0-incre/counter;
    else
        t = t0+incre/counter;
    end
    x1_init = x1;
    tic
    [x1,x1_grid,risk_function_x1] = risk_calc_const_tuning(Sigma,t);
    toc
    %check if the worst case risk is less than targeted
    max(risk_function_x1* Sigma_tb^2)
    %max_nonlinear_v_unbiased = abs(max(risk_function_x1* Sigma_tb^2)-1.2
    max_nonlinear_v_unbiased = abs(max(risk_function_x1* Sigma_tb^2)-1.5) % for AK91
    t0=t
    counter = counter+1
end
% save(strcat('sim_results/const_priors_tuning_sigmatb_',corr_str,'.mat'),'b_grid','x1');
% save(strcat('sim_results/const_risk_tuning_sigmatb_',corr_str,'.mat'),'b_grid','risk_function_x1','t0');
% save(strcat('sim_results/const_policy_tuning_sigmatb_',corr_str,'.mat'),'y_grid','x1_grid');

%% Form nonlinear adaptive estimates
% Looks for the nonlinear estimates stored in the /sim_results/
% subdirectory
% Tbl = readtable(strcat('sim_results/minimax_adaptive_psi_sigmatb_',corr_str,'_B',string(B),'.csv'));
% t_grid = Tbl.y_grid; psi = Tbl.psi;
t_grid = y_grid; psi = psi;
% spline interpolate nonlinear estimate and apply scaling
disp('The adaptive estimate is')
t_tilde = interp1(t_grid,psi,tO,'spline');
adaptive_nonlinear = VUO/sqrt(VO) * t_tilde + CUE;

% 
% Tbl = readtable(strcat('sim_results/risk_and_oracle_risk_sigmatb_',corr_str,'_B',string(B),'.csv'));
% risk_function_adaptive =  VUO^2/VO* Tbl.risk_function_adaptive;
% risk_function_oracle = VUO^2/VO* Tbl.risk_oracle;
risk_function_adaptive =  VUO^2/VO* risk_function_adaptive;
risk_function_oracle = VUO^2/VO* (rho_b_over_sigma + 1/Sigma_tb^2 -1);

adaptive_st = VUO/sqrt(VO) * ((tO > st)*(tO - st) + (tO < -st)*(tO + st)) + CUE;

% Tbl = readtable(strcat('sim_results/risk_st_and_oracle_risk_sigmatb_',corr_str,'_B',string(B),'.csv'));
% bias_grid = Tbl.bias_grid; Kb = length(bias_grid);
% risk_function_st_adaptive = VUO^2/VO*  Tbl.risk_function_st_adaptive;
risk_function_st_adaptive = VUO^2/VO*  risk_function_st_adaptive;

adaptive_ht = VUO/sqrt(VO) * ((tO > ht)*(tO) + (tO < -ht)*(tO)) + CUE;

% risk_function_ht_adaptive = VUO^2/VO*  Tbl.risk_function_ht_adaptive;
risk_function_ht_adaptive = VUO^2/VO*  risk_function_ht_adaptive;
risk_function_YR = VO* bias_grid.^2 + VR;
%% Use similation to calculate the risk function for the pre-test estimator that switches btw Y_U and Y_R (which is efficient)
%risk_function_ht_ttest = VUO^2/VO*  Tbl.risk_function_ht_ttest;
pretest_ht = YR - sqrt(VO) * ((tO > 1.96)*(tO) + (tO < -1.96)*(tO));

sims = 100000;
rng(1,'twister');
x = normrnd(0,1,[sims,1]);
x_b = x*ones(1,Kb) + ones(sims,1)*bias_grid';
Ebsims_ht = @(l) sum(((x_b > l).*x_b + (x_b < l & x_b > -l)*(1+VO/VUO).*x_b + (x_b < -l).*x_b...
    -ones(sims,1)*bias_grid').^2,1)/sims;
risk_function_ht_ttest = VUO^2/VO*  (Ebsims_ht(1.96) + 1/corr^2 - 1)';
%% Adaptive estimate with constraint on maximum risk
% load(strcat('sim_results/const_priors_tuning_sigmatb_',corr_str,'.mat'));
% load(strcat('sim_results/const_risk_tuning_sigmatb_',corr_str,'.mat'));
% load(strcat('sim_results/const_policy_tuning_sigmatb_',corr_str,'.mat'));
t_tilde = interp1(y_grid,x1_grid,tO,'spline');%AK91
adaptive_estimate5 = VUO/sqrt(VO) * t_tilde + CUE;
 
% load('sim_results/AK91_const_thresholds.mat')
st_estimate5 = VUO/sqrt(VO) * ((tO > st5)*(tO-st5) + (tO < -st5)*(tO+st5) ) + CUE
Eb = @(l) 1+l^2+...
(bias_grid.^2-1-l^2).*(normcdf(l-bias_grid)-normcdf(-l-bias_grid))+...
(-bias_grid-l).*normpdf(l-bias_grid) - (l-bias_grid).*normpdf(-l-bias_grid);
risk_function_st5 = VUO^2/VO* (Eb(st5) + 1/corr^2 - 1); 
%% Export the results for Table 3 Unconstrained and constrained adaptation results

fid = fopen(strcat('../../tables/table3.tex'),'w'); 
fprintf(fid, '%s\n',' & $Y_{U}$ & $Y_{R}$ &  Pre-test & Adaptive & Soft-threshold & Hard-threshold & Const. Adaptive & Const. Soft-threshold ');
fprintf(fid, '%s\n', '\hline');
fprintf(fid, '%s\n', '\hline');
fprintf(fid, 'Estimate & %4.4f  & %4.3f & %4.3f  & %4.3f& %4.3f& %4.3f  & %4.3f & %4.3f \n',...
YU,YR, pretest_ht, adaptive_nonlinear, adaptive_st,adaptive_ht, ...
adaptive_estimate5,st_estimate5  );
fprintf(fid, '%s\n', '\hline');
fprintf(fid, 'Std error & %4.4f  & %4.4f  &   &   &   & &  &   &  \n',...
sqrt(VU),sqrt(VR) );

fprintf(fid, '%s\n', '\hline');
fprintf(fid, '%s\n', 'Max Risk rel. to $Y_U$ &&&&&&');
fprintf(fid, '%s\n', '\hline');
fprintf(fid, ' & 0\\%% &  %s  & %4.0f\\%% & %4.0f\\%% & %4.0f\\%% & %4.0f\\%% & %4.0f\\%% & %4.0f\\%% \n', ...
'$\infty$',100* (max(risk_function_ht_ttest)/VU -1), 100*(max(risk_function_adaptive)/VU - 1),...
100*(max(risk_function_st_adaptive)/VU - 1), 100*(max(risk_function_ht_adaptive)/VU - 1),...
100*(max( VUO^2/VO*risk_function_x1./VU)- 1), 100*(max(risk_function_st5./VU)- 1) );
fprintf(fid, '%s\n', '\hline');

fprintf(fid, 'Max Regret & %4.0f\\%%  &  %s  & %4.0f\\%% & %4.0f\\%% & %4.0f\\%% & %4.0f\\%% & %4.0f\\%% & %4.0f\\%% \n',...
100* ( VU/(VU  - VUO^2/VO)-1 ), '$\infty$', 100*(max(risk_function_ht_ttest./risk_function_oracle)-1), 100*(max(risk_function_adaptive./risk_function_oracle) - 1),...
100*(max(risk_function_st_adaptive./risk_function_oracle) - 1), 100*(max(risk_function_ht_adaptive./risk_function_oracle) - 1),...
100*(max( VUO^2/VO*risk_function_x1./(risk_function_oracle))-1), 100*(max( risk_function_st5./risk_function_oracle)-1));

fprintf(fid, '%s\n', '\hline');

fprintf(fid, 'Threshold &  &    & %4.2f  &  & %4.2f& %4.2f  & & %4.2f \n',...
1.96, st, ht, st5);
fprintf(fid,'The relative efficiency is %4.4f', 1-corr^2); 
fprintf(fid,'The t-stat %4.4f', tO); 

fclose(fid);

          