rm(list=ls())
library(R.matlab)
library(akima) # spline interpolation package
library(pracma) # meshgrid function
library(xtable)
library(ggplot2)
library(tidyverse)
library(latex2exp)
library(ggplot2)
# Define color and linetype preferences
cbp1 <- c("#999999", "#E69F00", "#56B4E9", "#009E73", "#F0E442")
my_linetypes <- c('solid', 'dotted', 'dashed', 'dotdash','longdash')
my_linetypes2 <- c('solid', 'solid', 'dashed', 'dotdash','longdash')
# Load data
B = 9;
# Define variables
Sigma_UO_grid <- abs(tanh(seq(-3, -0.05, 0.05)))
minimax <- readMat('../Matlab/sim_results/risk.mat')   # Load minimax risk
mse <- readMat('../Matlab/sim_results/emse_corr.mat')

b_grid <- minimax$b.grid;
Kb <- length(b_grid)
 

Eb_ht <- function(l) {
  result <- 1+(b_grid^2-1) *(pnorm(l-b_grid)-pnorm(-l-b_grid))+ 
(l-b_grid) *dnorm(l-b_grid) - (-l-b_grid) *dnorm(-l-b_grid)
return(result)
}
 

risk_function_ht_ttest <-   Eb_ht(1.96)  
  
#####  Calculate the oracle risk function
rho_tbl <- read.csv('../Matlab/sim_results/minimax_rho_B9.csv', header = FALSE)
rho_b_over_sigma_function<- splinefun(rho_tbl[,1], rho_tbl[,2], method = "fmm", ties = mean)
rho_grid <- rho_b_over_sigma_function(abs(b_grid))

# Soft-threshold risk function
Eb <- function(l) {
  result <- 1 + l^2 +
    (b_grid^2 - 1 - l^2) * (pnorm(l - b_grid) - pnorm(-l - b_grid)) +
    (-b_grid - l) * dnorm(l - b_grid) - (l - b_grid) * dnorm(-l - b_grid)
  return(result)
}



thresholds <- readMat('../Matlab/sim_results/thresholds.mat')


# Initialize variables
penalty_ht <- numeric(length(Sigma_UO_grid))
penalty_pretest <- numeric(length(Sigma_UO_grid))
penalty_st <- numeric(length(Sigma_UO_grid)) 
penalty_adaptive <- numeric(length(Sigma_UO_grid))
  
max_risk_ht <- numeric(length(Sigma_UO_grid))
max_risk_pretest <- numeric(length(Sigma_UO_grid))
max_risk_st <- numeric(length(Sigma_UO_grid)) 
max_risk_adaptive <- numeric(length(Sigma_UO_grid))


##### Calculate penalties and risk
for (i in 1:length(Sigma_UO_grid)) {
  corr2 <- Sigma_UO_grid[i]^2
  risk_oracle <- rho_grid + 1/corr2 - 1
  penalty_ht[i] <- max((Eb_ht(thresholds$ht.mat[i]) + 1/corr2 - 1) / risk_oracle)
  penalty_pretest[i] <- max((risk_function_ht_ttest + 1/corr2 - 1) / risk_oracle)
  
  penalty_adaptive[i] <-max(minimax$risk.mat[,i] / risk_oracle)
  # penalty_erm[i] <- max((risk_function_erm+ 1/corr2 - 1) / risk_oracle)
  penalty_st[i] <- max((Eb(thresholds$st.mat[i]) + 1/corr2 - 1) / risk_oracle) 
  
  max_risk_ht[i] <- corr2*max((Eb_ht(thresholds$ht.mat[i]) + 1/corr2 - 1))
  max_risk_pretest[i] <- corr2*max((risk_function_ht_ttest + 1/corr2 - 1))
  max_risk_adaptive[i] <-corr2*max(minimax$risk.mat[,i])
  # max_risk_erm[i] <- corr2*max((risk_function_erm+ 1/corr2 - 1))
  max_risk_st[i] <-corr2*max((Eb(thresholds$st.mat[i]) + 1/corr2 - 1)) 
  
}
 

##### ONLINE APPENDIX FIGURE A1

##### Adaptive penalty plot
df <- data.frame(Sigma_UO_grid = 1-Sigma_UO_grid^2, 
                 a = penalty_adaptive ,
                 b = penalty_st ,
                 c = penalty_ht ,
                 d = penalty_pretest )
 
# 

# figurename <- paste("./sim_results/penalty_against_corr", ".png", sep = "")
figurename <- paste("../../figures/figureA1", ".png", sep = "")

