//
//  lmf_graph.cpp
//  sequence
//
//  Created by Sensei on 9/24/15.
//  Copyright © 2015 Franco "Sensei" Milicchio. All rights reserved.
//

#include "lmf_graph.hpp"

#include <fstream>
#include <algorithm>
#include <numeric>
#include <utility>
#include <mutex>
#include <tbb/parallel_for.h>
#include <tbb/parallel_reduce.h>

#include "../libseq/timer.h"

// CSV
void lmf_graph::dump_csv(const std::string &name)
{
    seq::timer t;
    
    std::ofstream f(name);
    
    t.start("Dumping kmers and frequencies in CSV format");
    std::cout << ">>> " << name << std::endl;
    
    // Dump list
    for(const auto &p : hash_)
    {
        f << tostring(p.first) << "," << p.second << std::endl;
    }
    t.stop();
}

// Remove useless kmers
void lmf_graph::cleanup(double error)
{
    seq::timer t;

    t.start("Computing kmer frequencies");

    std::size_t sum;
    
    // Compute the sum of all frequencies concurrently
    sum = tbb::parallel_reduce(hash_.range(),
                                    0,
                                    [](hashtype::range_type &r, std::size_t init) -> std::size_t
                                    {
                                        std::size_t q = init;
                                        for (const auto &p : r)
                                        {
                                            q += p.second;
                                        }
                                        return q;
                                    },
                                    std::plus<std::size_t>());
    t.stop();
    
    // Clean not implemented
}

