/*!
 *  \file graph.h
 *
 *  \copyright Copyright (c) 2014 Franco "Sensei" Milicchio. All rights reserved.
 *
 *  \license BSD Licensed.
 */

#ifndef sequence_graph_h
#define sequence_graph_h

#include <type_traits>
#include <iostream>
#include <utility>
#include <unordered_map>
#include <stdexcept>
#include <boost/graph/graph_traits.hpp>
#include <boost/graph/adjacency_list.hpp>
#include <boost/graph/edge_list.hpp>
#include <boost/graph/compressed_sparse_row_graph.hpp>
#include <boost/graph/labeled_graph.hpp>
#include <boost/graph/graphviz.hpp>

#include "common.h"
#include "sequence.h"
#include "kmer.h"
#include "timer.h"

namespace seq
{
    namespace graph
    {
        //! \brief Import boost graphs
        using namespace boost;
    }
    
    /*!
     * Alias for a graph with FASTA sequence pointers on nodes, an integer on edges
     */
    using fasta_graph = graph::adjacency_list<graph::vecS, graph::vecS, graph::bidirectionalS, const fasta_string*, std::size_t>;

    /*!
     * Alias for a graph with FASTQ sequence pointers on nodes, an integer on edges
     */
    using fastq_graph = graph::adjacency_list<graph::vecS, graph::vecS, graph::bidirectionalS, const fastq_string*, std::size_t>;

    /*!
     * Alias for a graph with a paired end forward/reverse sequence pointers on nodes, an integer on edges
     */
    using paired_revfwd_graph = graph::adjacency_list<graph::vecS, graph::vecS, graph::bidirectionalS, const paired_revfwd*, std::size_t>;
    
    /*!
     * This class converts a sequence into a string. This is just for testing,
     * not to be used in the real world.
     */
    class to_string
    {
    public:
        
        //! \brief Generic counter
        template <class Sequence>
        void operator()(Sequence &s, std::string &p)
        {
            p = s.get();
        }

    };
    
    /*!
     * This class converts a sequence to its memory location
     */
    class to_pointer
    {
    public:
        
        template <class Sequence>
        void operator()(Sequence &s, Sequence* &p)
        {
            p = &s;
        }
    };
    
    /*!
     * This functor checks for overlapping sequences, given a kmer length
     */
    template <std::size_t klength>
    class kmer_overlap
    {
    public:
        
        //! \brief Store the kmer overlap length
        const std::size_t length = klength;
        
        //! \brief Generic overlapping
        template <class SequenceFrom, class SequenceTo>
        bool operator()(const SequenceFrom &f, const SequenceTo &t) const
        {
            const auto &ff = f.get();
            const auto &tt = t.get();
            
            for (std::size_t j = 0; j < klength; j++)
                if (container_traits<SequenceFrom>::at(ff, j + klength) != container_traits<SequenceTo>::at(tt, j)) return false;
            
            return true;
        }
        
        //! \brief Hashing function for a sequence as a source
        template <class Sequence>
        std::size_t hash_source(const Sequence &s) const
        {
            typedef typename sequence_traits<Sequence>::container_type container_type;
            container_type t;
            
            std::size_t size = container_traits<container_type>::size(s.get());
            
            container_traits<container_type>::resize(t, size - klength, ' ');

            for (std::size_t j = 0; j < size - klength; j++)
                container_traits<container_type>::at(t, j) = container_traits<container_type>::at(s.get(), j + klength);

            return std::hash<container_type>()(t);
        }

        //! \brief Hashing function for a sequence as a target
        template <class Sequence>
        std::size_t hash_target(const Sequence &s) const
        {
            typedef typename sequence_traits<Sequence>::container_type container_type;
            container_type t;
            
            std::size_t size = container_traits<container_type>::size(s.get());
            
            container_traits<container_type>::resize(t, size - klength, ' ');
            
            for (std::size_t j = 0; j < size - klength; j++)
                container_traits<container_type>::at(t, j) = container_traits<container_type>::at(s.get(), j);
            
            return std::hash<container_type>()(t);
        }
        
