// reproduce.cpp
//
// Small self-contained reproduction example for the BCF10 data headers.
//
// To check another form, replace FORM below.
//
// Current form, catalogue row 1824:
//   f = 017+034+058+146+257+259+269+356+379+468+489+678
//   known values:
//     dim = 10
//     |St| = 384
//     hash = d58690f8dbfdec9e
//     d2 = 408
//     witness = 03+25  (one possible witness; the program may print another)
//     enumerator / 2^dim = 408:4160 416:134400 424:3012416
//                          432:47806464 440:543929216 448:4654207488
//                          456:30550505344 464:155604375552
//                          472:616329946816 480:1900153520896
//                          488:4559225219520 496:8515599478784
//                          504:12385533426944 512:14032253041664
//
// Compile from this directory with:
//   g++ -std=c++20 -O3 reproduce.cpp -o reproduce

#include "enumerator.h"
#include "invariant.h"
#include "stabilizer.h"

#include <algorithm>
#include <array>
#include <chrono>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <random>
#include <sstream>
#include <string>

using U16 = std::uint16_t;
using U64 = std::uint64_t;
using Clock = std::chrono::steady_clock;

constexpr int M = 10;
constexpr int BCF10_DIM = 120;
constexpr int WORD_BITS = 64;
constexpr int ENUM_BINS = 65;
constexpr int ENUM_STEP = 8;
constexpr int RM2_DIM = 56;
constexpr int RM2_QUADRATIC_DIM = M * (M - 1) / 2;
constexpr int RM2_CONST_BIT = RM2_DIM - 1;
constexpr int TRUTH_SIZE = 1 << M;
constexpr int HALF_TRUTH_SIZE = TRUTH_SIZE / 2;
constexpr int ENUM_RESULT_SIZE = ENUM_BINS + 1;

constexpr std::array<std::array<int, 3>, 12> FORM = {{
    {{0, 1, 7}}, {{0, 3, 4}}, {{0, 5, 8}}, {{1, 4, 6}},
    {{2, 5, 7}}, {{2, 5, 9}}, {{2, 6, 9}}, {{3, 5, 6}},
    {{3, 7, 9}}, {{4, 6, 8}}, {{4, 8, 9}}, {{6, 7, 8}},
}};

constexpr U64 RM2_COSET_SIZE = U64(1) << RM2_DIM;
constexpr int GL_SAMPLES = 1000;
constexpr U64 GL_RANDOM_SEED = 0x4bcf10ULL;
constexpr int LABEL_WIDTH = 30;
constexpr int VALUE_WIDTH = 22;

constexpr std::array<std::array<int, 3>, BCF10_DIM> make_triples() {
    std::array<std::array<int, 3>, BCF10_DIM> triples{};
    int bit = 0;
    for (int i = 0; i < M; i++) {
        for (int j = i + 1; j < M; j++) {
            for (int k = j + 1; k < M; k++, bit++) {
                triples[bit] = {{i, j, k}};
            }
        }
    }
    return triples;
}

constexpr auto TRIPLES = make_triples();

struct PackedForm {
    U64 lo = 0;
    U64 hi = 0;

    int bit(int i) const {
        return i < WORD_BITS ? int((lo >> i) & 1ULL) : int((hi >> (i - WORD_BITS)) & 1ULL);
    }

    void set(int i) {
        if (i < WORD_BITS) {
            lo |= U64(1) << i;
        } else {
            hi |= U64(1) << (i - WORD_BITS);
        }
    }
};

using Matrix = std::array<U16, M>;

int triple_index(int a, int b, int c) {
    if (a > b) std::swap(a, b);
    if (b > c) std::swap(b, c);
    if (a > b) std::swap(a, b);

    for (int bit = 0; bit < BCF10_DIM; bit++) {
        if (TRIPLES[bit][0] == a && TRIPLES[bit][1] == b && TRIPLES[bit][2] == c) return bit;
    }
    return -1;
}

PackedForm pack_form() {
    PackedForm f;
    for (auto t : FORM) f.set(triple_index(t[0], t[1], t[2]));
    return f;
}

std::string form_string() {
    std::ostringstream out;
    for (int i = 0; i < int(FORM.size()); i++) {
        if (i) out << '+';
        out << FORM[i][0] << FORM[i][1] << FORM[i][2];
    }
    return out.str();
}

std::string hex64(U64 x) {
    std::ostringstream out;
    out << std::hex << std::setw(16) << std::setfill('0') << x;
    return out.str();
}

