#!/usr/bin/rb

# filenames
range_fn = "data/__data/elmidae.range.nex"
tree_fn   = "data/__data/elmidae.newick"
out_fn   = "output/Antarctica_2areas_Xno0/OUT"
geo_fn   = "data/Antarctica/elmidae"
times_fn = geo_fn + ".times.txt"
dist_fn  = geo_fn + ".distances.txt"

########
# data #
########

# read binary (01) presence-absence range data
dat_range_01 = readDiscreteCharacterData(range_fn)
n_areas <- dat_range_01.nchar()

# determine the number of states
max_areas <- 2
n_states <- 0
for (k in 0:max_areas) n_states += choose(n_areas, k)

# convert binary ranges into NaturalNumbers
dat_range_n = formatDiscreteCharacterData(dat_range_01, "DEC", n_states)

# epoch times
time_bounds <- readDataDelimitedFile(file=times_fn, delimiter=" ")
n_epochs <- time_bounds.nrows()

# epoch connectivity
for (i in 1:n_epochs) {
    epoch_fn = geo_fn + ".connectivity." + i + ".txt"
    connectivity[i] <- readDataDelimitedFile(file=epoch_fn, delimiter=" ")
}

# epoch distances
for (i in 1:n_epochs) {
    epoch_fn = geo_fn + ".distances." + i + ".txt"
    distances[i] <- readDataDelimitedFile(file=epoch_fn, delimiter=" ")
}

# helper variables
n_gen = 30000
moves = VectorMoves()
monitors = VectorMonitors()

# get the converted state descriptions
state_desc = dat_range_n.getStateDescriptions()

# write the state descriptions to file
state_desc_str = "state,range\n"
for (i in 1:state_desc.size())
{
    state_desc_str += (i-1) + "," + state_desc[i] + "\n"
}
write(state_desc_str, file=out_fn+".state_labels.txt")


###############
# Tree models #
###############

# read tree
tree <- readTrees(tree_fn)[1]

#######################
# Biogeography models #
#######################

# the biogeographic rate scaler
rate_bg ~ dnLoguniform(1E-4,1E2)
rate_bg.setValue(1E-2)
moves.append( mvScale(rate_bg, weight=4) )

# relative dispersal rate is fixed to one
dispersal_rate <- 1.0

# the geographical distance scaling factor
distance_scale ~ dnUnif(0.0005,20)
distance_scale.setValue(0.01)
moves.append( mvScale(distance_scale, weight=3) )

# then, the dispersal rate matrix
for (i in 1:n_epochs) {
    for (j in 1:n_areas) {
        for (k in 1:n_areas) {
            dr[i][j][k] <- 0.0
            if (connectivity[i][j][k] > 0) {
                dr[i][j][k] := dispersal_rate * exp(-distance_scale * distances[i][j][k])
            }
        }
    }
}

# then extirpation rates (or per-area extinction rates)
log_sd <- 0.5
log_mean <- ln(1) - 0.5*log_sd^2
extirpation_rate ~ dnLognormal(mean=log_mean, sd=log_sd)
moves.append( mvScale(extirpation_rate, weight=2) )

# the extirpation rate matrix
for (i in 1:n_epochs) {
    for (j in 1:n_areas) {
        for (k in 1:n_areas) {
            er[i][j][k] <- 0.0
        }
        er[i][j][j] := extirpation_rate
    }
}

# build DEC rate matrices
#for (i in 1:n_epochs) {
for (i in n_epochs:1) {
    Q_DEC[i] := fnDECRateMatrix(dispersalRates=dr[i],
                                extirpationRates=er[i],
                                maxRangeSize=max_areas)
}

# build the times
for (i in 1:n_epochs) {
    time_max[i] <- time_bounds[i][1]
    time_min[i] <- time_bounds[i][2]

    if (i == n_epochs) {
        epoch_times[i] <- 0.0
    } else {
        epoch_times[i] ~ dnUniform(time_min[i], time_max[i])
        moves.append( mvSlide(epoch_times[i], delta=(time_max[i]-time_min[i])/2) )
    }
}