        //! \brief As a source string
        template <class Sequence>
        std::string as_source(const Sequence &s) const
        {
            typedef typename sequence_traits<Sequence>::container_type container_type;
            std::string t;
            
            std::size_t size = container_traits<container_type>::size(s.get());
            
            container_traits<container_type>::resize(t, size - klength, ' ');
            
            for (std::size_t j = 0; j < size - klength; j++)
                container_traits<container_type>::at(t, j) = container_traits<container_type>::at(s.get(), j + klength);
            
            return t;
        }
        
        //! \brief As a target string
        template <class Sequence>
        std::string as_target(const Sequence &s) const
        {
            typedef typename sequence_traits<Sequence>::container_type container_type;
            container_type t;
            
            std::size_t size = container_traits<container_type>::size(s.get());
            
            container_traits<container_type>::resize(t, size - klength, ' ');
            
            for (std::size_t j = 0; j < size - klength; j++)
                container_traits<container_type>::at(t, j) = container_traits<container_type>::at(s.get(), j);
            
            return t;
        }
        
    };

    /*!
     * This functor returns a string for two given sequences. This is a <b>test
     * function</b> and shoud never be used.
     */
    class dummy_weight
    {
    public:
        
        template<class SequenceFrom, class SequenceTo>
        std::string operator()(const SequenceFrom &f, const SequenceTo &t) const
        {
            int i = 1;
            
            std::string q, sfrom(f.get()), sto(t.get());
            
            while (i <= f.get().size())
            {
                auto orig = f.get();
                auto dest = t.get();
                
                orig.erase(0, i);
                dest.erase(dest.length() - i, i);
                
                if (orig == dest)
                {
                    q = orig;
                    break;
                }
                i++;
            }
            
            return "edge(" + f.get() + " -> " + t.get() + ", " + std::to_string(q.size()) + ", " + q + ")";
        }
    };
    
    /*!
     * This class is used only to specialize weights to overlap positions
     */
    class forward_overlap_weight
    {
        // NOP
    };
    
    /*! 
     * This functor returns the overlap between two sequences
     */
    class overlap_weight
    {
    public:
        
        //! \brief Store the kmer overlap length
        const std::size_t length;

        //! \brief Constructor
        template <std::size_t L>
        overlap_weight(const kmer_overlap<L> &k) : length{ k.length }
        {
            // NOP
        }
        
        template<class SequenceFrom, class SequenceTo>
        std::size_t operator()(const SequenceFrom &f, const SequenceTo &t) const
        {
            return length;
        }

    };

    //! \brief This function generates a weighted graph from two given containers.
    template <class ContainerFrom, class ContainerTo, class Graph, class Fn, class Weight, class Node>
    void make_graph(Graph &g, const ContainerFrom& from, const ContainerTo& to, Fn is_edge, Weight weight, Node convert_node)
    {
        using namespace graph;
        
        typedef typename graph_traits<Graph>::vertex_descriptor      node_type;
        typedef typename vertex_bundle_type<Graph>::type             node_content;
        typedef typename container_traits<ContainerFrom>::value_type from_type;
        typedef typename container_traits<ContainerTo>::value_type   to_type;
        
        // Maps to graph nodes, used only to check for existing nodes (their ID)
        std::unordered_map<node_content, node_type> cache;
        
        // The local maximum index
        node_type idx = 0;
        
        for (const auto& f : from.get())
            for (const auto &t : to.get())
            {
                // If there cannot be an arc (f -> t), then skip this pair
                if (!is_edge(f, t)) continue;
                
                // Convert to graph node properties
                node_content from_node, to_node;
                convert_node(f, from_node);
                convert_node(t, to_node);
                
                node_type from_idx, to_idx;
                
                // Find FROM indices
                auto from_find = cache.find(from_node);
                if (from_find != cache.end())
                    from_idx = from_find->second;
                else
                {
                    cache[from_node] = ++idx;
                    from_idx         = idx;
                }
                
                // Find TO indices
                auto to_find = cache.find(to_node);
                if (to_find != cache.end())
                    to_idx = to_find->second;
                else
                {
                    cache[to_node] = ++idx;
                    to_idx         = idx;
                }
                
                // Weight
                auto w = weight(f, t);
                
                // Add information
                add_edge(from_idx, to_idx, w, g);
                g[from_idx] = from_node;
                g[to_idx]   = to_node;
            }
    }
    
