library(data.table)
library(ggplot2)
library(scales)

if(!exists("data")){ #check if data variable is already defined
  data <- readRDS(file='data.rds')  
}

stats <- setDT(data)[,list(Mean=mean(coverage), Max=max(coverage), Min=min(coverage), Median=as.numeric(median(coverage)), Std=sd(coverage)), by=c("aut","size","strategy")]

#plot stats

p <- ggplot(data=stats, aes(x=size,y=Std,group=aut)) +
  geom_line() +
  geom_point() +
  facet_grid(strategy~aut) +
  theme_bw()
print(p)

p <- ggplot(data=stats, aes(x=size,y=Median,group=aut)) +
  geom_line() +
  geom_point() +
  facet_grid(strategy~aut) +
  theme_bw()
print(p)

p <- ggplot(data=stats, aes(x=size,y=Mean,group=aut)) +
  geom_line() +
  geom_point() +
  facet_grid(strategy~aut) +
  theme_bw()
print(p)

# plot heatmap

# prepare dataframe

names <- c("AUT", "UETsize", "IETsize","DifferenceOfMedians", "DifferenceOfMeans")
heatmapData <- data.frame(AUT = NA, UETsize = NA, IETsize=NA, DifferenceOfMedians = NA, DifferenceOfMeans = NA)
names(heatmapData) <- names
heatmapData <- heatmapData[FALSE,]

for(AUT in unique(stats$aut)){
  for(sUET in c(2:19)){
    for(sIET in c(2:19)){
      
      medianUET <- stats[which(stats$aut==AUT & stats$size==sUET & stats$strategy=='UET'),]$Median
      meanUET   <- stats[which(stats$aut==AUT & stats$size==sUET & stats$strategy=='UET'),]$Mean
      
      medianIET <- stats[which(stats$aut==AUT & stats$size==sIET & stats$strategy=='IET'),]$Median
      meanIET   <- stats[which(stats$aut==AUT & stats$size==sIET & stats$strategy=='IET'),]$Mean
      
      medianDiff <- medianIET - medianUET
      meanDiff   <- meanIET - meanUET
  
      record <- data.frame(
        AUT, 
        sUET, sIET, medianDiff, meanDiff
      )
      names(record) <- names
      
      heatmapData <- rbind(heatmapData,record)
    }
  }
}

#get percentages
heatmapData$DifferenceOfMedians <- heatmapData$DifferenceOfMedians*100
heatmapData$DifferenceOfMeans   <- heatmapData$DifferenceOfMeans*100


heatmap <- ggplot(data=heatmapData, aes(x=UETsize,y=IETsize,fill=DifferenceOfMedians)) +
  geom_tile() + #color="white",size=0.1
  scale_fill_viridis_c(breaks=seq(-3,10,by=1.5))+
  facet_wrap(~AUT,ncol=4) 
print(heatmap)


heatmap <- ggplot(data=heatmapData, aes(x=UETsize,y=IETsize,fill=DifferenceOfMeans)) +
  geom_tile() +
  scale_fill_viridis_c(breaks=seq(-3,10,by=1.5))+
  facet_wrap(~AUT,ncol=4)
print(heatmap)


# GREEN/RED

heatmap <- ggplot(data=heatmapData, aes(x=UETsize,y=IETsize,fill=DifferenceOfMedians)) +
  geom_tile() +
  #scale_fill_viridis_c(breaks=seq(-3,10,by=1.5))+
  scale_fill_gradient2(low = "red", high = "green",mid = "white", breaks=seq(-3,10,by=1.5))+
  facet_wrap(~AUT,ncol=4) 
print(heatmap)


heatmap <- ggplot(data=heatmapData, aes(x=UETsize,y=IETsize,fill=DifferenceOfMeans)) +
  geom_tile() +
  #scale_fill_viridis_c(breaks=seq(-3,10,by=1.5))+
  scale_fill_gradient2(low = "red", high = "green",mid = "white", breaks=seq(-3,10,by=1.5))+
  facet_wrap(~AUT,ncol=4)
print(heatmap)


# ALTERNATE

meansMin <- min(rescale(heatmapData$DifferenceOfMeans)[which(heatmapData$DifferenceOfMeans<0.1 & heatmapData$DifferenceOfMeans>-0.1)])
meansMax <- max(rescale(heatmapData$DifferenceOfMeans)[which(heatmapData$DifferenceOfMeans<0.1 & heatmapData$DifferenceOfMeans>-0.1)])
meansAvg <- mean(rescale(heatmapData$DifferenceOfMeans)[which(heatmapData$DifferenceOfMeans<0.1 & heatmapData$DifferenceOfMeans>-0.1)])

heatmap <- ggplot(data=heatmapData, aes(x=UETsize,y=IETsize,fill=DifferenceOfMedians)) +
  geom_tile() +
  #scale_fill_viridis_c(breaks=seq(-3,10,by=1.5))+
  scale_fill_gradientn(colours = c("red4","white","white","lightgreen","springgreen4"), values=c(0,0.12,0.17,0.2,1), breaks=seq(-3,10,by=1.5))+
  facet_wrap(~AUT,ncol=4) 
print(heatmap)


heatmap <- ggplot(data=heatmapData, aes(x=UETsize,y=IETsize,fill=DifferenceOfMeans)) +
  geom_tile() +
  #scale_fill_viridis_c(breaks=seq(-3,10,by=1.5))+
  scale_fill_gradientn(colours = c("red4","#db9a9a","white","lightgreen","#006130"), values=c(0,0.12,0.17,0.2,1), breaks=seq(-3,10,by=1.5))+
  #[scale_fill_gradient2(low = "red4", high = "springgreen4",mid = "white", breaks=seq(-3,10,by=1.5))+
  facet_wrap(~AUT,ncol=4) 
print(heatmap)