// Cytosim was created by Francois Nedelec. Copyright 2021 Cambridge University.

#include "random.h"
#include <cstdio>
#include <bitset>
#include <cstring>
#include <iostream>
#include <climits>
#include "timer.h"

template < typename T >
void print_bits(FILE* f, const T& val, char spc)
{
    unsigned char * ptr = (unsigned char*) & val;
    for ( int i = sizeof(T)-1; i >= 0; --i)
    {
        unsigned char byte = ptr[i];
        for ( int i = 0; i < CHAR_BIT; ++i )
        {
            putc('0' + (1 & (byte>>(CHAR_BIT-1))), f);
            byte <<= 1;
        }
        if ( spc ) putc(spc, f);
    }
    putc('\n', f);
}


void speed_test()
{
    const size_t cnt = 1 << 30;
    tic();
    uint32_t u = 10;
    for (size_t j=0; j<cnt; ++j)
    {
        u = RNG.pint32(1024);
        RNG.pint32(u);
    }
    printf("int %5.2f\n", toc(cnt));
}


void test_int()
{
    for (int j=0; j<8; ++j)
    {
        for (int k=0; k<8; ++k)
            printf(" %12u", RNG.pint32());
        printf("\n");
    }
    printf("\n");
    
    for (int j=0; j<8; ++j)
    {
        for (int k=0; k<8; ++k)
            printf(" %+12i", RNG.sint32());
        printf("\n");
    }
    printf("\n");

    for (int j=0; j<8; ++j)
    {
        for (int k=0; k<32; ++k)
            printf(" %2u", RNG.pint32(100));
        printf("\n");
    }
    printf("\n");
    
    for (int j=0; j<8; ++j)
    {
        for (int k=0; k<32; ++k)
            printf(" %2u", RNG.pint32_fair(100));
        printf("\n");
    }
    printf("\n");

    for (int j=0; j<8; ++j)
    {
        for (int k=0; k<32; ++k)
            printf(" %2u", RNG.pint32_slow(99));
        printf("\n");
    }
    printf("\n");
}


void silly_test()
{
    const uint32_t up = 1 << 30;
    
    const uint32_t cnt = 1 << 24;
    uint32_t hit = 0;
    
    for (uint32_t j=0; j<cnt; ++j)
        hit += ( RNG.pint32() < up );

    printf(" prob( pint32() < 1^30 ) = %f\n", hit/(float)cnt);
}


/**
 This assumes IEEE Standard 754 Floating point numbers
32 bits: 1 for sign, 8 for exponents, 23 for fraction
 */
float convertFix(uint32_t x)
{
    constexpr uint32_t FRAC  = 0x7FFFFFU;
    constexpr uint32_t EXPON = 127 << 23;
    uint32_t res = EXPON | ( x & FRAC );
    return *((float*)&res) - 1.0;
}


void testbits()
{
    const int SCALE = 2;
    for ( int i=0; i <= SCALE; ++i )
    {
        float x = i / float(SCALE);
        printf(" %f :", x);
        print_bits(stdout, x, 0);
        // x = -ii / float(SCALE);
        // printf("%f :", x);
        // print_bits(stdout, x, 0);
    }
    
    for ( int i=0; i < 16; ++i )
    {
        float y = convertFix(RNG.pint32());
        printf(" %f :", y);
        print_bits(stdout, y, ' ');
    }
}


#define TEST test
void test_test( const real prob, const size_t MAX )
{
    int cnt = 0, a, b, c;
    for ( size_t jj=0; jj < MAX; ++jj )
    {
        a = RNG.TEST(prob) + RNG.TEST(prob) + RNG.TEST(prob) + RNG.TEST(prob);
        b = RNG.TEST(prob) + RNG.TEST(prob) + RNG.TEST(prob) + RNG.TEST(prob);
        c = RNG.TEST(prob) + RNG.TEST(prob) + RNG.TEST(prob) + RNG.TEST(prob);
        cnt += a + b + c;
    }
    printf("prob = %f measured = %f cnt = %i\n", prob, cnt / double(12*MAX), cnt);
}

void test_RNG(const size_t MAX)
{
    for ( size_t jj=0; jj < MAX; ++jj )
    {
        RNG.preal();RNG.preal();RNG.preal();RNG.preal();RNG.preal();
        RNG.preal();RNG.preal();RNG.preal();RNG.preal();RNG.preal();
    }
}


