
from sys import stderr

from dynamite import config
from dynamite.operators import sigmax, sigmay, sigmaz, index_sum
from dynamite.subspaces import SpinConserve
from dynamite.tools import mpi_print, MPI_COMM_WORLD

import numpy as np
from argparse import ArgumentParser


def main():
    args = parse_args()

    # we print this to stderr to separate it from the data output below
    mpi_print('== Run parameters: ==', file=stderr)
    for key, value in vars(args).items():
        if key == 'seed':
            continue  # we handle seed below
        mpi_print(f'  {key}, {value}', file=stderr)

    # set a random seed, and output it for reproducibility
    if args.seed is None:
        seed = get_shared_seed()
    else:
        seed = args.seed
    mpi_print(f'  seed, {seed}', file=stderr)
    np.random.seed(seed)

    # extra newline for readability of the output
    mpi_print(file=stderr)

    # set spin chain length globally for dynamite
    config.L = args.L

    # work in half-filling subspace
    config.subspace = SpinConserve(args.L, args.L//2)

    # column headers
    mpi_print('h,energy_point,entropy,ratio')

    for _ in range(args.iters):
        for h in np.linspace(args.h_min, args.h_max, args.h_points):
            H = build_hamiltonian(h, seed)

            # first solve for the exterior ones

            # by default eigsolve finds the lowest eigenpairs
            evals, evecs = H.eigsolve(nev=args.nev, getvecs=True)
            print_eig_stats(evals, evecs, h, energy_point=0)
            min_eval = evals[0]

            # now the highest ones
            evals, evecs = H.eigsolve(nev=args.nev, which='highest', getvecs=True)
            print_eig_stats(evals, evecs, h, energy_point=1)
            max_eval = evals[0]

            for energy_point in np.linspace(0, 1, args.energy_points)[1:-1]:
                energy_target = min_eval + energy_point*(max_eval-min_eval)
                evals, evecs = H.eigsolve(nev=args.nev, target=energy_target, getvecs=True)
                print_eig_stats(evals, evecs, h=h, energy_point=energy_point)


def build_hamiltonian(h, seed=0xB0BACAFE):
    '''
    Implements the nearest-neighbor Heisenberg interaction on a 1D spin chain,
    plus random Z fields on each site.
    '''

    # 0.25 because we are working with Paulis and want spin operators
    one_site_heisenberg = 0.25*sum(s(0)*s(1) for s in [sigmax, sigmay, sigmaz])
    full_chain_heisenberg = index_sum(one_site_heisenberg)

    # 0.5 again for Pauli -> spin operator conversion
    random_fields = sum(0.5*np.random.uniform(-h, h)*sigmaz(i) for i in range(config.L))

    return full_chain_heisenberg + random_fields


def print_eig_stats(evals, evecs, h, energy_point):
    '''
    Compute the mean adjacent gap ratio and half-chain entanglement entropy
    for the provided eigenvalues and eigenstates
    '''
    # sum the entropy for all evecs then divide by nev for the mean
    # NOTE: entanglement_entropy returns the EE value only on MPI rank 0, and -1 on all other ranks.
    #       this is OK here because mpi_print below only prints on rank 0
    entropy = sum(v.entanglement_entropy(keep=range(config.L//2)) for v in evecs)
    entropy /= len(evecs)

    # compute the adjacent gap ratio of the eigenvals
    evals = sorted(evals)
    ratio = 0
    for i in range(1, len(evals)-1):
        this_gap = evals[i] - evals[i-1]
        next_gap = evals[i+1] - evals[i]
        ratio += min(this_gap, next_gap) / max(this_gap, next_gap)
    ratio /= len(evals)-2

    mpi_print(f'{h}, {energy_point}, {entropy}, {ratio}')


def get_shared_seed():
    '''
    Generate a seed for the random number generator, that is shared by all MPI ranks
    '''
    from random import SystemRandom

    # get PETSc's MPI communicator object
    comm = MPI_COMM_WORLD()

    # have rank 0 pick a seed
    if comm.rank == 0:
        # get a hardware-random number from the system to use as a seed
        seed = SystemRandom().randrange(2**32)
    else:
        seed = None

    # if there is only one rank, don't need to do anything fancy
    # doing this before using mpi4py below allows us to avoid needing mpi4py installed
    # when we only use one rank
    if comm.size == 1:
        return seed

    # otherwise, we need to communicate the seed among the ranks, using mpi4py
    # so we convert to a full-fledged mpi4py communicator class
    comm = comm.tompi4py()

    # now broadcast from rank 0 to all other ranks
    seed = comm.bcast(seed, root=0)

    return seed


def parse_args():
    '''
    Read arguments from the command line.
    '''
    parser = ArgumentParser()

    parser.add_argument('-L', type=int, required=True, help='spin chain length')
    # the weird type here allows passing integers in both decimal and hex
    parser.add_argument('--seed', type=lambda x: int(x, 0),
                        help='seed for random number generator. if omitted, a random '
                             'seed is chosen by querying system hardware randomness')
    parser.add_argument('--iters', type=int, default=16,
                        help='number of disorder realizations')

    parser.add_argument('--energy-points', type=int, default=3,
                        help='number of points in the spectrum to target')
    parser.add_argument('--h-points', type=int, default=5,
                        help='number of disorder strengths to test')
    parser.add_argument('--h-min', type=float, default=1,
                        help='minimum value of disorder strength h')
    parser.add_argument('--h-max', type=float, default=5,
                        help='maximum value of disorder strength h')
    parser.add_argument('--nev', type=int, default=32,
                        help='number of eigenpairs to compute at each point')

    return parser.parse_args()


if __name__ == '__main__':
    main()
