load("data_meta_otu.Rdata")

library(jsonify)
library(Hmsc)
library(ape)
library(phytools)
library(gllvm)
library(vioplot)
source("as.phylo.formula.R")

otu.table = otu.table.plausible
taxonomy = taxonomy.plausible
read.counts = read.counts.plausible

sel = read.counts$filt_nread>=10000
otu.table = otu.table[sel,]
meta = meta[sel,]
read.counts = read.counts[sel,]
read.counts$spike_nread = read.counts$filt_nread-read.counts$nospike_nread
meta$dna_amount = read.counts$nospike_nread/read.counts$spike_nread
for(i in 1:nrow(otu.table)){
  if(sum(otu.table[i,])>0) otu.table[i,] = otu.table[i,]/sum(otu.table[i,])
}

inocs = c("Antrodia_piceata",
          "Antrodiella_citrinella",
          "Fomitopsis_rosea",
          "Perenniporia_subacida",
          "Physisporinus_crocatus",
          "Postia_guttulata",
          "Skeletocutis_odora",
          "Skeletocutis_stellae",
          "Steccherinum_collabens"
)

spp = c("Antrodia_piceata_813073",
        "Flaviporus_citrinellus_106139",
        "Rhodofomes_roseus_127496",
        "Perenniporia_subacida_335816",
        "Rigidoporus_crocatus_107233",
        "Postia_guttulata_110917",
        "Skeletocutis_odora_106476",
        "Skeletocutis_stellae_323595",
        "Junghuhnia_collabens_315979")

ni = length(inocs)
inoc.indices = rep(NA,ni)
for(i in 1:ni) inoc.indices[i] = which(taxonomy$species==spp[i])
otu.ta = otu.table[,-inoc.indices]
otu.io = otu.table[,inoc.indices]
taxonomy = taxonomy[-inoc.indices,]
log.id = meta$RunningLogID
log.ids = levels(log.id)
nl = length(log.ids)
log.io.s = matrix(0,nrow=nl,ncol=ni)
n = dim(otu.ta)[1]

S = rowSums(otu.ta>0)
plot(meta$dna_amount,S)
prev = colSums(otu.ta>0)

if(FALSE){
  cm = as.matrix(1*(otu.ta>0))
  cm = cm[,colSums(cm)>9]
  my.gllvm = gllvm(cm, family = "binomial", link = "probit")
  lv1=as.numeric(scale(my.gllvm$lvs[,1]))
  lv2=as.numeric(scale(my.gllvm$lvs[,2]))
  save(lv1,lv2,file = "processed_data/lv12.RData")
} else {
  load("processed_data/lv12.RData")
}

plot(lv1,lv2)
sel = prev>=100
S.included = rowSums(otu.ta[,sel]>0)
S.excluded = rowSums(otu.ta[,!sel]>0)
taxonomy = taxonomy[sel,]
otu.ta = otu.ta[,sel]
Y.pa = 1*(otu.ta>0)

site = meta$SiteNew
year = meta$SamplingYear
year2020 = 1*(year==2020)
year2021 = 1*(year==2021)
treatment = as.character(meta$InocSpecies)
treatment[treatment=="control"] = "aa_control"
treatment = as.factor(treatment)

LT = as.character(meta$LogType)
LT[LT=="R"] = "felled"
LT[LT=="B"] = "natural-broken"
LT[LT=="U"] = "natural-uprooted"
LT = as.factor(LT)
NF = rep("natural",length(LT)) #natural vs felled, felled as baseline
NF[LT=="felled"] = "felled"
NF = as.factor(NF)
UB = rep(0,length(LT)) #uprooted vs broken, average as baseline 
UB[LT=="natural-broken"] = -1/2
UB[LT=="natural-uprooted"] = 1/2

DecayStage = meta$DecayStage2019
DecayStage[year==2020]=(meta$DecayStage2021[year==2020]+meta$DecayStage2019[year==2020])/2
DecayStage[year==2021]=meta$DecayStage2021[year==2021]

seq.depth = log(read.counts$filt_nread)
XData = data.frame(NF,UB,DecayStage,year,year2020,year2021,treatment,seq.depth)
XFormula = ~NF * (year2020+year2021) + UB + DecayStage + (year2020 + year2021 + NF)*treatment -treatment + seq.depth 

studyDesign = data.frame(log.id,site)
rL.log.id = HmscRandomLevel(units = levels(studyDesign$log.id))
rL.site = HmscRandomLevel(units = levels(studyDesign$site))

taxonomy$OTU = as.factor(rownames(taxonomy))
my.tree = as.phylo.formula(~kingdom/phylum/class/order/family/genus/OTU,
                           data = taxonomy)
plot(my.tree,cex=0.1)

# Presence-absence model
m.pa = Hmsc(Y.pa, XData = XData, XFormula = XFormula,
            studyDesign = studyDesign,
            phyloTree = my.tree,
            ranLevels = list(log.id=rL.log.id, site = rL.site),
            distr="probit")


control = which(m.pa$XData$treatment == "aa_control")
S = rowSums(m.pa$Y)[control]
natural = m.pa$XData$NF[control]
year = m.pa$XData$year[control]

pdf("S versus natural year.pdf")
vioplot(S~natural*year)
dev.off()

# Community facets model
Y.facets = as.matrix(cbind(S,S.included,S.excluded,log(meta$dna_amount),lv1,lv2))
colnames(Y.facets) = c("S","S.included","S.excluded","dna","lv1","lv2")

m.facets = Hmsc(Y=Y.facets, XData = XData, XFormula = XFormula,
                studyDesign = studyDesign,
                ranLevels = list(log.id=rL.log.id, site = rL.site),
                distr=c("lognormal poisson","lognormal poisson","lognormal poisson","normal","normal","normal"))

models = list(m.pa,m.facets)
names(models) = c("m.pa","m.facets")
save(models,file = "models/unfitted_models.RData")

thin = 1
samples = 250
nChains = 4
transient = round(0.5*thin*samples)
verbose = 10
m = m.pa
init_obj = sampleMcmc(m, samples=samples, thin=thin,
                      transient=transient, nChains=nChains,
                      verbose=verbose, engine="HPC")
init_file_path = file.path(getwd(), paste0("models/init_file_pa.rds"))
saveRDS(to_json(init_obj), file=init_file_path)
todofile = paste0("models/todo_pa.txt")
write("#!/bin/bash",file=todofile,append=FALSE)
for(thin in c(1,10,100,1000)){
  write(paste0("python3 -m hmsc.run_gibbs_sampler --input ./init_file_pa.rds --output ./post_file_pa_",as.character(thin),".rds --samples ",as.character(samples),
               " --transient ",as.character(round(samples*0.5*thin)),
               " --thin ",as.character(thin),
               " --verbose ",as.character(verbose)),file=todofile,append = TRUE) 
}
