//
//  main.cpp
//  counter
//
//  Created by Franco Milicchio on 01/03/21.
//  Copyright © 2021 Franco Milicchio. All rights reserved.
//

#include <chrono>
#include <iostream>

#include "../libseq/accelerator_string.hpp"
#include "../libseq/accelerator_uint128.hpp"
#include "../libseq/accelerator_sse.hpp"
#include "../libseq/partitioner_hash.hpp"
#include "../libseq/logger.hpp"
#include "../libseq/veboas.hpp"


const int k = 5;


bool in_set(uint16_t bitmap[16], uint8_t byte)
{
    
    const uint8_t lo_nibble = byte & 0xf;
    const uint8_t hi_nibble = byte >> 4;
    
    const uint16_t bitset  = bitmap[lo_nibble];
    const uint16_t bitmask = uint16_t(1) << hi_nibble;
    
    return (bitset & bitmask) != 0;
}


inline bool in_set_sse(uint8_t *ptr)
{
    /*
     
     input example:
     'A' 0x41 -> 2(lo) 4(hi)
     'N' 0x4E -> E(lo) 4(hi)
     
          K     L     M     N     O     P     Q     R
    dict {0x01, 0x31, 0xc1, 0x35, 0x65, 0x77, 0x8b, 0x3e};
          ?     '1'   ?     '5'   'e'   'w'   ?     '>'
     
    lo bits 1,5,7,b,e
     
    hi bits 0,3,6,7,8,c
     
     
    **** DNA ****
    
     'A' 0x41 -> 1(lo) 4(hi)    K
     'C' 0x43 -> 3(lo) 4(hi)    L
     'G' 0x47 -> 7(lo) 4(hi)    M
     'T' 0x54 -> 4(lo) 5(hi)    N

     'a' 0x61 -> 1(lo) 6(hi)    O
     'c' 0x63 -> 3(lo) 6(hi)    P
     'g' 0x67 -> 7(lo) 6(hi)    Q
     't' 0x74 -> 4(lo) 7(hi)    R
     */
    constexpr uint8_t K = (1 << 0); // 0x01
    constexpr uint8_t L = (1 << 1); // 0x02
    constexpr uint8_t M = (1 << 2); // 0x04
    constexpr uint8_t N = (1 << 3); // 0x08
    constexpr uint8_t O = (1 << 4); // 0x10
    constexpr uint8_t P = (1 << 5); // 0x20
    constexpr uint8_t Q = (1 << 6); // 0x40
    constexpr uint8_t R = (1 << 7); // 0x80

    constexpr uint8_t lo_nibbles_lookup[16] = {
        /* 0 */ 0x00,
        /* 1 */ K | O,
        /* 2 */ 0x00,
        /* 3 */ L | P,
        /* 4 */ N | R,
        /* 5 */ 0x00,
        /* 6 */ 0x00,
        /* 7 */ M | Q,
        /* 8 */ 0x00,
        /* 9 */ 0x00,
        /* a */ 0x00,
        /* b */ 0x00,
        /* c */ 0x00,
        /* d */ 0x00,
        /* e */ 0x00,
        /* f */ 0x00
    };
//    const uint8_t lo_nibbles_lookup[16] = {
//        /* 0 */ 0x00,
//        /* 1 */ 0x07, // K | L | M
//        /* 2 */ 0x00,
//        /* 3 */ 0x00,
//        /* 4 */ 0x00,
//        /* 5 */ 0x18, // N | O
//        /* 6 */ 0x00,
//        /* 7 */ 0x20, // P
//        /* 8 */ 0x00,
//        /* 9 */ 0x00,
//        /* a */ 0x00,
//        /* b */ 0x40, // Q
//        /* c */ 0x00,
//        /* d */ 0x00,
//        /* e */ 0x80, // R
//        /* f */ 0x00
//    };
  
    constexpr uint8_t hi_nibbles_lookup[16] = {
        /* 0 */ 0x00,
        /* 1 */ 0x00,
        /* 2 */ 0x00,
        /* 3 */ 0x00,
        /* 4 */ K | L | M,
        /* 5 */ N,
        /* 6 */ O | P | Q,
        /* 7 */ R,
        /* 8 */ 0x00,
        /* 9 */ 0x00,
        /* a */ 0x00,
        /* b */ 0x00,
        /* c */ 0x00,
        /* d */ 0x00,
        /* e */ 0x00,
        /* f */ 0x00
    };
//    const uint8_t hi_nibbles_lookup[16] = {
//        /* 0 */ 0x01, // K
//        /* 1 */ 0x00,
//        /* 2 */ 0x00,
//        /* 3 */ 0x8a, // L | N | R
//        /* 4 */ 0x00,
//        /* 5 */ 0x00,
//        /* 6 */ 0x10, // O
//        /* 7 */ 0x20, // P
//        /* 8 */ 0x40, // Q
//        /* 9 */ 0x00,
//        /* a */ 0x00,
//        /* b */ 0x00,
//        /* c */ 0x04, // M
//        /* d */ 0x00,
//        /* e */ 0x00,
//        /* f */ 0x00
//    };
    
    // input          = [11|31|11|35|8b|ff|ee|77|11|c1|11|8b|11|11|ff|01]
    //                      ^^    ^^ ^^       ^^    ^^    ^^          ^^
    const __m128i input = _mm_loadu_si128((__m128i*)ptr);
    
    // lower_nibbles  = [01|01|01|05|0b|0f|0e|07|01|01|01|0b|01|01|0f|01]
    const __m128i lower_nibbles = _mm_and_si128(input, _mm_set1_epi8(0x0f));
    
    // higher_nibbles = [01|03|01|03|08|0f|0e|07|01|0c|01|08|01|01|0f|00]
    const __m128i higher_nibbles = _mm_and_si128(_mm_srli_epi16(input, 4), _mm_set1_epi8(0x0f));
    
    // lo_translated  = [07|07|07|18|40|00|80|20|07|07|07|40|07|07|00|07]
    const __m128i lo_translated = _mm_shuffle_epi8(_mm_load_si128((__m128i*)lo_nibbles_lookup), lower_nibbles);
    
    // hi_translated  = [00|8a|00|8a|40|00|00|20|00|04|00|40|00|00|00|01]
    const __m128i hi_translated = _mm_shuffle_epi8(_mm_load_si128((__m128i*)hi_nibbles_lookup), higher_nibbles);
    
    // lo_translated  = [07|07|07|18|40|00|80|20|07|07|07|40|07|07|00|07]
    // hi_translated  = [00|8a|00|8a|40|00|00|20|00|04|00|40|00|00|00|01]
    // intersection   = [00|02|00|08|40|00|00|20|00|04|00|40|00|00|00|01]
    //                      ^^    ^^ ^^       ^^    ^^    ^^          ^^
    const __m128i intersection = _mm_and_si128(lo_translated, hi_translated);
    
    // t0             = [ff|00|ff|00|00|ff|ff|00|ff|00|ff|00|ff|ff|ff|00]
    const __m128i t0 = _mm_cmpeq_epi8(intersection, _mm_setzero_si128());
        
    int r = _mm_test_all_zeros(t0, _mm_set1_epi8(0xff));
    
    return r;
}