    //! \brief This function specializes for maps.
    template <class From, class To, class Graph, class Fn, class Weight, class Node>
    void make_graph(Graph &g, const kmerlist_hashmap<From>& from, const kmerlist_hashmap<To>& to, Fn is_edge, Weight weight, Node convert_node)
    {
        using namespace graph;
        
        typedef typename graph_traits<Graph>::vertex_descriptor      node_type;
        typedef typename vertex_bundle_type<Graph>::type             node_content;
        typedef From                                                 from_type;
        typedef To                                                   to_type;
        
        // Maps to graph nodes, used only to check for existing nodes (their ID)
        std::unordered_map<node_content, node_type> cache;
        
        // The local maximum index
        node_type idx = 0, cc = 0;

        // Timer
        timer timer;
        
        // Start me up
        if (logging) timer.start("Constructing graph");

        for (const auto& ff : from.get())
            for (const auto &tt : to.get())
            {
                auto &f = ff.second;
                auto &t = tt.second;
                                
                //std::cout << "check " << f.get() << " to " << t.get() << " " << ++cc << std::endl;

                // If there cannot be an arc (f -> t), then skip this pair
                if (!is_edge(f, t)) continue;
                
                // Convert to graph node properties
                node_content from_node, to_node;
                convert_node(f, from_node);
                convert_node(t, to_node);
                
                node_type from_idx, to_idx;
                
                // Find FROM indices
                auto from_find = cache.find(from_node);
                if (from_find != cache.end())
                    from_idx = from_find->second;
                else
                {
                    cache[from_node] = idx;
                    from_idx         = idx++;
                }
                
                // Find TO indices
                auto to_find = cache.find(to_node);
                if (to_find != cache.end())
                    to_idx = to_find->second;
                else
                {
                    cache[to_node] = idx;
                    to_idx         = idx++;
                }
                
                // Weight
                auto w = weight(f, t);
                
                // Add information
                add_edge(from_idx, to_idx, w, g);
                g[from_idx] = from_node;
                g[to_idx]   = to_node;
                
                //std::cout << "      > " << w << std::endl;
            }
        
        // If you start me up I'll never stop
        if (logging) timer.stop();
    }

    //! \brief This function generates a weighted graph from two given concurrent hashmap containers.
    template <class Kmer, class Graph, class Fn, class Weight, class Node>
    void make_kmerhash_graph(Graph &g, const kmerlist_concurrent_hashmap<Kmer>& list, Fn hasher, Weight weight, Node convert_node)
    {
        using namespace graph;
        
        typedef typename graph_traits<Graph>::vertex_descriptor      node_type;
        typedef typename vertex_bundle_type<Graph>::type             node_content;
        typedef Kmer                                                 kmer_type;
        
        // Node id
        node_type idx = 0;
        
        // Timer
        timer timer;

        // Map of kmers to their respective hashes and node index
        std::unordered_map<const Kmer*, std::tuple<std::size_t, std::size_t, node_type>> hashes;
        
        // Reserve space on hashmaps and graph
        hashes.reserve(1000000);
        //add_edge(0, num_vertices(g) + list.size() - 1, g);
        //remove_edge(0, num_vertices(g) + list.size() - 1, g);
        //g.m_vertices.resize(list.size() - 1);

        // Start hashing kmers
        if (logging) timer.start("Hashing " + std::to_string(hasher.length) + "-overlap kmers");
        for(auto &p : list.get())
        {
            hashes[&(p.second)] = std::make_tuple(hasher.hash_source(p.second), hasher.hash_target(p.second), idx++);
        }
        if (logging) timer.stop();
        
        // Maps hashes to (graph node id, kmer pointer), used only to check for existing nodes (their ID)
        std::unordered_multimap<std::size_t, std::tuple<std::size_t, const Kmer*>> bytarget;
        
        // Sort kmers into the unordered multimap, by target hash, mapped to their index
        if (logging) timer.start("Sorting " + std::to_string(hasher.length) + "-overlap kmers");
        for (auto &p : hashes)
        {            
            bytarget.emplace(std::make_pair(std::get<1>(p.second),
                                            std::make_tuple(std::get<2>(p.second), p.first))
                             );
        }
        if (logging) timer.stop();
        
        // Create the graph
        if (logging) timer.start("Creating " + std::to_string(hasher.length) + "-overlap kmers adjacency graph");
        for (auto &p : hashes)
        {
            // For each source p, find all that have the same target hash in q
            auto it = bytarget.equal_range(std::get<0>(p.second));
            
            // Add matching nodes
            for (auto q = it.first; q != it.second; ++q)
            {
                auto fidx = std::get<2>(p.second), tidx = std::get<0>(q->second);
                
                node_content f, t;
                
                convert_node(*p.first, f);
                convert_node(*std::get<1>(q->second), t);
                
                add_edge(fidx, tidx, weight(*p.first, *std::get<1>(q->second)), g);
                g[fidx] = f;
                g[tidx] = t;
            }
        }
        if (logging) timer.stop();
    }