par(mar = c(5, 5, 4, 5), pty = 'm',cex=1.2)  # Adjust margins as needed
df_long <- tidyr::gather(df, key = "series", value = "y", -Sigma_UO_grid)
#[Sigma_UO_grid>0.85,]
plot <- ggplot(df_long, aes(x = Sigma_UO_grid, y = y, linetype = series, color = series)) +
  geom_line(size=1.2) +
  labs(x = TeX('Relative efficiency $\\sigma^2_{R,GMM}/\\sigma^2_{U}$'), 
       y = "Worst-case adaptation regret", 
       title = "") +
  scale_y_log10(breaks=c(1,2,3,6,11),
                labels=c("0%","100%","200%","500%","1000%")) + # scale_y_continuous(labels = scales::percent_format(accuracy = 1),limits = c(0,10)) + 
  theme_minimal() +
  theme(
    legend.position = "bottom",
    legend.direction = "horizontal",
    legend.box = "horizontal",
    legend.text = element_text(size = 12),  # Adjust legend text size
    axis.text = element_text(size = 14),
    axis.title = element_text(size = 14)
  ) +
  scale_color_manual(values = cbp1, name = "",
                     labels = c(TeX('Opt. adaptive'),'Opt. soft threshold','Opt. hard threshold','Pre-test')) + #TeX('$\\log(\\rho^2/(1-\\rho^2)$')
  scale_linetype_manual(values = my_linetypes, name = "",
                        labels = c(TeX('Opt. adaptive'),'Opt. soft threshold','Opt. hard threshold','Pre-test')) +
  guides(color = guide_legend(ncol = 3))

print(plot)
# Save the plot
ggsave(filename = figurename,
       plot, width = 15, height = 10, units = "cm")

##### ONLINE APPENDIX FIGURE A2

df_risk <- data.frame(Sigma_UO_grid = 1-Sigma_UO_grid^2, 
                      a = max_risk_adaptive  ,
                      b = max_risk_st  ,
                      c = max_risk_ht  ,
                      d = max_risk_pretest  )
df_long <- tidyr::gather(df_risk, key = "series", value = "y", -Sigma_UO_grid)

# figurename <- paste("./sim_results/ht_st_risk_against_corr", ".png", sep = "")
figurename <- paste("../../figures/figureA2", ".png", sep = "")

plot <- ggplot(df_long, aes(x = Sigma_UO_grid, y = y, linetype = series, color = series)) +
  geom_line(size=1.2) +
  labs(x = TeX('Relative efficiency $\\sigma^2_{R,GMM}/\\sigma^2_{U}$'), 
       y = "Worst-case risk increase", 
       title = "") +
  scale_y_log10(breaks=c(1,2,3),
                labels=c("0%","100%","200%")) + #  scale_y_continuous(labels = scales::percent_format(accuracy = 1)) + 
  theme_minimal() +
  theme(
    legend.position = "bottom",
    legend.direction = "horizontal",
    legend.box = "horizontal",
    legend.text = element_text(size = 12),  # Adjust legend text size
    axis.text = element_text(size = 14),
    axis.title = element_text(size = 14)
  ) +
  scale_color_manual(values = cbp1, name = "",
                     labels = c(TeX('Opt. adaptive'),'Opt. soft threshold','Opt. hard threshold','Pre-test')) + #TeX('$\\log(\\rho^2/(1-\\rho^2)$')
  scale_linetype_manual(values = my_linetypes, name = "",
                        labels = c(TeX('Opt. adaptive'),'Opt. soft threshold','Opt. hard threshold','Pre-test')) +
  guides(color = guide_legend(ncol = 3))

print(plot)
# Save the plot
ggsave(filename = figurename,
       plot, width = 15, height = 10, units = "cm")
 
##### FIGURE 4
 # figurename <- paste("./sim_results/ht_against_corr", ".png", sep = "")
figurename <- paste("../../figures/figure4", ".png", sep = "")
 data <- data.frame(one_minus_corr2 = 1 - Sigma_UO_grid^2,
                    a = thresholds$st.mat,
                    b = thresholds$ht.mat,
                    c = mse$MSE.lambda.mat) 
 df_long <- tidyr::gather(data, key = "series", value = "y", -one_minus_corr2)
 plot <- ggplot(df_long, aes(x = one_minus_corr2, y = y, linetype = series, color = series)) +
   geom_line(size=1.2) +
   labs(x = TeX('Relative efficiency $\\sigma^2_{R,GMM}/\\sigma^2_{U}$'), 
        y = TeX('Optimal $\\lambda$'), 
        title = "")  + 
   theme_minimal() +
   scale_y_continuous(trans='log2') + 
   theme(
     legend.position = "bottom",
     legend.direction = "horizontal",
     legend.box = "horizontal",
     legend.text = element_text(size = 12),  # Adjust legend text size
     axis.text = element_text(size = 14),
     axis.title = element_text(size = 14)
   ) +
   scale_color_manual(values = cbp1, name = "",
                      labels = c('Soft-threshold', 'Hard-threshold', 'Adaptive ERM'))+ #TeX('$\\log(\\rho^2/(1-\\rho^2)$')
   scale_linetype_manual(values = my_linetypes, name = "",
                         labels = c('Soft-threshold', 'Hard-threshold', 'Adaptive ERM')) +
   guides(color = guide_legend(ncol = 3))
 print(plot)
 ggsave(filename = figurename,
        plot, width = 15, height = 10, units = "cm")
 ##### text surrounding Figure 4 for approximaing the threshold
 
 approx1 <- lm(thresholds$st.mat -0.45~-1+log((1-Sigma_UO_grid^2)) )
 data <- data.frame(Sigma_UO_grid = log(1 - Sigma_UO_grid^2),
                    approx1 = predict(approx1)+0.45,
                    st_mat = thresholds$st.mat )
 sink("../../figures/figure4text.txt")
 summary(approx1)
 c(min(Sigma_UO_grid^2), max(Sigma_UO_grid^2))
 sink()