library(dplyr)
library(data.table)
library(ggplot2)
options(stringsAsFactors = F)
parse_pop = function(x) {
  x = unlist(lapply(strsplit(x, '-'), function(y) {
    if(length(y) > 1) {
      return(paste0(y[1], '-', y[2]))
    } else {
      return(y)
    }
  }))
}
source('../code/rlib_doc.R')

1 Load results

1.1 PTRS with LD block based pruning

traits = tolower(read.csv('../external_data/martin_et_al_2019ng_table_s6.csv')$Trait)
diag_df = data.frame(trait = traits, index = 1 : 17)
result = list()
for(i in 1 : 17) {
  for(t in traits) {
    filename = paste0('~/Desktop/tmp/ptrs_ldblock_r2/ptrs-r2_subset', i, '_x_', t, '.txt')
    tmp = read.table(filename, header = T, sep = '\t')
    result[[length(result) + 1]] = tmp %>% mutate(subset = i, trait = t)
  }
}
df_result = do.call(rbind, result)
df_result$population = parse_pop(df_result$population)
df_result$ptrs_col = factor(df_result$ptrs_col, levels = unique(df_result$ptrs_col))

1.2 PTRS (naive implementation)

result = list()
for(i in 1 : 17) {
  for(t in traits) {
    filename = paste0('~/Desktop/tmp/ptrs_r2/ptrs-r2_subset', i, '_x_', t, '.txt')
    tmp = read.table(filename, header = T, sep = '\t')
    result[[length(result) + 1]] = tmp %>% mutate(subset = i, trait = t)
  }
}
df_result2 = do.call(rbind, result)
df_result2$population = parse_pop(df_result$population)
df_result2$ptrs_col = factor(df_result$ptrs_col, levels = unique(df_result$ptrs_col))

# df_ptrs = rbind(df_result %>% select(-SSE.wo, -SSE.with) %>% mutate(type = 'ldblock'), df_result2 %>% mutate(type = 'naive'))

1.3 PTRS with mashr p-value

result = list()
for(i in 1 : 17) {
  for(t in traits) {
    filename = paste0('~/Desktop/tmp/ptrs_mashr_r2/ptrs-r2_subset', i, '_x_', t, '.txt')
    tmp = read.table(filename, header = T, sep = '\t')
    result[[length(result) + 1]] = tmp %>% mutate(subset = i, trait = t)
  }
}
df_result3 = do.call(rbind, result)
df_result3$population = parse_pop(df_result$population)
df_result3$ptrs_col = factor(df_result$ptrs_col, levels = unique(df_result$ptrs_col))

df_ptrs = rbind(df_result %>% select(-SSE.wo, -SSE.with) %>% mutate(type = 'ldblock'), df_result2 %>% mutate(type = 'naive'), df_result3 %>% select(-SSE.wo, -SSE.with) %>% mutate(type = 'mashr'))

2 Plot

2.1 Side-by-side comparison

df_side_by_side = df_ptrs %>% dcast(population + ptrs_col + subset + trait ~ type, value.var = 'r2')
df_side_by_side %>% ggplot() + geom_point(aes(x = naive, y = ldblock, color = ptrs_col), alpha = .5) + facet_grid(trait~population, scales = 'free') + geom_abline(slope = 1, intercept = 0)

df_side_by_side %>% ggplot() + geom_point(aes(x = naive, y = mashr, color = ptrs_col), alpha = .5) + facet_grid(trait~population, scales = 'free') + geom_abline(slope = 1, intercept = 0)

df_side_by_side %>% ggplot() + geom_point(aes(x = ldblock, y = mashr, color = ptrs_col), alpha = .5) + facet_grid(trait~population, scales = 'free') + geom_abline(slope = 1, intercept = 0)

2.2 Best British model

tmp = list()
for(t in unique(df_ptrs$type)) {
  sub = df_ptrs %>% filter(type == t)
  o = best_model_based_on_one(sub, 'British-validation', 'ptrs_col', 'r2')
  tmp[[length(tmp) + 1]] = o$perf_in_all %>% mutate(type = t)
}
df_best_british = do.call(rbind, tmp)
tmp = df_best_british %>% dcast(trait + population + subset ~ type, value.var = 'r2') # %>% mutate(diff = ldblock - naive) %>% ggplot() + geom_boxplot(aes(x = trait, y = diff, color = population)) + theme(legend.position = 'bottom', axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1))# + scale_y_log10()
for(pop in unique(tmp$population)) {
  p = myggpairs(tmp %>% filter(population == pop) %>% select(ldblock, naive, mashr), tmp %>% filter(population == pop) %>% pull(trait)) + geom_abline(slope = 1, intercept = 0)
  print(p + ggtitle(pop))
}

tmp %>% ggplot() + geom_point(aes(x = naive, mashr, color = trait)) + facet_wrap(~population, scales = 'free') + geom_abline(slope = 1, intercept = 0)

2.3 Best within each population model

tmp = list()
for(t in unique(df_ptrs$type)) {
  sub = df_ptrs %>% filter(type == t)
  o = best_model_for_each(sub, 'British-validation', 'ptrs_col', 'r2')
  tmp[[length(tmp) + 1]] = o$perf_in_all %>% mutate(type = t)
}
df_best_each = do.call(rbind, tmp)
tmp = df_best_each %>% dcast(trait + population + subset ~ type, value.var = 'r2') # %>% mutate(diff = ldblock - naive) %>% ggplot() + geom_boxplot(aes(x = trait, y = diff, color = population)) + theme(legend.position = 'bottom', axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1))# + scale_y_log10()
for(pop in unique(tmp$population)) {
  p = myggpairs(tmp %>% filter(population == pop) %>% select(ldblock, naive, mashr), tmp %>% filter(population == pop) %>% pull(trait)) + geom_abline(slope = 1, intercept = 0)
  print(p + ggtitle(pop))
}

tmp %>% ggplot() + geom_point(aes(x = naive, mashr, color = trait)) + facet_wrap(~population, scales = 'free') + geom_abline(slope = 1, intercept = 0)

3 Regulability/heritability vs. PTRS/PRS

regu = readRDS('../analysis_output/regulability_ctimp.rds')
regu$population = parse_pop(regu$population)
regu$h_sq[is.na(regu$h_sq)] = 0
df_best_each_cleaned = df_best_each # %>% 
  # filter(paste(trait, subset) %in% paste(diag_df$trait, diag_df$index))
df_best_british_cleaned = df_best_british # %>% 
  # filter(paste(trait, subset) %in% paste(diag_df$trait, diag_df$index))
df_best_each_cleaned = inner_join(df_best_each_cleaned, regu, by = c('trait', 'population'))
df_best_british_cleaned = inner_join(df_best_british_cleaned, regu, by = c('trait', 'population'))
df_best_each_cleaned %>% ggplot() + geom_point(aes(x = h_sq, y = r2, color = type), alpha = .5) + facet_wrap(~population) + ggtitle('Best model within each population')

df_best_british_cleaned %>% ggplot() + geom_point(aes(x = h_sq, y = r2, color = type), alpha = .5) + facet_wrap(~population) + ggtitle('Best British model')