'''
This module defines functions that operate on the MSC representation of a matrix.

The MSC_to_numpy method is the reference implementation that defines the MSC format.
'''

from itertools import chain
import numpy as np
import scipy.sparse
from .bitwise import parity, intlog2

from ._backend.bbuild import dnm_int_t

msc_dtype = np.dtype([('masks', dnm_int_t),
                      ('signs', dnm_int_t),
                      ('coeffs', np.complex128)])


def msc_to_numpy(msc, dims, idx_to_state=None, state_to_idx=None, sparse=True):
    '''
    Build a NumPy array from an MSC array. This method defines the MSC
    representation.

    Parameters
    ----------

    MSC : np.ndarray(dtype = msc_dtype)
        An MSC array.

    dims : (int, int)
        The dimensions (M, N) of the matrix.

    idx_to_state : function(int), optional
        If working in a subspace, a function to map indices to states for
        the *left* subspace.

    state_to_idx : function(int), optional
        If working in a subspace, a function to map states to indices for
        the *right* subspace.

    sparse : bool, optional
        Whether to return a scipy sparse matrix or a dense numpy array.

    Returns
    -------
    scipy.spmatrix or np.ndarray (dtype = np.complex128)
        The matrix
    '''
    msc = np.array(msc, copy = False, dtype = msc_dtype)
    data = np.ndarray(msc.size * np.min(dims), dtype = np.complex128)
    # data[:] = -1 # for testing if we have correctly sized buffers
    row_idxs = np.ndarray(data.size, dtype = dnm_int_t)
    col_idxs = np.ndarray(data.size, dtype = dnm_int_t)
    mat_idx = 0

    # if these aren't supplied, they are the identity
    if idx_to_state is None:
        idx_to_state = lambda x: x

    if state_to_idx is None:
        state_to_idx = lambda x: x

    for row_idx in range(dims[0]):
        ket = idx_to_state(row_idx)
        bra = msc['masks'] ^ ket
        col_idx = state_to_idx(bra)
        good = np.nonzero(col_idx != -1)[0]
        nnew = len(good)
        if nnew == 0:
            continue

        good_col_idx = col_idx[good]
        good_bras = bra[good]
        sign = 1 - 2*(parity(msc['signs'][good] & good_bras))

        nnew = len(good)
        data[mat_idx:mat_idx+nnew] = sign * msc['coeffs'][good]
        row_idxs[mat_idx:mat_idx+nnew] = row_idx
        col_idxs[mat_idx:mat_idx+nnew] = good_col_idx
        mat_idx += nnew

    # trim to the amount we used
    data = data[:mat_idx]
    row_idxs = row_idxs[:mat_idx]
    col_idxs = col_idxs[:mat_idx]

    ary = scipy.sparse.csc_matrix((data, (row_idxs, col_idxs)), shape = dims)

    if not sparse:
        ary = ary.toarray()

    return ary

def is_hermitian(msc):
    '''
    Checks whether a given MSC matrix represents a Hermitian operator or not.

    Parameters
    ----------
    msc : np.ndarray
        The MSC matrix

    Returns
    -------

    bool
        Whether the matrix is Hermitian
    '''

    should_be_complex = parity(msc['masks'] & msc['signs']) == 1

    if np.any(np.real(msc['coeffs'][should_be_complex])):
        return False

    if np.any(np.imag(msc['coeffs'][~should_be_complex])):
        return False

    return True

def msc_sum(iterable):
    '''
    Defines the matrix addition operation for any number of MSC matrices returned by
    ``iterable``.

    Parameters
    ----------
    iterable : iter
        An iterable containing MSC representations of matrices.

    Returns
    -------
    np.ndarray
        The sum as an MSC matrix
    '''
    term_lst = list(iterable)
    if not term_lst:
        return np.ndarray(0, dtype=msc_dtype)
    return np.hstack(term_lst)

