#include "bits.h"

INDEX_T bits::NBITS_B;
INDEX_T bits::NBYTES_B;
BLOCK_T bits::ONES[BLOCK_S][BLOCK_S];
BLOCK_T bits::ZEROS[BLOCK_S][BLOCK_S];
BLOCK_T bits::ALL_ONES;
BLOCK_T bits::ALL_ZEROS;
BLOCK_T bits::SINGLE_ONE[BLOCK_S];
BLOCK_T bits::SINGLE_ZERO[BLOCK_S];
BLOCK_T bits::MASK_MOD;
INDEX_T bits::DIVISION_SHIFT;
bool bits::StaticInitialized = bits::StaticInitializer();

bool bits::StaticInitializer() {
    /* compute the numbers of bits and bytes in one block */
    NBYTES_B = sizeof (BLOCK_T);
    NBITS_B = NBYTES_B * 8;
    assert(NBITS_B == BLOCK_S);
    /* compute ONES and ZEROS */
    for (int i = 0; i < NBITS_B; i++) {
        ONES[i][i] = (BLOCK_T) 1 << i;
        ZEROS[i][i] = ~ONES[i][i];
    }
    for (int i = 0; i < NBITS_B; i++) {
        for (int j = i + 1; j < NBITS_B; j++) {
            for (int k = i; k <= j; k++) {
                ONES[i][j] |= ONES[k][k];
                ZEROS[i][j] = ~ONES[i][j];
            }
        }
    }
    /* derive ALL_ONE and ALL_ZERO */
    ALL_ONES = ONES[0][NBITS_B - 1];
    ALL_ZEROS = ZEROS[0][NBITS_B - 1];
    /* derive SINGLE_ONE and SINGLE_ZERO */
    for (int i = 0; i < NBITS_B; i++) {
        SINGLE_ONE[i] = ONES[i][i];
        SINGLE_ZERO[i] = ZEROS[i][i];
    }
    /* compute the mask for modular operation  */
    MASK_MOD = NBITS_B - 1;
    /* compute the shift for division operation */
    DIVISION_SHIFT = __builtin_ctz(NBITS_B);
    return true;
}

INDEX_T bits::TI(INDEX_T i) {
    /*  NBITS_B is a power of 2, so i % NBITS_B = i & (NBITS_B-1)) = i & MASK_MOD */
    return i & MASK_MOD;
}

INDEX_T bits::KI(INDEX_T i) {
    /*  NBITS_B is a power of 2, so i / NBITS_B = i >> DIVISION_SHIFT, where DIVISION_SHIFT is the number of trailing zeros in NBITS_B */
    return (i >> DIVISION_SHIFT);
}

bits::bits(const INDEX_T n)
: num_bits(n), num_blocks(1 + (n - 1) / NBITS_B), tail_ones_mask(ONES[0][TI(n - 1)]), tail_zeros_mask(ZEROS[0][TI(n - 1)]) {
    data = (BLOCK_T*) malloc(num_blocks * NBYTES_B);
    assert(data != NULL);
    clear();
}

bits::bits(const bits& orig)
: num_bits(orig.num_bits), num_blocks(orig.num_blocks), tail_ones_mask(orig.tail_ones_mask), tail_zeros_mask(orig.tail_zeros_mask) {
    data = (BLOCK_T*) malloc(num_blocks * NBYTES_B);
    memcpy(data, orig.data, num_blocks * NBYTES_B);
    assert(data != NULL);
}

bits::~bits() {
    free(data);
}

void bits::clear(INDEX_T i) {
    data[KI(i)] &= SINGLE_ZERO[TI(i)];
}

void bits::clear(INDEX_T i, INDEX_T j) {
    /* check if the first bit and the last bit are in the same block */
    if (KI(i) == KI(j)) {
        /* the first bit and the last bit are in the same block */
        /* clear the included bits in the block */
        data[KI(i)] &= ZEROS[TI(i)][TI(j)];
    } else {
        /* the first bit and the last bit are not in the same block */
        /* clear included bits in the first block */
        data[KI(i)] &= ZEROS[TI(i)][NBITS_B - 1];
        /* clear included bits in the last block */
        data[KI(j)] &= ZEROS[0][TI(j)];
        /* clear all bits in the middle blocks */
        for (INDEX_T bi = KI(i) + 1; bi < KI(j); bi++) {
            data[bi] = ALL_ZEROS;
        }
    }
}

void bits::clear() {
    for (INDEX_T bi = 0; bi < num_blocks; bi++) {
        data[bi] = ALL_ZEROS;
    }
}

void bits::set(INDEX_T i) {
    data[KI(i)] |= SINGLE_ONE[TI(i)];
}

void bits::set(INDEX_T i, INDEX_T j) {
    /* check if the first bit and the last bit are in the same block */
    if (KI(i) == KI(j)) {
        /* the first bit and the last bit are in the same block */
        /* set the included bits in the block */
        data[KI(i)] |= ONES[TI(i)][TI(j)];
    } else {
        /* the first bit and the last bit are not in the same block */
        /* set included bits in the first block */
        data[KI(i)] |= ONES[TI(i)][NBITS_B - 1];
        /* set included bits in the last block */
        data[KI(j)] |= ONES[0][TI(j)];
        /* set all bits in the middle blocks */
        for (INDEX_T bi = KI(i) + 1; bi < KI(j); bi++) {
            data[bi] = ALL_ONES;
        }
    }
}

void bits::set() {
    for (INDEX_T bi = 0; bi < num_blocks; bi++) data[bi] = ALL_ONES;
}

BLOCK_T bits::get(INDEX_T i) const {
    return data[KI(i)] & SINGLE_ONE[TI(i)];
}

INDEX_T bits::count_ones() const {
    if (num_blocks == 1) return __builtin_popcountll(data[0] & tail_ones_mask);
    INDEX_T ret = __builtin_popcountll(data[num_blocks - 1] & tail_ones_mask);
    for (INDEX_T bi = 0; bi <= num_blocks - 2; bi++) ret += __builtin_popcountll(data[bi]);
    return ret;
}

INDEX_T bits::trailing_zeros() const {
    INDEX_T bi = 0;
    /* get to the rightmost un-empty (contains at least one 1) block */
    while (bi < num_blocks && zero_only(bi)) bi++;
    /* check if its all bits are 0's */
    if (bi == num_blocks) return num_bits;
    return __builtin_ctzl(data[bi]) + bi * NBITS_B;
}

bool bits::zero_only(INDEX_T bi) const {
    if (bi == num_blocks - 1) return (data[bi] | tail_zeros_mask) == tail_zeros_mask;
    return data[bi] == ALL_ZEROS;
}

bool bits::ones_only(INDEX_T bi) const {
    if (bi == num_blocks - 1) return (data[bi] & tail_ones_mask) == tail_ones_mask;
    return data[bi] == ALL_ONES;
}