    //! \brief This function generates a weighted graph from two given concurrent hashmap containers.
    template <class Kmer, class Graph, class Weight, class Node, class Fn>
    void make_multikmerhash_graph(Graph &g, const kmerlist_concurrent_hashmap<Kmer>& list, Weight weight, Node convert_node, Fn hasher)
    {
        make_kmerhash_graph(g, list, hasher, weight, convert_node);
    }

    //! \brief Recursive call
    template <class Kmer, class Graph, class Weight, class Node, class Fn, class... Fns>
    void make_multikmerhash_graph(Graph &g, const kmerlist_concurrent_hashmap<Kmer>& list, Weight weight, Node convert_node, Fn hasher, Fns... hashers)
    {
        make_kmerhash_graph(g, list, hasher, weight, convert_node);
        
        make_multikmerhash_graph(g, list, weight, convert_node, hashers...);
    }

    //! \brief This function forwards weight operators in order to use overlapping sizes as weights
    template <class Kmer, class Graph, class Node, class Fn>
    void make_multikmerhash_graph(Graph &g, const kmerlist_concurrent_hashmap<Kmer>& list, forward_overlap_weight weight, Node convert_node, Fn hasher)
    {
        make_kmerhash_graph(g, list, hasher, overlap_weight(hasher), convert_node);
    }
    
    //! \brief Recursive call
    template <class Kmer, class Graph, class Node, class Fn, class... Fns>
    void make_multikmerhash_graph(Graph &g, const kmerlist_concurrent_hashmap<Kmer>& list, forward_overlap_weight weight, Node convert_node, Fn hasher, Fns... hashers)
    {
        make_kmerhash_graph(g, list, hasher, overlap_weight(hasher), convert_node);
        
        make_multikmerhash_graph(g, list, weight, convert_node, hashers...);
    }
    
    /*!
     * This class is the default node printer for a graph dumper
     */
    template <class Graph>
    class node_printer
    {
        const Graph &graph_;
        
    public:
        
        node_printer(const Graph &g) : graph_{g}
        {
            // NOP
        };
        
        template <class Stream, class Node>
        void operator()(Stream &s, Node v)
        {
            s << " [label=\"" << graph_[v]->get() << "\"]";
        }
    };
    
    /*!
     * This class is the default edge printer for a graph dumper
     */
    template <class Graph>
    class edge_printer
    {
        const Graph &graph_;
        
    public:
        
        edge_printer(const Graph &g) : graph_{g}
        {
            // NOP
        };
        
        template <class Stream, class Edge>
        void operator()(Stream &s, Edge e)
        {
            s << " [label=\"" << graph_[e] << "\"]";
        }
    };

    //! \brief This function dumps a graph in a DOT (aka Graphviz) format
    template <class Stream, class Graph>
    void graph_dump(const Graph &g, Stream &s)
    {
        auto node_dumper = node_printer<Graph>(g);
        auto edge_dumper = edge_printer<Graph>(g);
        
        graph::write_graphviz(s, g, node_dumper, edge_dumper);
    }
}

#endif