def msc_product(iterable):
    '''
    Defines the matrix-matrix-matrix-... product operation for MSC matrices.

    Parameters
    ----------
    iterable : iter
        An iterable containing the MSC matrices to be multiplied together, in order.

    Returns
    -------
    np.ndarray
        The product
    '''
    vals = list(iterable)

    # an efficient way of doing the cartesian product
    all_terms = np.array(np.meshgrid(*vals)).reshape(len(vals),-1)

    # the following is the product on the MSC representation
    rtn = all_terms[0]

    # if there was a zero in the terms
    if all_terms.size == 0:
        return rtn

    for term in all_terms[1:]:
        flipped = term['masks'] & rtn['signs']
        rtn['masks'] ^= term['masks']
        rtn['signs'] ^= term['signs']
        rtn['coeffs'] *= (-1)**parity(flipped) * term['coeffs']

    return rtn

def shift(msc, shift_idx, wrap_idx):
    '''
    Shift an MSC representation along the spin chain. Guaranteed to not modify input,
    but not guaranteed to return a copy (could return the same object).

    Parameters
    ----------
    MSC : np.ndarray
        The input MSC representation.

    shift_idx : int
        The number of spins to shift by.

    wrap : int or None
        The index at which to wrap around to the beginning of the spin chain.
        If None, do not wrap.

    Returns
    -------
    np.ndarray
        The shifted MSC representation.
    '''

    if shift_idx == 0:
        return msc

    msc = msc.copy()

    msc['masks'] <<= shift_idx
    msc['signs'] <<= shift_idx

    if wrap_idx is not None:

        mask = (-1) << wrap_idx

        for v in [msc['masks'], msc['signs']]:

            # find the bits that need to wrap around
            overflow = v & mask

            # wrap them to index 0
            overflow >>= wrap_idx

            # recombine them with the ones that didn't get wrapped
            v |= overflow

            # shave off the extras that go past L
            v &= ~mask

    return msc

def combine_and_sort(msc):
    '''
    Take an MSC representation, sort it, and combine like terms.

    Parameters
    ----------
    MSC : np.ndarray
        The input MSC representation.

    Returns
    -------
    np.ndarray
        The reduced representation (may be of a smaller dimension).
    '''

    unique, inverse = np.unique(msc[['masks','signs']], return_inverse = True)
    rtn = np.ndarray(unique.size, dtype = msc.dtype)

    rtn['masks'] = unique['masks']
    rtn['signs'] = unique['signs']

    rtn['coeffs'] = 0
    for i,(_,_,c) in enumerate(msc):
        rtn[inverse[i]]['coeffs'] += c

    rtn = rtn[rtn['coeffs'] != 0]

    return rtn

def truncate(msc, tol):
    '''
    Remove terms whose magnitude is less than `tol`.

    Parameters
    ----------
    MSC : np.ndarray
        The input MSC representation.

    tol : float
        The cutoff for truncation.

    Returns
    -------
    np.ndarray
        The truncated MSC representation.
    '''
    if tol < 0:
        raise ValueError('tol cannot be less than zero')

    return msc[np.abs(msc['coeffs'])>tol]

def serialize(msc):
    '''
    Take an MSC representation and spin chain length and serialize it into a
    byte string.

    The format is
    `nterms int_size masks signs coefficients`
    where `nterms`, and `int_size` are utf-8 text, including newlines, and the others
    are each just a binary blob, one after the other. `int_size` is an integer representing
    the size of the int data type used (32 or 64 bits).

    Binary values are saved in big-endian format, to be compatible with PETSc defaults.

    Parameters
    ----------
    MSC : np.array
        The MSC representation

    Returns
    -------
    bytes
        A byte string containing the serialized operator.
    '''

    rtn = b''

    rtn += (str(msc.size)+'\n').encode('utf-8')
    rtn += (str(msc.dtype['masks'].itemsize*8)+'\n').encode('utf-8')

    int_t = msc.dtype[0].newbyteorder('B')
    cplx_t = np.dtype(np.complex128).newbyteorder('B')
    rtn += msc['masks'].astype(int_t, casting='equiv', copy=False).tobytes()
    rtn += msc['signs'].astype(int_t, casting='equiv', copy=False).tobytes()
    rtn += msc['coeffs'].astype(cplx_t, casting='equiv', copy=False).tobytes()

    return rtn