int det3_rows(U16 a, U16 b, U16 c, int i, int j, int k) {
    int ai = (a >> i) & 1, aj = (a >> j) & 1, ak = (a >> k) & 1;
    int bi = (b >> i) & 1, bj = (b >> j) & 1, bk = (b >> k) & 1;
    int ci = (c >> i) & 1, cj = (c >> j) & 1, ck = (c >> k) & 1;
    return (ai & ((bj & ck) ^ (bk & cj))) ^
           (aj & ((bi & ck) ^ (bk & ci))) ^
           (ak & ((bi & cj) ^ (bj & ci)));
}

PackedForm pullback(PackedForm f, const Matrix& A) {
    std::array<int, BCF10_DIM> active{};
    int nactive = 0;
    for (int bit = 0; bit < BCF10_DIM; bit++) {
        if (f.bit(bit)) active[nactive++] = bit;
    }

    PackedForm out;
    for (int out_bit = 0; out_bit < BCF10_DIM; out_bit++) {
        auto out_term = TRIPLES[out_bit];
        int value = 0;
        for (int q = 0; q < nactive; q++) {
            auto in_term = TRIPLES[active[q]];
            value ^= det3_rows(
                A[out_term[0]], A[out_term[1]], A[out_term[2]],
                in_term[0], in_term[1], in_term[2]
            );
        }
        if (value) out.set(out_bit);
    }
    return out;
}

int matrix_rank(Matrix A) {
    int rank = 0;
    for (int col = 0; col < M; col++) {
        int pivot = -1;
        for (int row = rank; row < M; row++) {
            if ((A[row] >> col) & 1) {
                pivot = row;
                break;
            }
        }
        if (pivot < 0) continue;

        std::swap(A[rank], A[pivot]);
        for (int row = 0; row < M; row++) {
            if (row != rank && ((A[row] >> col) & 1)) A[row] ^= A[rank];
        }
        rank++;
    }
    return rank;
}

Matrix random_gl_matrix(std::mt19937_64& rng) {
    std::uniform_int_distribution<int> dist(0, (1 << M) - 1);

    while (true) {
        Matrix A{};
        for (int row = 0; row < M; row++) A[row] = U16(dist(rng));
        if (matrix_rank(A) == M) return A;
    }
}

int eval_cubic(PackedForm f, int x) {
    int value = 0;
    for (int bit = 0; bit < BCF10_DIM; bit++) {
        auto t = TRIPLES[bit];
        if (f.bit(bit)) value ^= ((x >> t[0]) & 1) & ((x >> t[1]) & 1) & ((x >> t[2]) & 1);
    }
    return value;
}

int eval_rm2(U64 q, int x) {
    int value = 0;
    int bit = 0;
    for (int i = 0; i < M; i++) {
        for (int j = i + 1; j < M; j++, bit++) {
            if ((q >> bit) & 1ULL) value ^= ((x >> i) & 1) & ((x >> j) & 1);
        }
    }
    for (int i = 0; i < M; i++) {
        if ((q >> (RM2_QUADRATIC_DIM + i)) & 1ULL) value ^= (x >> i) & 1;
    }
    if ((q >> RM2_CONST_BIT) & 1ULL) value ^= 1;
    return value;
}

int weight_with_rm2(PackedForm f, U64 q) {
    int weight = 0;
    for (int x = 0; x < TRUTH_SIZE; x++) weight += eval_cubic(f, x) ^ eval_rm2(q, x);
    return weight;
}

U64 coset_size(const std::array<U64, ENUM_RESULT_SIZE>& result) {
    U64 total = result[ENUM_BINS];
    for (int k = 0; k < ENUM_BINS - 1; k++) total += 2 * result[1 + k];
    return total;
}

int distance_from_result(const std::array<U64, ENUM_RESULT_SIZE>& result) {
    for (int k = 0; k < ENUM_BINS; k++) {
        if (result[1 + k] != 0) return ENUM_STEP * k;
    }
    return HALF_TRUTH_SIZE;
}

bool enumerator_divisible_by_dim(const std::array<U64, ENUM_RESULT_SIZE>& result, int n) {
    U64 mask = (U64(1) << n) - 1;
    for (int k = 0; k < ENUM_BINS; k++) {
        if ((result[1 + k] & mask) != 0) return false;
    }
    return true;
}

std::string normalized_enumerator_string(const std::array<U64, ENUM_RESULT_SIZE>& result, int n) {
    std::ostringstream out;
    bool first = true;
    for (int k = 0; k < ENUM_BINS; k++) {
        if (result[1 + k] == 0) continue;
        if (!first) out << ' ';
        out << ENUM_STEP * k << ':' << (result[1 + k] >> n);
        first = false;
    }
    return first ? std::string("(empty)") : out.str();
}