void test_real()
{
    for (int j=0; j<8; ++j)
    {
        for (int k=0; k<8; ++k)
            printf(" %10f", RNG.sreal());
        printf("\n");
    }

    printf("\n");
    for (int j=0; j<8; ++j)
    {
        for (int k=0; k<8; ++k)
            printf(" %10f", RNG.preal());
        printf("\n");
    }
    
    printf("\n");
    for (int j=0; j<8; ++j)
    {
        for (int k=0; k<8; ++k)
            printf(" %10f", RNG.shalf());
        printf("\n");
    }

    printf("\npfloat:     ");
    float x;
    for ( int kk=0; kk < 10; ++kk )
    {
        x = RNG.pfloat();
        printf(" %+f", x);
    }

    printf("\nsfloat:     ");
    for ( int kk=0; kk < 10; ++kk )
    {
        x = RNG.sfloat();
        printf(" %+f", x);
    }
    
    double d;
    printf("\npdouble:    ");
    for ( int kk=0; kk < 10; ++kk )
    {
        d = RNG.pdouble();
        printf(" %+f", d);
    }

    printf("\nsdouble:    ");
    for ( int kk=0; kk < 10; ++kk )
    {
        d = RNG.sdouble();
        printf(" %+f", d);
    }

    printf("\nsflip:      ");
    for ( int kk=0; kk < 10; ++kk )
    {
        d = RNG.sflip();
        printf(" %+f", d);
    }
    printf("\n");
}

//==========================================================================

void test_uniform(size_t cnt)
{
    const double off = 0.5;
    double avg = 0, var = 0;
    for ( size_t i = 0; i < cnt; ++i )
    {
        real x = RNG.sreal() - off;
        real y = RNG.sreal() - off;
        real z = RNG.sreal() - off;
        real t = RNG.sreal() - off;
        avg += x + y + z + t;
        var += x*x + y*y + z*z + t*t;
    }
    cnt *= 4;
    avg /= cnt;
    var = ( var - square(avg) * cnt ) / real(cnt-1);
    printf("UNIFORM      avg = %.12e   var = %.12e\n", avg+off, var);
}


void test_gauss(size_t CNT)
{
    size_t cnt = 0;
    double avg = 0, var = 0;
    const size_t n_max = 1<<6;
    real vec[n_max] = { 0 };
    for ( size_t i = 0; i < CNT; ++i )
    {
        size_t n = RNG.pint32(n_max);
        RNG.gauss_set(vec, n);
        cnt += n;
        for ( size_t u = 0; u < n; ++u )
        {
            avg += vec[u];
            var += vec[u] * vec[u];
        }
    }
    avg /= cnt;
    var = ( var - square(avg) * cnt ) / real(cnt-1);
    printf("GAUSSIAN     avg = %.12e   var = %.12e\n", avg, var);
}


void test_prob()
{
    size_t avg = 0;
    size_t cnt = 1 << 28;
    for ( size_t i = 0; i < cnt; ++i )
        avg += RNG.flip_8th();

    printf("8th      prob = %.6f\n", avg/(double)cnt);
}


void test_exponential(size_t cnt)
{
    const double off = 1;
    double avg = 0, var = 0;
    for ( size_t i = 0; i < cnt; ++i )
    {
        real x = RNG.exponential() - off;
        real y = RNG.exponential() - off;
        real z = RNG.exponential() - off;
        real t = RNG.exponential() - off;
        avg += x + y + z + t;
        var += x*x + y*y + z*z + t*t;
    }
    cnt *= 4;
    avg /= cnt;
    var = ( var - square(avg) * cnt ) / real(cnt-1);
    printf("EXPONENTIAL  avg = %.12e   var = %.12e\n", avg+off, var);
}


void test_poisson(size_t sup)
{
    for ( size_t n = 0; n < sup; ++n )
    {
        int x = (int)(RNG.gauss() * std::sqrt(n) + n);
        printf("%10lu %9i %9i %9i\n", n, RNG.poisson_knuth(n), RNG.poisson(n), x);
    }
}


//==========================================================================


/**
 Fill array `vec[]` with Gaussian values ~ N(0,1).
 the size of `vec` should be a multiple of 2, and sufficient to hold `end-src` values
 @Return the number of values that were stored in `vec`
 */
template < typename REAL >
REAL * makeGaussians_(REAL dst[], size_t cnt, const int32_t src[])
{
    int32_t const*const end = src + cnt;
    while ( src < end )
    {
        REAL x = REAL(src[0]) * TWO_POWER_MINUS_31;
        REAL y = REAL(src[1]) * TWO_POWER_MINUS_31;
#if 1
        if ( std::abs(x) + std::abs(y) >= M_SQRT2 )
        {
            constexpr REAL S = M_SQRT1_2 + 1;
            // subtract corner and scale to recover a square of size sqrt(1/2)
            REAL cx = S * x - std::copysign(S, x);
            REAL cy = S * y - std::copysign(S, y);
            // apply rotation, scaling by sqrt(2): x' = y + x;  y' = y - x
            x = cy + cx;
            y = cy - cx;
        }
#endif
        REAL w = x * x + y * y;
        if (( w <= 1 ) & ( 0 < w ))
        {
            w = std::sqrt( std::log(w) / ( -0.5 * w ) );
            dst[0] = w * x;
            dst[1] = w * y;
            dst += 2;
        }
        src += 2;
    }
    return dst;
}