def deserialize(data):
    '''
    Reverse the serialize operation.

    Parameters
    ----------
    data : bytes
        The byte string containing the serialized data.

    Returns
    -------
    tuple(int, np.ndarray)
        A tuple of the form (L, MSC)
    '''

    start = 0
    stop = data.find(b'\n')
    msc_size = int(data[start:stop])

    start = stop + 1
    stop = data.find(b'\n', start)
    int_size = int(data[start:stop])
    if int_size == 32:
        int_t = np.int32
    elif int_size == 64:
        int_t = np.int64
    else:
        raise ValueError('Invalid int_size. Perhaps file is corrupt.')

    msc = np.ndarray(msc_size, dtype=msc_dtype)

    mv = memoryview(data)
    start = stop + 1
    int_msc_bytes = msc_size * int_size // 8

    masks = np.frombuffer(mv[start:start+int_msc_bytes],
                          dtype=np.dtype(int_t).newbyteorder('B'))

    # operator was saved using 64 bit dynamite, but loaded using 32
    if int_size == 64 and msc_dtype['masks'].itemsize == 4:
        if np.count_nonzero(masks >> 31):
            raise ValueError('dynamite must be built with 64-bit indices'
                             'to load operator on more than 31 spins.')

    msc['masks'] = masks
    start += int_msc_bytes
    msc['signs'] = np.frombuffer(mv[start:start+int_msc_bytes],
                                 dtype=np.dtype(int_t).newbyteorder('B'))
    start += int_msc_bytes
    msc['coeffs'] = np.frombuffer(mv[start:],
                                  dtype=np.dtype(np.complex128).newbyteorder('B'))

    return msc

def max_spin_idx(msc):
    '''
    Compute the largest spin index on which the operator represented by MSC
    has support. Returns -1 for an empty operator.

    Parameters
    ----------
    MSC : np.array
        The MSC operator

    Returns
    -------
    int
        The index
    '''
    if msc.size == 0:
        return -1

    max_op = np.max([np.max(msc['masks']), np.max(msc['signs'])])
    return intlog2(max_op)

def nnz(msc):
    '''
    Compute the number of nonzero elements per row of the sparse matrix representation
    of this MSC operator.
    '''
    return len(np.unique(msc['masks']))

def table(msc, L):
    '''
    Build a table in string format that shows all of the terms in the MSC matrix.

    Displays only the real part of coefficients, since complex coefficients would imply
    non-Hermitian matrices.
    '''

    coeff_strs = []
    pauli_strs = []
    for m, s, c in msc:

        pauli_str = ''
        for i in range(L):
            maskbit = (m >> i) & 1
            signbit = (s >> i) & 1

            pauli_str += [['-', 'Z'],
                          ['X', 'Y']][maskbit][signbit]

            if maskbit and signbit:
                c *= -1j

        coeff_strs.append(_get_coeff_str(c, trunc=True))
        pauli_strs.append(pauli_str)

    coeff_just_len = max(7, max((len(s) for s in coeff_strs), default=0))

    rtn = f' {"coeff.".center(coeff_just_len)}'
    rtn += ' | '

    npad_operator = max(L - 8, 0)//2
    text_pad = ' '*npad_operator
    rtn += f'{text_pad}operator{text_pad} \n'
    rtn += '='*(len(rtn)-1)
    rtn += '\n'

    rtn += '\n'.join(f' {cstr.rjust(coeff_just_len)} | {pstr}' for cstr, pstr in zip(coeff_strs, pauli_strs))

    return rtn


def _get_coeff_str(x, trunc=False, parens=False):
    if trunc:
        both_parts = x.real != 0 and x.imag != 0
        if both_parts:
            if 1E-2 <= abs(x) <= 1E2 or x == 0:
                rtn = f'{x:.2f}'
            else:
                rtn = f'{x:.2e}'

        else:
            big = not (1E-2 <= abs(x) <= 1E2) and not x == 0
            if x.imag:
                if big:
                    rtn = f'{x.imag:.2e}j'
                else:
                    rtn = f'{x.imag:.3f}j'
            else:
                if big:
                    rtn = f'{x.real:.2e}'
                else:
                    rtn = f'{x.real:.3f}'

        if parens and (both_parts or 'e' in rtn):
            rtn = f'({rtn})'

    else:
        rtn = str(x)
        if not parens and '(' in rtn:
            rtn = rtn[1:-1]

    return rtn