bool is_valid_64bytes (uint8_t* src)
{
    // valid chars are f, h, o, t, and w
    const __m128i tab = _mm_set_epi8('o','_','_','_','_','_','_','h',
                                     'w','f','_','t','_','_','_','_');
    
    __m128i src0 = _mm_loadu_si128((__m128i*)&src[0]);
    __m128i src1 = _mm_loadu_si128((__m128i*)&src[16]);
    __m128i src2 = _mm_loadu_si128((__m128i*)&src[32]);
    __m128i src3 = _mm_loadu_si128((__m128i*)&src[48]);
    __m128i acc;
    
    acc = _mm_cmpeq_epi8(_mm_shuffle_epi8(tab, src0), src0);
    acc = _mm_and_si128(acc, _mm_cmpeq_epi8(_mm_shuffle_epi8(tab, src1), src1));
    acc = _mm_and_si128(acc, _mm_cmpeq_epi8(_mm_shuffle_epi8(tab, src2), src2));
    acc = _mm_and_si128(acc, _mm_cmpeq_epi8(_mm_shuffle_epi8(tab, src3), src3));
    return !!(((unsigned)_mm_movemask_epi8(acc)) == 0xFFFF);
}

bool is_valid(std::string_view s)
{
    static char kmer[65] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";

    std::memmove(kmer, s.begin(), s.length());
    
    if (!in_set_sse((uint8_t*) (kmer +   0))) return false;
    if (!in_set_sse((uint8_t*) (kmer +  16))) return false;
    if (!in_set_sse((uint8_t*) (kmer +  32))) return false;
    if (!in_set_sse((uint8_t*) (kmer +  48))) return false;
    return true;
}