template < typename REAL >
REAL * makeExponentials_(REAL dst[], size_t cnt, const int32_t src[])
{
    for ( size_t i = 0; i < cnt; ++i )
    {
        REAL x = std::fabs(static_cast<REAL>(src[i]));
        dst[i] = -std::log(1 - x * TWO_POWER_MINUS_31);
    }
    return dst + cnt;
}


template < typename T >
void print_gaussian(size_t cnt, T const* vec)
{
    for ( size_t i = 0; i < cnt; )
    {
        for ( int k = 0; k < 4; ++k )
        {
            printf(" :");
            for ( int j = 0; j < 8 && i < cnt; ++j )
                printf(" %8.4f", vec[i++]);
        }
        printf("\n");
    }
}

template < typename REAL >
void check_gaussian(size_t cnt, REAL* vec)
{
    size_t nan = 0;
    double off = 1; // assumed mean
    double avg = 0, var = 0;
    for ( size_t i = 0; i < cnt; ++i )
    {
        if ( std::isnan(vec[i]) )
            ++nan;
        else
        {
            avg += vec[i];
            var += ( vec[i] - off ) * ( vec[i] - off );
        }
    }
    avg /= cnt;
    var = ( var - square( avg - off ) * cnt ) / ( cnt - 1 );
    // covariance of odd and even numbers:
    double cov = 0;
    for ( size_t i = 1; i < cnt; i += 2 )
    {
        if ( !std::isnan(vec[i]) )
            cov += ( vec[i-1] - avg ) * ( vec[i] - avg );
    }
    cnt -= nan;
    cov /= ( cnt / 2 );
    printf("%6lu + %6lu NaNs: avg %7.4f var %7.4f cov %7.4f ", cnt, nan, avg, var, cov);
}


//------------------------------------------------------------------------------
#pragma mark -


template < float* (*FUNC)(float*, size_t, const int32_t*) >
void runGaussian(sfmt_t& sfmt, const char str[], int cnt)
{
    float flt[SFMT_N32] = { 0 };
    tic();
    for ( int i = 0; i < cnt; ++i )
    {
        sfmt_gen_rand_all(&sfmt);
        FUNC(flt, SFMT_N32, (int32_t*)sfmt.state);
    }
    float* end = FUNC(flt, SFMT_N32, (int32_t*)sfmt.state);
    printf("%-12s %5.2f :", str, toc(cnt));
    check_gaussian(end-flt, flt);
    print_gaussian(std::min(end-flt, 16l), flt);
}


#if defined(__AVX__)
template < real* (*FUNC)(real*, size_t, const __m256i*) >
void runGaussian(sfmt_t& sfmt, const char str[], int cnt)
{
    real *end, vec[SFMT_N32];
    for ( int i = 0; i < SFMT_N32; ++i )
        vec[i] = NAN;
    tic();
    for ( int i = 0; i < cnt; ++i )
    {
        sfmt_gen_rand_all(&sfmt);
        end = FUNC(vec, SFMT_N256, (__m256i*)sfmt.state);
    }
    printf("%-12s %5.2f :", str, toc(cnt));
    check_gaussian(end-vec, vec);
    print_gaussian(std::min(end-vec, 16l), vec);
}
#endif

/**
 Tests different implementation for speed
 */
void test_gaussian(int cnt)
{
    printf("test_gaussian --- %lu bytes real --- %s\n", sizeof(real), __VERSION__);
    sfmt_t sfmt;
    sfmt_init_gen_rand(&sfmt, time(nullptr));

    tic();
    for ( int i = 0; i < cnt; ++i )
        sfmt_gen_rand_all(&sfmt);
    printf("RNG.refill   %5.2f\n", toc(cnt));
    //print(vec, end);
    
    runGaussian<makeGaussians_>(sfmt, "Gauss_", cnt);
    runGaussian<makeExponentials_>(sfmt, "Exponential", cnt);
}

/**
 Prints many Gaussian distributecd random numbers
 */
void print_gaussian(int cnt)
{
    for ( int i = 0; i < cnt; ++i )
        printf("%10.5f\n", RNG.gauss());
}


//==========================================================================
int main(int argc, char* argv[])
{
    int mode = 4;
    RNG.seed();

    if ( argc > 1 )
        mode = atoi(argv[1]);
    real rate = 1;
    if ( argc > 2 )
        rate = strtod(argv[2], 0);

    switch ( mode )
    {
        case 0:
            test_poisson(1024);
            test_prob();
            break;
            
        case 1:
            test_exponential(0x1p26);
            test_uniform(0x1p20);
            test_gauss(0x1p20);
            break;

        case 3:
            for ( int kk=0; kk < 11; ++kk )
                test_test(rate*kk, 5000000);
            break;
            
        case 4:
            printf("sizeof(uint32_t) = %lu\n", sizeof(uint32_t));
            test_int();
            test_real();
            break;
            
        case 5:
            speed_test();
            break;
            
        case 6:
            silly_test();
            break;
            
        case 7:
            test_gaussian(1<<18);
            break;
            
        case 8:
            print_gaussian(1<<14);
            break;
    }
}

