//
//  parser.cpp
//
//  Copyright 2018 Franco Milicchio. All rights reserved.
//

#include <iostream>
#include <string>
#include <string_view>
#include <cmath>

#include <x86intrin.h>

#include "../libseq/precomputed.hpp"


int main(int argc, const char * argv[])
{
//    char lookup[4] = {  0,   1,   2,   3  },
//         uplook[4] = { 'A', 'C', 'G', 'T' };
    
    //                |              | <---- parsing: 128-bits register length
//    std::string s  = "012301012233111001230123012301230123012301230123012301230123012301230123012301230123012301230123012301230123012301230123012301231100110011001100110011001100110011001100110011001100110011001100110011001100110011001100110011001100110011001100110011001100110032323232323232323232323232323232323232323232323232323232323232323232323232323232323232323232323232323232323232323232323232323232";
    std::string s("ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG");
//    std::string s("AAAAAAAAAAAAAAAATTTTTTTTTTTTTTTTCCCCCCCCCCCCCCCCGGGGGGGGGGGGGGGGAAAATTTTCCCCGGGG");

    std::size_t substringlength = 64,
                nsubstrings     = s.length() - substringlength + 1;
    
    // Number of 64-bits needed for a substring
    std::size_t nsizetforsubstr = std::ceil((double) substringlength / 64.0);
    
    std::size_t i, j, times, sub, off, c, v;
    
    std::string substring;
    
    // Fake 128-bit register
    std::size_t ymm[2];
    
    // SIMD 128-bit registers, apparently I cannot use AVX
    __m128i sse[4], val[4], adj[4], res;
    auto mask = _mm_set1_epi16(0x0401); //_mm_set_epi8(1, 1<<2, 1, 1<<2, 1, 1<<2, 1, 1<<2, 1, 1<<2, 1, 1<<2, 1, 1<<2, 1, 1<<2);
//    auto mask = _mm_set_epi8(1, 1<<2, 1, 1<<2, 1, 1<<2, 1, 1<<2, 1, 1<<2, 1, 1<<2, 1, 1<<2, 1, 1<<2);
    auto zero = _mm_set1_epi8('0'); //_mm_set_epi8('0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0');
    
    res = _mm_setzero_si128();
    
    // ************
    // *** SSE3 ***
    // ************
    
    // SSE register size
    std::size_t reglen = 128;
    
    // No more than these bits, please
    assert(reglen / 2 >= substringlength);
    
    // One single SSE register needs 4 SSE registers to be filled
    __m128i substringreg[4], maskreg[4], adjreg[4];
    
    // Mask to shift to the low bytes
    maskreg[0] = _mm_set_epi8(1<< 2, 1<<0, 1<< 2, 1<<0, 1<< 2, 1<<0, 1<< 2, 1<<0, 1<< 2, 1<<0, 1<< 2, 1<<0, 1<< 2, 1<<0, 1<< 2, 1<<0);//_mm_set1_epi16(0x0401);
    maskreg[1] = _mm_set_epi8(1<< 6, 1<<4, 1<< 6, 1<<4, 1<< 6, 1<<4, 1<< 6, 1<<4, 1<< 6, 1<<4, 1<< 6, 1<<4, 1<< 6, 1<<4, 1<< 6, 1<<4);//_mm_set1_epi16(0x1040);

    // Bitmask for FMA (Now all bytes are in order)
//    maskreg[0] = _mm_set_epi8(1<< 0, 1<< 2, 1<< 0, 1<< 2, 1<< 0, 1<< 2, 1<< 0, 1<< 2, 1<< 0, 1<< 2, 1<< 0, 1<< 2, 1<< 0, 1<< 2, 1<< 0, 1<< 2);
//    maskreg[1] = _mm_set_epi8(1<< 4, 1<< 6, 1<< 4, 1<< 6, 1<< 4, 1<< 6, 1<< 4, 1<< 6, 1<< 4, 1<< 6, 1<< 4, 1<< 6, 1<< 4, 1<< 6, 1<< 4, 1<< 6);

    // Equivalent bitmasks
    maskreg[2] = _mm_set1_epi16(0x0104); // _mm_set_epi8(1<< 0, 1<< 2, 1<< 0, 1<< 2, ...
    maskreg[3] = _mm_set1_epi16(0x1040); // _mm_set_epi8(1<< 4, 1<< 6, 1<< 4, 1<< 6, ...

    std::cout << "> STRING: " << s << std::endl;
    
/*
 *********************************
 *** INTERLEAVED MEMORY LAYOUT ***
 *********************************
 
 string     01230123012301230123012301230123012301230123012301230123012301230123012301230123012301230123012301230123012301230123012301230123
            a0, a1, a2, ...

 each string will be encoded with 2-bits tokens as usual, let's say we use SSEx,
 in other words we have 128-bits registers, we have ascii chars encoded as:
 
 
 byte order: 0 1 2 3 ...
 
 r0 =       a0      a1      a2      a3
 r1 =     a4      a5      a6      a7
 r2 =   a8      a9      aA      aB
 r3 = aC      aD      aE      aF
 
 they are interleaved and not linear, then we OR them and build a bitstring, as
 maybe I hope, bits are scrambled and look like a hashed version of something.
 
 
 */
    
    int failed = 0, succeeded = 0;
    
#define SIMD false
    
    // All substrings
    for (sub = 0; sub < nsubstrings; sub++)
    {
        substring.clear();
        
#if SIMD
        
        /*
         * PARSE TO INTEGERS
         *
         res -> substringreg[0]
         sse -> sse[x]
         adj -> adj[x]
         */
        //            res = _mm_loadu_si128((__m128i*) (test.data() +  0));               // load data
        //            sse = _mm_sub_epi8(res, _mm_set1_epi8('A'));                        // letters start at zero 'A' -> 0
        //            sse = _mm_and_si128(sse, _mm_set1_epi8(0x0F));                      // get lower bits only
        //            adj = _mm_cmpeq_epi8(sse, _mm_set1_epi8(0x06));                     // find 0x06 'G'
        //            res = _mm_sub_epi8(_mm_andnot_si128(adj, sse), adj);                // convert 0x06 to 0x01
        
        // Load data without caring about past-end-of-buffer
        substringreg[0] = _mm_loadu_si128((__m128i*) (s.data() + sub +   0));
        substringreg[1] = _mm_loadu_si128((__m128i*) (s.data() + sub +  16));
        substringreg[2] = _mm_loadu_si128((__m128i*) (s.data() + sub +  32));
        substringreg[3] = _mm_loadu_si128((__m128i*) (s.data() + sub +  48));
        
        // Offset ASCII 'A' to zero
        sse[0]          = _mm_sub_epi8(substringreg[0], _mm_set1_epi8('A'));
        sse[1]          = _mm_sub_epi8(substringreg[1], _mm_set1_epi8('A'));
        sse[2]          = _mm_sub_epi8(substringreg[2], _mm_set1_epi8('A'));
        sse[3]          = _mm_sub_epi8(substringreg[3], _mm_set1_epi8('A'));

        // Get lower 4 bits to distinguish letters
        sse[0]          = _mm_and_si128(sse[0], _mm_set1_epi8(0x0F));
        sse[1]          = _mm_and_si128(sse[1], _mm_set1_epi8(0x0F));
        sse[2]          = _mm_and_si128(sse[2], _mm_set1_epi8(0x0F));
        sse[3]          = _mm_and_si128(sse[3], _mm_set1_epi8(0x0F));

        // Find 'G' base 0x06
        adj[0]          = _mm_cmpeq_epi8(sse[0], _mm_set1_epi8(0x06));
        adj[1]          = _mm_cmpeq_epi8(sse[1], _mm_set1_epi8(0x06));
        adj[2]          = _mm_cmpeq_epi8(sse[2], _mm_set1_epi8(0x06));
        adj[3]          = _mm_cmpeq_epi8(sse[3], _mm_set1_epi8(0x06));

        // Convert 'G' to 0x01
        substringreg[0] = _mm_sub_epi8(_mm_andnot_si128(adj[0], sse[0]), adj[0]);
        substringreg[1] = _mm_sub_epi8(_mm_andnot_si128(adj[1], sse[1]), adj[1]);
        substringreg[2] = _mm_sub_epi8(_mm_andnot_si128(adj[2], sse[2]), adj[2]);
        substringreg[3] = _mm_sub_epi8(_mm_andnot_si128(adj[3], sse[3]), adj[3]);

        
        ////////////////////////////////////////////////////////////////////
        // useful breakpoint here //////////////////////////////////////////
        ////////////////////////////////////////////////////////////////////
        //            i = 0;

        /*
         * CONVERT TO BITS
         */
        
//            substringreg[0] = _mm_loadu_si128((__m128i*) (s.data() + sub +   0));
//            substringreg[1] = _mm_loadu_si128((__m128i*) (s.data() + sub +  16));
//            substringreg[2] = _mm_loadu_si128((__m128i*) (s.data() + sub +  32));
//            substringreg[3] = _mm_loadu_si128((__m128i*) (s.data() + sub +  48));
        
//            substringreg[0] = _mm_sub_epi8(substringreg[0], zero);
//            substringreg[1] = _mm_sub_epi8(substringreg[1], zero);
//            substringreg[2] = _mm_sub_epi8(substringreg[2], zero);
//            substringreg[3] = _mm_sub_epi8(substringreg[3], zero);

        // FMA
        substringreg[0] = _mm_maddubs_epi16(substringreg[0], maskreg[0]);
        substringreg[1] = _mm_maddubs_epi16(substringreg[1], maskreg[1]);
        substringreg[2] = _mm_maddubs_epi16(substringreg[2], maskreg[0]);
        substringreg[3] = _mm_maddubs_epi16(substringreg[3], maskreg[1]);
        
        // OR... I think I can remove one OP here
        res             = _mm_or_si128(substringreg[0], substringreg[1]);
        substringreg[2] = _mm_slli_epi16(substringreg[2], 8);
        res             = _mm_or_si128(res, substringreg[2]);
        substringreg[3] = _mm_slli_epi16(substringreg[3], 8);
        res             = _mm_or_si128(res, substringreg[3]);
        
        
        ////////////////////////////////////////////////////////////////////
        // useful breakpoint here //////////////////////////////////////////
        ////////////////////////////////////////////////////////////////////
        i = 0;
        
        // Convert back to string: substring += static_cast<char>(uplook[c]);
        
        //res = _mm_set1_epi8(0xff);

        // Get bases
        sse[0] = _mm_and_si128(res, _mm_set1_epi16(0x0003));    // first bits
        sse[1] = _mm_and_si128(res, _mm_set1_epi16(0x000C));    // second 8
        sse[2] = _mm_and_si128(res, _mm_set1_epi16(0x0030));    // ...
        sse[3] = _mm_and_si128(res, _mm_set1_epi16(0x00C0));    // you get the idea
        adj[0] = _mm_and_si128(res, _mm_set1_epi16(0x0300));    // ...
        adj[1] = _mm_and_si128(res, _mm_set1_epi16(0x0C00));
        adj[2] = _mm_and_si128(res, _mm_set1_epi16(0x3000));
        adj[3] = _mm_and_si128(res, _mm_set1_epi16(0xC000));    // done.
        
        
/*
LOWER BITS:
sse[0]    __m128i    (00 A, 00, 02 C, 00, 00 A, 00, 02 C, 00, 00 A, 00, 02 C, 00, 00 A, 00, 02 C, 00)
sse[1]    __m128i    (0c T, 00, 04 G, 00, 0c T, 00, 04 G, 00, 0c T, 00, 04 G, 00, 0c T, 00, 04 G, 00)

adj[0]    __m128i    (00, 00 A, 00, 02 C, 00, 00 A, 00, 02 C, 00, 00 A, 00, 02 C, 00, 00 A, 00, 02 C)
adj[1]    __m128i    (00, 0c T, 00, 04 G, 00, 0c T, 00, 04 G, 00, 0c T, 00, 04 G, 00, 0c T, 00, 04 G)

SAME THING WITH HIGHER BITS:
sse[2]    __m128i    (00, 00, 20, 00, 00, 00, 20, 00, 00, 00, 20, 00, 00, 00, 20, 00)
sse[3]    __m128i    (c0, 00, 40, 00, c0, 00, 40, 00, c0, 00, 40, 00, c0, 00, 40, 00)

now shuffle sse[1] i.e. SHR by one byte, then blend alternatively sse[0]/sse[1]
for adj, SHL adj[0], then blend
*/
        
        // Rearrange bytes
        sse[1] = _mm_shuffle_epi8(sse[1], _mm_set_epi8(0x0e, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00, 0xff));
        sse[0] = _mm_blendv_epi8(sse[0], sse[1], _mm_set_epi8(0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00));
        sse[3] = _mm_shuffle_epi8(sse[3], _mm_set_epi8(0x0e, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00, 0xff));
        sse[2] = _mm_blendv_epi8(sse[2], sse[3], _mm_set_epi8(0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00));

        adj[0] = _mm_shuffle_epi8(adj[0], _mm_set_epi8(0xff, 0x0f, 0x0e, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01));
        adj[1] = _mm_blendv_epi8(adj[0], adj[1], _mm_set_epi8(0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00));
        adj[2] = _mm_shuffle_epi8(adj[2], _mm_set_epi8(0xff, 0x0f, 0x0e, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01));
        adj[3] = _mm_blendv_epi8(adj[2], adj[3], _mm_set_epi8(0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00));

        // Lazy-man
        auto dumpsse = [&substring](__m128i r)
        {
            for (int j = 0; j < sizeof(__m128i); j++)
            {
                unsigned char c = *((unsigned char*)(&r) + j);
                
                switch (c)
                {
                    case 0x00:
                        substring += 'A';
                        break;

                    case 0x0c: case 0x03:
                    case 0xc0: case 0x30:
                        substring += 'T';
                        break;

                    case 0x02: case 0x08:
                    case 0x20: case 0x80:
                        substring += 'C';
                        break;
                        
                    case 0x04: case 0x01:
                    case 0x40: case 0x10:
                        substring += 'G';
                        break;
                        
                    default:
                        substring += '.';
                        throw std::domain_error("what the hell was that in my simd register?");
                }
            }
        };
        
        // Now in order dump sse[0,2], adj[1,3]
        dumpsse(sse[0]);
        dumpsse(sse[2]);
        dumpsse(adj[1]);
        dumpsse(adj[3]);
        
        // Cut the substring length
        substring = substring.substr(0, substringlength);
        
        // Good for debugging
        j = 0;

#else
        std::fill(std::begin(ymm), std::end(ymm), 0);
        
//        if (s[0] != '0')
//        {
//            throw std::domain_error("remember to switch the s variables to the numeric one");
//        }

        // Non-SIMD
        for (i = 0, off = 0; i < substringlength; i++)
        {
            c = libseq::bases_values::value(s[sub + i]); //lookup[s[sub + i] - '0'];
            v = c << ((i - off * 32) * 2);
            
            ymm[off] |= v;
            
            // Test
            v = ymm[off] >> ((i - off * 32) * 2);
            c = v & 3;
            substring += libseq::bases_values::ascii(c); //static_cast<char>(uplook[c]);
            
            // New offset inside fake AVX register
            off = ((i + 1) % 32 == 0) ? off + 1 : off;
        }
#endif
        // Substring (the real one)
        auto q = s.substr(sub, substringlength);
        
        if (substring != q)
        {
            std::cout << "F " << q << std::endl;
            failed++;
        }
        else
        {
            std::cout << "s " << q << std::endl;
            succeeded++;
        }
        std::cout << "  " << substring << std::endl;
    }
    
    std::cout << std::endl;
    std::cout << "> substrings " << nsubstrings << std::endl;
    std::cout << "> succeeded  " << succeeded << std::endl;
    std::cout << "> failed     " << failed << std::endl;

    return 0;
}