// Build the graph
void lmf_graph::build()
{
    seq::timer t;
    
    t.start("Creating de Bruijn graph");
    
    // Checks if it exists a next-in-line kmer in the list (e.g., ATCA with T, checks for TCAT)
    auto checknext = [&](kmertype p, char amino) -> fwdrev
    {
        // Get the next kmer with last aminoacid == amino and its reverse
        auto f = getnext(p, amino);
        auto r = revfn(f);
        
        // Accessor for the concurrent hashmap
        typename hashtype::accessor accessor;

        // Find the next kmer as a forward and as its reverse (we don't know which form has been stored)
        bool F = hash_.find(accessor, f);
        bool R = hash_.find(accessor, r);
        if ((p.get() != f.get()) && (p.get() != r.get()) && (F || R))
        {
            return fwdrev(F, R);
        }
        
        return fwdrev(false, false);
    };
    
    // Checks if it exists a next-in-line kmer in the list (e.g., ATCA with T, checks for TCAT)
    auto checkprev = [&](kmertype p, char amino) -> fwdrev
    {
        // Get the next kmer with last aminoacid == amino and its reverse
        auto f = getprev(p, amino);
        auto r = revfn(f);
        
        // Accessor for the concurrent hashmap
        typename hashtype::accessor accessor;
        
        // Find the next kmer as a forward and as its reverse (we don't know which form has been stored)
        bool F = hash_.find(accessor, f);
        bool R = hash_.find(accessor, r);
        if ((p.get() != f.get()) && (p.get() != r.get()) && (F || R))
        {
            return fwdrev(F, R);
        }
        
        return fwdrev(false, false);
    };

#if 1
    tbb::parallel_for(hash_.range(20),
                      [&](hashtype::range_type &r)
                      {
                          kmerprop q;
                          
                          for (auto &p : r)
                          {
                              // Clear bits
                              q = 0;
                              
                              // Found in forward or reverse form
                              fwdrev   found;
                              
                              // The target kmer to search
                              kmertype target;
                              
                              char     amino;
                              
                              amino = kmertype::A;
                              target = getnext(p.first, amino);
                              found  = checknext(p.first, amino);
                              if (found.first || found.second)
                              {
                                  q = q | (found.first ? fwdA : revA);
                              }
                              // T
                              amino = kmertype::T;
                              target = getnext(p.first, amino);
                              found  = checknext(p.first, amino);
                              if (found.first || found.second)
                              {
                                  q = q | (found.first ? fwdT : revT);
                              }
                              // C
                              amino = kmertype::C;
                              target = getnext(p.first, amino);
                              found  = checknext(p.first, amino);
                              if (found.first || found.second)
                              {
                                  q = q | (found.first ? fwdC : revC);
                              }
                              // G
                              amino = kmertype::G;
                              target = getnext(p.first, amino);
                              found  = checknext(p.first, amino);
                              if (found.first || found.second)
                              {
                                  q = q | (found.first ? fwdG : revG);
                              }
                              
                              // === REV ===
                              
                              // A
                              amino = kmertype::A;
                              target = getprev(p.first, amino);
                              found  = checkprev(p.first, amino);
                              if (found.first || found.second)
                              {
                                  q = q | ((found.first ? fwdA : revA) << 8);
                              }
                              // T
                              amino = kmertype::T;
                              target = getprev(p.first, amino);
                              found  = checkprev(p.first, amino);
                              if (found.first || found.second)
                              {
                                  q = q | ((found.first ? fwdT : revT) << 8);
                              }
                              // C
                              amino = kmertype::C;
                              target = getprev(p.first, amino);
                              found  = checkprev(p.first, amino);
                              if (found.first || found.second)
                              {
                                  q = q | ((found.first ? fwdC : revC) << 8);
                              }
                              // G
                              amino = kmertype::G;
                              target = getprev(p.first, amino);
                              found  = checkprev(p.first, amino);
                              if (found.first || found.second)
                              {
                                  q = q | ((found.first ? fwdG : revG) << 8);
                              }
                              
                              // Set bits
                              typename hashtype::accessor accessor;
                              hash_.insert(accessor, p.first);
                              accessor->second = q;
                          }
                          
                      });
    
#else
    // For each kmer
    for (auto &p : hash_)
    {
        // Clear bits
        p.second = 0;
        
        // Found in forward or reverse form
        fwdrev   found;
        
        // The target kmer to search
        kmertype target;
        
        char     amino;
        
        amino = kmertype::A;
        target = getnext(p.first, amino);
        found  = checknext(p.first, amino);
        if (found.first || found.second)
        {
            p.second = p.second | (found.first ? fwdA : revA);
        }
        // T
        amino = kmertype::T;
        target = getnext(p.first, amino);
        found  = checknext(p.first, amino);
        if (found.first || found.second)
        {
            p.second = p.second | (found.first ? fwdT : revT);
        }
        // C
        amino = kmertype::C;
        target = getnext(p.first, amino);
        found  = checknext(p.first, amino);
        if (found.first || found.second)
        {
            p.second = p.second | (found.first ? fwdC : revC);
        }
        // G
        amino = kmertype::G;
        target = getnext(p.first, amino);
        found  = checknext(p.first, amino);
        if (found.first || found.second)
        {
            p.second = p.second | (found.first ? fwdG : revG);
        }
        
        // === REV ===
        
        // A
        amino = kmertype::A;
        target = getprev(p.first, amino);
        found  = checkprev(p.first, amino);
        if (found.first || found.second)
        {
            p.second = p.second | ((found.first ? fwdA : revA) << 8);
        }
        // T
        amino = kmertype::T;
        target = getprev(p.first, amino);
        found  = checkprev(p.first, amino);
        if (found.first || found.second)
        {
            p.second = p.second | ((found.first ? fwdT : revT) << 8);
        }
        // C
        amino = kmertype::C;
        target = getprev(p.first, amino);
        found  = checkprev(p.first, amino);
        if (found.first || found.second)
        {
            p.second = p.second | ((found.first ? fwdC : revC) << 8);
        }
        // G
        amino = kmertype::G;
        target = getprev(p.first, amino);
        found  = checkprev(p.first, amino);
        if (found.first || found.second)
        {
            p.second = p.second | ((found.first ? fwdG : revG) << 8);
        }

    }
#endif
    t.stop();
}