libseq::vebbsttree<libseq::accelerator_sse> veb(12);

std::size_t testread(std::string f)
{
    libseq::path file(f);
    
    std::size_t count = 0;
    
    libseq::fastq_mmap fastq(file);
    
    // Accelerator
    libseq::accelerator_sse acc;
    
    // Parsed string
    libseq::accelerator_sse::storage_type from;
    
    int j;
    
    std::string_view kmer;
        
    // Test with range (just get the reads)
    for (auto &r : fastq)
    {
        count++;
        
        // Parse every kmer
        for (j = 0; j < r.length() - k + 1; j++)
        {
            kmer = std::string_view(r.data() + j, k);
            
//            if (kmer.find_first_not_of("ATCGatcg") != std::string_view::npos) continue;
            if (!is_valid(kmer)) continue;
            
            if (!veb.insert(acc.to_canonical(kmer)))
            {
                libseq::logger::info("NOT INSERTED");
                return count;
            }
            
            from = acc.to_canonical(kmer);
        }
    }
    
    return count;
}


void print_veb()
{
    std::size_t i = 1;
    libseq::accelerator_sse acc;
    
    for (i = 1; i <= veb.maxindex(); i++)
    {
        if (veb[i].count > 0)
            std::cout << acc.to_string(veb[i].kmer, k) << " " << veb[i].count << std::endl;
    }
}


int main(int argc, char *argv[])
{
    std::string s = "AAAAAATTTTTTTTTTTTTTTAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
    libseq::accelerator_sse acc;
    
    std::string_view kmer1(s.data() +  0, 5);
    std::string_view kmer2(s.data() +  5, 5);
    std::string_view kmer3(s.data() + 10, 5);
    std::string_view kmer4(s.data() + 15, 5);
    
    std::cout << kmer1 << " " << kmer2 << acc.compare(acc.to_canonical(kmer1), acc.to_canonical(kmer2)) << std::endl;
    std::cout << kmer1 << " " << kmer3 << acc.compare(acc.to_canonical(kmer1), acc.to_canonical(kmer3)) << std::endl;
    std::cout << kmer1 << " " << kmer4 << acc.compare(acc.to_canonical(kmer1), acc.to_canonical(kmer4)) << std::endl;
    std::cout << kmer2 << " " << kmer1 << acc.compare(acc.to_canonical(kmer2), acc.to_canonical(kmer1)) << std::endl;
    std::cout << kmer1 << " " << kmer1 << acc.compare(acc.to_canonical(kmer1), acc.to_canonical(kmer1)) << std::endl;
    return 0;
    
    libseq::path log("./benchmark.log");
    libseq::logger::init(log.absolute_path());

    auto start = std::chrono::high_resolution_clock::now();
    std::size_t kmers = testread(argv[1]);
    auto end = std::chrono::high_resolution_clock::now();

    double time_taken = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
    
    libseq::logger::info("Number of kmers: {}", kmers);
    libseq::logger::info("Time taken in milliseconds 63: {}", time_taken);
    
    print_veb();

    return 0;
}