# combine the epoch rate matrices and times
Q_DEC_epoch := fnEpoch(Q=Q_DEC, times=epoch_times, rates=rep(1, n_epochs))

    
# build cladogenetic transition probabilities
clado_event_types <- [ "s", "a" ]
p_sympatry ~ dnUniform(0,1)
p_allopatry := abs(1.0 - p_sympatry)
clado_type_probs := simplex(p_sympatry, p_allopatry)
moves.append( mvSlide(p_sympatry, weight=2) )
P_DEC := fnDECCladoProbs(eventProbs=clado_type_probs,
                         eventTypes=clado_event_types,
                         numCharacters=n_areas,
                         maxRangeSize=max_areas)


# root frequencies
# all states eyually probable
rf_DEC_tmp    <- rep(1, n_states)
rf_DEC    <- simplex(rf_DEC_tmp)

# the phylogenetic CTMC with cladogenetic events
m_bg ~ dnPhyloCTMCClado(tree=tree,
                           Q=Q_DEC_epoch,
                           cladoProbs=P_DEC,
                           branchRates=rate_bg,
                           rootFrequencies=rf_DEC,
                           type="NaturalNumbers",
                           nSites=1)
    
# attach the range data
m_bg.clamp(dat_range_n)

############
# Monitors #
############


monitors.append( mnScreen(printgen=100, rate_bg, extirpation_rate, distance_scale) )
monitors.append( mnModel(file=out_fn+".model.log", printgen=10) )
monitors.append( mnFile(tree, file=out_fn+".trees", printgen=10) )
monitors.append( mnFile(tree, filename=out_fn+".tre", printgen=10) )
monitors.append( mnJointConditionalAncestralState(tree=tree,
                                                       ctmc=m_bg,
                                                       type="NaturalNumbers",
                                                       withTips=true,
                                                       withStartStates=true,
                                                       filename=out_fn+".states.log",
                                                       printgen=10) )
monitors.append( mnStochasticCharacterMap(ctmc=m_bg,
                                          filename=out_fn+".stoch.log",
                                          printgen=100) )

############
# Analysis #
############

# build the model analysis object from the model graph
mymodel = model(m_bg)

### Compute power posterior distributions###
pow_p = powerPosterior(mymodel, moves, monitors, "output/Antarctica_2areas_Xno0/POWER.out", cats=64, sampleFreq=10)
pow_p.burnin(generations=500,tuningInterval=250)
pow_p.run(generations=25)


### Use stepping-stone sampling to calculate marginal likelihoods
ss = steppingStoneSampler(file="output/Antarctica_2areas_Xno0/POWER.out", powerColumnName="power", likelihoodColumnName="likelihood")
ss.marginal()



### Use path-sampling to calculate marginal likelihoods
#ps = pathSampler(file="output/Ebiog_testKms/ebiog_Kms.out", powerColumnName="power", likelihoodColumnName="likelihood")
#ps.marginal()


# create the MCMC analysis object
mymcmc = mcmc(mymodel, monitors, moves)

# mymcmc.initializeFromCheckpoint("output/Ebiog_Kms/elmidae_epoch.state")

# run the MCMC analysis
mymcmc.run(n_gen, checkpointInterval=100, checkpointFile="output/Antarctica_2areas_Xno0/OUT.state")


out_str = "output/Antarctica_2areas_Xno0/OUT"
out_state_fn = out_str + ".states.log"
out_tree_fn = out_str + ".trees"
out_mcc_fn = out_str + ".mcc.tre"

tree_trace = readTreeTrace(file=out_tree_fn, treetype="clock")
tree_trace.setBurnin(0.1)
n_burn = tree_trace.getBurnin()

mcc_tree = mccTree(tree_trace, file=out_mcc_fn)

state_trace = readAncestralStateTrace(file=out_state_fn)

tree_trace = readAncestralStateTreeTrace(file=out_tree_fn, treetype="clock")

anc_tree = ancestralStateTree(tree=mcc_tree,
                              ancestral_state_trace_vector=state_trace,
                              tree_trace=tree_trace,
                              include_start_states=true,
                              file=out_str+".ase.tre",
                              burnin=n_burn,
                              site=1)