// Graph building
void lmf_graph::dump_graph(const std::string &name)
{
    seq::timer t;
    
    std::ofstream f(name);
    
    t.start("Dumping graph in graphviz format");

    std::cout << ">>> " << name << std::endl;

    f << "digraph dna {" << std::endl;

    // Accessor for the concurrent hashmap
    typename hashtype::accessor accessor;
    
    // Checks if it exists a next-in-line kmer in the list (e.g., ATCA with T, checks for TCAT)
    auto checknext = [&](const kmertype &p, char amino) -> fwdrev
    {
        // Get the next kmer with last aminoacid == amino and its reverse
        auto f = getnext(p, amino);
        auto r = revfn(f);
        
        // Find the next kmer as a forward and as its reverse (we don't know which form has been stored)
        bool F = hash_.find(accessor, f);
        bool R = hash_.find(accessor, r);
        if ((p.get() != f.get()) && (p.get() != r.get()) && (F || R))
        {
            return fwdrev(F, R);
        }
        
        return fwdrev(false, false);
    };

    // Checks if it exists a next-in-line kmer in the list (e.g., ATCA with T, checks for TCAT)
    auto checkprev = [&](const kmertype &p, char amino) -> fwdrev
    {
        // Get the next kmer with last aminoacid == amino and its reverse
        auto f = getprev(p, amino);
        auto r = revfn(f);
        
        // Find the next kmer as a forward and as its reverse (we don't know which form has been stored)
        bool F = hash_.find(accessor, f);
        bool R = hash_.find(accessor, r);
        if ((p.get() != f.get()) && (p.get() != r.get()) && (F || R))
        {
            return fwdrev(F, R);
        }
        
        return fwdrev(false, false);
    };

    // For each kmer
    for (auto &p : hash_)
    {
        p.second = 0;
        
        fwdrev   found;
        kmertype target;
        
        char     amino;
        
        f << std::endl;
        f << "// ========================================" << std::endl;
        f << "\"" << tostring(p.first) << "\" -> \"" << tostring(revfn(p.first)) << "\" [label=\"hasrev\"]" << std::endl;
        f << "// ========================================" << std::endl;
        
        
        
        // Unroll the loop from A to G, i.e., 0 to 3
        
        // === FWD ===

        f << "// fwd" << std::endl;
        // A
        amino = kmertype::A;
        target = getnext(p.first, amino);
        found  = checknext(p.first, amino);
        f << "// A: " << tostring(target) << std::endl;
        if (found.first || found.second)
        {
            // ---------
            // DEBUGGING
            // ---------
            std::string srcF, srcR, dstF, dstR;
            
            // Source
            srcF = tostring(p.first);
            srcR = tostring(revfn(p.first));
            
            // Destination
            dstF = found.first ? tostring(target)        : tostring(revfn(target));
            dstR = found.first ? tostring(revfn(target)) : tostring(target);
            
            f << "\"" << tostring(p.first) << "\" -> \"" << (found.first ? tostring(target) : tostring(revfn(target))) << "\" [label=\"fwd" << (found.first ? "F" : "") << (found.second ? "R" : "") << "\"]" <<std::endl;
        }
        // T
        amino = kmertype::T;
        target = getnext(p.first, amino);
        found  = checknext(p.first, amino);
        f << "// T: " << tostring(target) << std::endl;
        if (found.first || found.second)
        {
            // ---------
            // DEBUGGING
            // ---------
            std::string srcF, srcR, dstF, dstR;
            
            // Source
            srcF = tostring(p.first);
            srcR = tostring(revfn(p.first));
            
            // Destination
            dstF = found.first ? tostring(target)        : tostring(revfn(target));
            dstR = found.first ? tostring(revfn(target)) : tostring(target);
            
            f << "\"" << tostring(p.first) << "\" -> \"" << (found.first ? tostring(target) : tostring(revfn(target))) << "\" [label=\"fwd" << (found.first ? "F" : "") << (found.second ? "R" : "") << "\"]" <<std::endl;
        }
        // C
        amino = kmertype::C;
        target = getnext(p.first, amino);
        found  = checknext(p.first, amino);
        f << "// C: " << tostring(target) << std::endl;
        if (found.first || found.second)
        {
            // ---------
            // DEBUGGING
            // ---------
            std::string srcF, srcR, dstF, dstR;
            
            // Source
            srcF = tostring(p.first);
            srcR = tostring(revfn(p.first));
            
            // Destination
            dstF = found.first ? tostring(target)        : tostring(revfn(target));
            dstR = found.first ? tostring(revfn(target)) : tostring(target);
            
            f << "\"" << tostring(p.first) << "\" -> \"" << (found.first ? tostring(target) : tostring(revfn(target))) << "\" [label=\"fwd" << (found.first ? "F" : "") << (found.second ? "R" : "") << "\"]" <<std::endl;
        }
        // G
        amino = kmertype::G;
        target = getnext(p.first, amino);
        found  = checknext(p.first, amino);
        f << "// G: " << tostring(target) << std::endl;
        if (found.first || found.second)
        {
            // ---------
            // DEBUGGING
            // ---------
            std::string srcF, srcR, dstF, dstR;
            
            // Source
            srcF = tostring(p.first);
            srcR = tostring(revfn(p.first));
            
            // Destination
            dstF = found.first ? tostring(target)        : tostring(revfn(target));
            dstR = found.first ? tostring(revfn(target)) : tostring(target);
            
            f << "\"" << tostring(p.first) << "\" -> \"" << (found.first ? tostring(target) : tostring(revfn(target))) << "\" [label=\"fwd" << (found.first ? "F" : "") << (found.second ? "R" : "") << "\"]" <<std::endl;
        }
        
        // === REV ===

        f << "// rev" << std::endl;
        // A
        amino = kmertype::A;
        target = getprev(p.first, amino);
        found  = checkprev(p.first, amino);
        f << "// A: " << tostring(target) << std::endl;
        if (found.first || found.second)
        {
            // ---------
            // DEBUGGING
            // ---------
            std::string srcF, srcR, dstF, dstR;
            
            // Source
            srcF = tostring(p.first);
            srcR = tostring(revfn(p.first));
            
            // Destination
            dstF = found.first ? tostring(target)        : tostring(revfn(target));
            dstR = found.first ? tostring(revfn(target)) : tostring(target);
            
            f << "\"" << tostring(p.first) << "\" -> \"" << (found.first ? tostring(target) : tostring(revfn(target))) << "\" [label=\"rev" << (found.first ? "F" : "") << (found.second ? "R" : "") << "\"]" <<std::endl;
        }
        // T
        amino = kmertype::T;
        target = getprev(p.first, amino);
        found  = checkprev(p.first, amino);
        f << "// T: " << tostring(target) << std::endl;
        if (found.first || found.second)
        {
            // ---------
            // DEBUGGING
            // ---------
            std::string srcF, srcR, dstF, dstR;
            
            // Source
            srcF = tostring(p.first);
            srcR = tostring(revfn(p.first));
            
            // Destination
            dstF = found.first ? tostring(target)        : tostring(revfn(target));
            dstR = found.first ? tostring(revfn(target)) : tostring(target);
            
            f << "\"" << tostring(p.first) << "\" -> \"" << (found.first ? tostring(target) : tostring(revfn(target))) << "\" [label=\"rev" << (found.first ? "F" : "") << (found.second ? "R" : "") << "\"]" <<std::endl;
        }
        // C
        amino = kmertype::C;
        target = getprev(p.first, amino);
        found  = checkprev(p.first, amino);
        f << "// C: " << tostring(target) << std::endl;
        if (found.first || found.second)
        {
            // ---------
            // DEBUGGING
            // ---------
            std::string srcF, srcR, dstF, dstR;
            
            // Source
            srcF = tostring(p.first);
            srcR = tostring(revfn(p.first));
            
            // Destination
            dstF = found.first ? tostring(target)        : tostring(revfn(target));
            dstR = found.first ? tostring(revfn(target)) : tostring(target);
            
            f << "\"" << tostring(p.first) << "\" -> \"" << (found.first ? tostring(target) : tostring(revfn(target))) << "\" [label=\"rev" << (found.first ? "F" : "") << (found.second ? "R" : "") << "\"]" <<std::endl;
        }
        // G
        amino = kmertype::G;
        target = getprev(p.first, amino);
        found  = checkprev(p.first, amino);
        f << "// G: " << tostring(target) << std::endl;
        if (found.first || found.second)
        {
            // ---------
            // DEBUGGING
            // ---------
            std::string srcF, srcR, dstF, dstR;
            
            // Source
            srcF = tostring(p.first);
            srcR = tostring(revfn(p.first));
            
            // Destination
            dstF = found.first ? tostring(target)        : tostring(revfn(target));
            dstR = found.first ? tostring(revfn(target)) : tostring(target);
            
            f << "\"" << tostring(p.first) << "\" -> \"" << (found.first ? tostring(target) : tostring(revfn(target))) << "\" [label=\"rev" << (found.first ? "F" : "") << (found.second ? "R" : "") << "\"]" <<std::endl;
        }
        
    }
    
    f << "}" << std::endl;
    
    t.stop();
}