std::string rm2_string(U64 q) {
    std::ostringstream out;
    bool first = true;

    int bit = 0;
    for (int i = 0; i < M; i++) {
        for (int j = i + 1; j < M; j++, bit++) {
            if (((q >> bit) & 1ULL) == 0) continue;
            if (!first) out << '+';
            out << i << j;
            first = false;
        }
    }

    for (int i = 0; i < M; i++) {
        if (((q >> (RM2_QUADRATIC_DIM + i)) & 1ULL) == 0) continue;
        if (!first) out << '+';
        out << i;
        first = false;
    }

    if ((q >> RM2_CONST_BIT) & 1ULL) {
        if (!first) out << '+';
        out << "const";
        first = false;
    }

    return first ? std::string("zero polynomial") : out.str();
}

void print_row(const std::string& name, const std::string& value) {
    std::cout << std::left << std::setw(LABEL_WIDTH) << name << "  " << value << "\n" << std::flush;
}

void print_header(const std::string& left, const std::string& right) {
    std::cout << std::left << std::setw(LABEL_WIDTH) << left << "  " << right << "\n";
    std::cout << std::string(LABEL_WIDTH, '-') << "  " << std::string(VALUE_WIDTH, '-') << "\n";
}

double seconds_since(Clock::time_point start, Clock::time_point finish) {
    return std::chrono::duration<double>(finish - start).count();
}

std::string seconds_string(double seconds) {
    std::ostringstream out;
    out << std::fixed << std::setprecision(seconds < 1.0 ? 6 : 3) << seconds;
    return out.str();
}

std::string status(bool ok) {
    return ok ? "OK" : "FAILED";
}

void print_check(const std::string& check, bool ok) {
    std::cout << std::left << std::setw(LABEL_WIDTH) << check
              << "  " << std::left << std::setw(VALUE_WIDTH) << status(ok) << "\n";
}

int main() {
    PackedForm f = pack_form();

    std::cout << "BCF10 reproduction example (single-threaded; may take a few minutes)\n\n";

    U64 h = bcf10::inv::hash(f.lo, f.hi);
    int n = bcf10::st::dim(f.lo, f.hi);

    auto st_start = Clock::now();
    U64 s = bcf10::st::order(f.lo, f.hi);
    double st_seconds = seconds_since(st_start, Clock::now());

    print_header("quantity", "value");
    print_row("form", form_string());
    print_row("dim", std::to_string(n));
    print_row("|St|", std::to_string(s));
    print_row("invariant hash", hex64(h));

    std::mt19937_64 rng(GL_RANDOM_SEED);
    std::array<PackedForm, GL_SAMPLES> images{};
    for (int i = 0; i < GL_SAMPLES; i++) {
        images[i] = pullback(f, random_gl_matrix(rng));
    }

    auto hash_start = Clock::now();
    bool hash_invariant = true;
    for (PackedForm g : images) {
        U64 gh = bcf10::inv::hash(g.lo, g.hi);
        hash_invariant = hash_invariant && (gh == h);
    }
    double hash_seconds = seconds_since(hash_start, Clock::now()) / GL_SAMPLES;

    auto enum_start = Clock::now();
    auto result = bcf10::enu::compute_coset_weight_enumerator(f.lo, f.hi);
    double enum_seconds = seconds_since(enum_start, Clock::now());
    int d2 = distance_from_result(result);
    U64 total = coset_size(result);
    int witness_weight = weight_with_rm2(f, result[0]);
    bool enumerator_divisible = enumerator_divisible_by_dim(result, n);

    print_row("d2", std::to_string(d2));
    print_row("witness", rm2_string(result[0]));

    std::cout << "\n";
    print_header("check", "status");
    print_check("GL(10,2) image hashes match", hash_invariant);
    print_check("coset size is 2^56", total == RM2_COSET_SIZE);
    print_check("enumerator divisible by 2^dim", enumerator_divisible);
    print_check("witness attains d2", witness_weight == d2);

    std::cout << "\n";
    print_header("operation", "time, s");
    print_row("invariant", seconds_string(hash_seconds));
    print_row("stabilizer", seconds_string(st_seconds));
    print_row("enumerator", seconds_string(enum_seconds));

    std::cout << "\ncoset enumerator / 2^dim\n";
    std::cout << normalized_enumerator_string(result, n) << "\n";

    return 0;
}
