#!/usr/bin/python3
#
# Calculate one- and two-photon photoionization cross section
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Typical usage:
#
#     python3 weak-field-photoionization.py  [OPTIONS]
#
# Reads dipole elements from the file 'molecular_data' and scattering coefficients (written by COMPAK) from
# the files whose filenames are provided on the command line. Uses the inner region dipoles and wave function
# coefficients and channel information to evaluate the one- and two-photon photoionization cross sections for
# all scattering energies available in the Ak files. Performs no compatibility checks between the the files.
#
# Order     First Ak file             Other Ak files           Output
# --------  ------------------------  -----------------------  --------------------------------------------
# 1         initial symmetry          final symmetries         CS averaged over final polarizations
# 2         initial & final symmetry  intermediate symmetries  CS for twice the same polarizations (summed)
# 4         initial & final symmetry  intermediate symmetry    CS for linear polarization consistent with i -> n -> f
#
# Assumes that the file 'molecular_data' is a stream file and uses int32 and real64 types.
# Assumes that the Ak files are sequential unformatted files written by COMPAK and use int64 and real64 types.
#
# The intermediate transition matrix elements do not contain the phase factor i^(-l) exp(i sigma_l(k)). It must be
# added manually if complete transition elements are required. This element is not needed for calculation of
# unoriented cross sections performed by this script.
#
# Options
#
#    -h, --help        Print help
#    --order   N       Perturbation order (1 = one-photon, 2 = two-photon, 4 = four-photon)
#    --akfiles ...     List of Ak files
#    --mgvns   ...     List of MGVNs present in the molecular data file (default "0 1 2 3 4 5 6 7" for full D2h)

from array  import array
from struct import unpack

import argparse
import math
import numpy
import sys

a0 = 5.2917721092e-11
alpha = 1 / 137.03599907
Aconst = 2*math.pi * (2*math.pi * alpha)**1
Bconst = 2*math.pi * (2*math.pi * alpha)**2
Dconst = 2*math.pi * (2*math.pi * alpha)**4
au_to_eV = 27.21
au_to_Mb = a0**2 * 1.0e+22  # ~ 28

class MolecularData:

    '''
    Class MolecularData serves for reading the RMT input file 'molecular_data'.
    Most of the data are ignored at the moment, only those relevant for this script
    are kept in memory.
    '''

    isize = 4; iform = 'i'  # integer data type byte size
    rsize = 8; rform = 'd'  # real data type byte size

    dipsto = {}  # dipole blocks per 'm'
    iidip  = {}  # bra symmetries per 'm', per dipole block (1+)
    ifdip  = {}  # ket symmetries per 'm', per dipole block (1+)
    mnp1   = {}  # number of (N+1)-electron states per symmetry
    eig    = {}  # eigen-energies per symmetry

    etarg  = None   # target (ionic) energies

    def __init__(self, filename, mgvnmap):

        f = open(filename, 'rb')
        print('\nReading file \'{}\'\n'.format(filename))

        for m in [ -1, 0, 1 ]:
            s, s1, s2 = unpack(self.iform * 3, f.read(3 * self.isize))
            self.iidip[m] = array(self.iform); self.iidip[m].fromfile(f, s)
            self.ifdip[m] = array(self.iform); self.ifdip[m].fromfile(f, s)
            self.dipsto[m] = numpy.reshape(numpy.fromfile(f, dtype = '<f' + str(self.rsize), count = s * s1 * s2), (s, s2, s1))
            print('  [m = {:2}] iidip = {}'.format(m, [i for i in self.iidip[m]]))
            print('  [m = {:2}] ifdip = {}'.format(m, [i for i in self.ifdip[m]]))
        ntarg, = unpack(self.iform, f.read(self.isize))
        print('\n  ntarg = ', ntarg)
        for m in [ -1, 0, 1 ]:
            crlv = array(self.rform); crlv.fromfile(f, ntarg * ntarg)
        n_rg, = unpack(self.iform, f.read(self.isize))
        rg = array(self.rform); rg.fromfile(f, n_rg)
        lm_rg = array(self.iform); lm_rg.fromfile(f, 6 * n_rg)
        nelc, nz, lrang2, lamax, ntarg, inast, nchmx, nstmx, lmaxp1 = unpack(self.iform * 9, f.read(9 * self.isize))
        rmatr, bbloch = unpack(self.rform * 2, f.read(2 * self.rsize))
        self.etarg = array(self.rform); self.etarg.fromfile(f, ntarg)
        ltarg = array(self.iform); ltarg.fromfile(f, ntarg)
        starg = array(self.iform); starg.fromfile(f, ntarg)
        nfdm, = unpack(self.iform, f.read(self.isize))
        delta_r, = unpack(self.rform, f.read(self.rsize))
        r_points = array(self.rform); r_points.fromfile(f, nfdm + 1)
        print('\n  number of inner eigenstates per symmetry')
        for i in range(0, inast):
            lrgl, nspn, npty, nchan, self.mnp1[mgvnmap[i]] = unpack(self.iform * 5, f.read(5 * self.isize))
            print('  {:3} {:8}'.format(mgvnmap[i], self.mnp1[mgvnmap[i]]))
            nconat = array(self.iform); nconat.fromfile(f, ntarg)
            l2p = array(self.iform); l2p.fromfile(f, nchmx)
            m2p = array(self.iform); m2p.fromfile(f, nchmx)
            self.eig[mgvnmap[i]] = array(self.rform); self.eig[mgvnmap[i]].fromfile(f, nstmx)
            wmat = array(self.rform); wmat.fromfile(f, nchmx * nstmx)
            cf = array(self.rform); cf.fromfile(f, nchmx * nchmx * lamax)
            ichl = array(self.iform); ichl.fromfile(f, nchmx)
            s1, s2 = unpack(self.iform * 2, f.read(2 * self.isize))
            for j in range (0, nfdm):
                wmat2 = array(self.rform); wmat2.fromfile(f, s1 * s2)

    def find_dipole_block(self, mgvn_i, mgvn_f):

        """
        Find this transition (mgvn_i -> mgvn_f) in the dipole block storage. Return tuple (m, idx, trans),
        where 'm' is the magnetic quantum number, 'idx' is the index of the block in the 'm'-specific
        storage, and 'trans' is a logical flag indicating whether the block is transposed with respect
        to the order mgvn_i -> mgvn_j.
        """

        trans = [ (m, idx, False) for m in [ -1, 0, 1 ] for idx in range(0, len(self.iidip[m]))
                  if self.iidip[m][idx] == mgvn_i + 1 and self.ifdip[m][idx] == mgvn_f + 1 ] \
              + [ (m, idx, True) for m in [ -1, 0, 1 ] for idx in range(0, len(self.iidip[m]))
                  if self.iidip[m][idx] == mgvn_f + 1 and self.ifdip[m][idx] == mgvn_i + 1 ]

        if len(trans) != 1:
            raise Exception('Transition not found (or ambiguous)!')
        else:
            return trans[0]


class ScatAkCoeffs:

    '''
    Read scattering coefficients from a file written by COMPAK. The file is expected to be sequential
    unformatted (binary), and to follow the usual Fortran bookmarking storage scheme. This means that
    each record is immediately preceeeded by a signed 32-bit integer with byte length of the record
    (excluding the bookmarks), and is followed by the same number with opposite sign. Support for
    above-2GiB records (i.e. record chunking) is not implemented at the moment, but should not be
    necessary given that matrices are written one row per record.
    '''

    bsize = 4; bform = 'i'  # bookmark integer type byte size
    isize = 8; iform = 'l'  # integer data type byte size
    rsize = 8; rform = 'd'  # real data type byte size

    Re_A = None     # real part of the Ak coefficients
    Im_A = None     # imag part of the Ak coefficients

    mgvn = 0        # symmetry of the (N+1)-electron system
    spin = 0        # total spin multiplicity of the (N+1)-electron system

    ichl  = None    # ionic state per channel
    lvchl = None    # angular momentum per channel
    mvchl = None    # angular momentum projection per channel
    evchl = None    # ionic energy w.r.t. to ionic ground state per channel

    nesc  = 0   # number of scattering energies
    nchan = 0   # number of scattering channels
    nstat = 0   # number of inner region eigenstates

    escat = None    # photoelectron energies

    def __init__(self, filename):

        f = open(filename, 'rb')
        print('\nReading file \'{}\'\n'.format(filename))

        begin, = unpack(self.bform, f.read(self.bsize))
        keysc, nset, nrec, ninfo, ndata = unpack(5 * self.iform, f.read(5 * self.isize))
        end, = unpack(self.bform, f.read(self.bsize))

        begin, = unpack(self.bform, f.read(self.bsize))
        title = f.read(80)
        end, = unpack(self.bform, f.read(self.bsize))

        begin, = unpack(self.bform, f.read(self.bsize))
        nscat, self.mgvn, self.stot, gutot, self.nstat, self.nchan, self.nesc = unpack(7 * self.iform, f.read(7 * self.isize))
        end, = unpack(self.bform, f.read(self.bsize))

        print('  mgvn  = ', self.mgvn)
        print('  stot  = ', self.stot)
        print('  nstat = ', self.nstat)
        print('  nchan = ', self.nchan)
        print('  nesc  = ', self.nesc)

        begin, = unpack(self.bform, f.read(self.bsize))
        rr, = unpack(self.rform, f.read(self.rsize))
        end, = unpack(self.bform, f.read(self.bsize))

        begin, = unpack(self.bform, f.read(self.bsize))
        self.ichl  = numpy.fromfile(f, dtype = '<i' + str(self.isize), count = self.nchan)
        self.lvchl = numpy.fromfile(f, dtype = '<i' + str(self.isize), count = self.nchan)
        self.mvchl = numpy.fromfile(f, dtype = '<i' + str(self.isize), count = self.nchan)
        self.evchl = numpy.fromfile(f, dtype = '<f' + str(self.rsize), count = self.nchan)
        end, = unpack(self.bform, f.read(self.bsize))

        print('\n  channel table')
        for i in range(0, self.nchan):
            print('  {:5} {:5} {:5} {:5} {:15.7f}'.format(i, self.ichl[i], self.lvchl[i], self.mvchl[i], self.evchl[i]))

        self.Re_A = numpy.zeros((self.nesc, self.nchan, self.nstat), '<f' + str(self.rsize))
        self.Im_A = numpy.zeros((self.nesc, self.nchan, self.nstat), '<f' + str(self.rsize))
        self.escat = numpy.zeros(self.nesc, '<f' + str(self.rsize))

        for n in range(0, self.nesc):
            for i in range(0, self.nchan):

                begin, = unpack(self.bform, f.read(self.bsize))
                self.escat[n], j = unpack(self.rform + self.iform, f.read(self.rsize + self.isize))
                self.Re_A[n, i, :] = numpy.fromfile(f, dtype = '<f' + str(self.rsize), count = self.nstat)
                end, = unpack(self.bform, f.read(self.bsize))

                begin, = unpack(self.bform, f.read(self.bsize))
                self.escat[n], j = unpack(self.rform + self.iform, f.read(self.rsize + self.isize))
                self.Im_A[n, i, :] = numpy.fromfile(f, dtype = '<f' + str(self.rsize), count = self.nstat)
                end, = unpack(self.bform, f.read(self.bsize))


def calc_1p_cs(moldat, aki, akf):

    '''
    Calculate photoionization cross section in the first order of the perturbation.

    The input parameters are:
        moldat  ...  molecular data structure
        aki     ...  scattering coefficients data structure for initial state symmetry (e.g. Ag)
        akf     ...  scattering coefficients data structures for all considered final state symmetries (a.g. B1u, B2u, B3u)
    '''

    mgvn_i = aki.mgvn   # symmetry of the initial state
    stat_i = 0          # index of the initial state

    energies = None     # photon energies
    sigma = {}          # photoionization cross sections per polarization

    for ak in akf:

        mgvn_f = ak.mgvn    # symmetry of the final state
        transp = False      # transpose the dipole block

        # Find this transition (mgvn_i -> mgvn_f) in molecular data
        m, ib, transp = moldat.find_dipole_block(mgvn_i, mgvn_f)
        print('\nProcessing transition {} -> {} (m = {}, iblock = {})'.format(mgvn_i, mgvn_f, m, ib))

        # Get the inner region dipoles between the initial state and all states of mgvn_f
        ni = moldat.mnp1[mgvn_i]    # number of inner eigenstates in mgvn_i
        nf = moldat.mnp1[mgvn_f]    # number of inner eigenstates in mgvn_f
        D = moldat.dipsto[m][ib, 0:nf, stat_i] if not transp else moldat.dipsto[m][ib, stat_i, 0:nf]

        # Calculate partial wave dipoles
        Re_pwd = ak.Re_A.dot(-D)
        Im_pwd = ak.Im_A.dot(-D)

        # Print 1-photon transition matrix elements
        for ich in range(0, ak.nchan):
            print('\nOne-photon matrix elements (mgvn {}, channel {}):\n'.format(mgvn_f, ich))
            print('    {:10}   {:10}   {:10}'.format('E [eV]', 'Re M [a.u.]', 'Im M [a.u.]'))
            for ie in range(0, aki.nesc):
                print('    {:10.3f}   {:10.8e}   {:10.8e}'.format(ak.escat[ie] * au_to_eV, Re_pwd[ie,ich], -Im_pwd[ie,ich]))

        # Calculate the partial cross sections and corresponding photon energies
        IP = moldat.etarg[0] - moldat.eig[mgvn_i][stat_i]   # first ionization potential
        energies = IP + numpy.repeat(ak.escat.reshape((ak.nesc, 1)), ak.nchan, axis = 1)
        psigma = Aconst * numpy.multiply(numpy.power(Re_pwd, 2) + numpy.power(Im_pwd, 2), energies)

        # Sum channel contributions
        sigma[m] = numpy.sum(psigma, axis = 1)

    print('\nOne-photon cross section:\n')
    print('    {:10}   {:10}'.format('E [eV]', 'cs [Mb]'))
    for k in range(0, aki.nesc):
        print('    {:10.3f}   {:10.5e}'.format(energies[k,0] * au_to_eV, sum(sigma[m][k] for m in sigma) / max(1, len(sigma)) * au_to_Mb))


def calc_2p_cs(moldat, aki, akn):

    '''
    Calculate photoionization cross section in the second order of the perturbation. At the moment assumes that
    both photons have the same polarization.

    The input parameters are:
        moldat  ...  molecular data structure
        aki     ...  scattering coefficients data structure for initial and final state symmetry (e.g. Ag)
        akn     ...  scattering coefficients data structures for all considered intermediate state symmetries (a.g. B1u, B2u, B3u)
    '''

    mgvn_i = aki.mgvn   # symmetry of the initial state (and also final symmetry)
    stat_i = 0          # index of the initial state

    Ei = moldat.eig[mgvn_i][stat_i]     # eigen-energy of the initial state
    IP = moldat.etarg[0] - Ei           # first ionization potential

    Q = numpy.zeros((aki.nesc, aki.nstat))
    Q[:,stat_i] = -1

    for ak in akn:

        mgvn_n = ak.mgvn    # symmetry of the intermediate state
        transp = False      # transpose the dipole block

        # Find first transition (mgvn_i -> mgvn_n) in molecular data
        m, ib, transp = moldat.find_dipole_block(mgvn_i, mgvn_n)
        print('\nProcessing transition {} -> {} -> {} (m = {}, iblock = {})'.format(mgvn_i, mgvn_n, mgvn_i, m, ib))

        # Get the inner region dipoles between the initial state and all states of mgvn_n
        ni = moldat.mnp1[mgvn_i]    # number of inner eigenstates in mgvn_i
        nn = moldat.mnp1[mgvn_n]    # number of inner eigenstates in mgvn_n
        Di = moldat.dipsto[m][ib, 0:nn, stat_i] if not transp else moldat.dipsto[m][ib, stat_i, 0:nn]
        Df = moldat.dipsto[m][ib, 0:nn, 0:ni].T if not transp else moldat.dipsto[m][ib, 0:ni, 0:nn]

        # Evaluate the projection operator for all energies
        for ie in range(0, ak.nesc):
            Ephoton = 0.5 * (IP + ak.escat[ie])
            Dn = Di / (Ei + Ephoton - moldat.eig[mgvn_n][0:nn])
            Q[ie,:] = Q[ie,:] + Df.dot(Dn)

    # Calculate partial wave dipoles
    Re_pwd = numpy.zeros((aki.nesc, aki.nchan))
    Im_pwd = numpy.zeros((aki.nesc, aki.nchan))
    for ie in range(0, aki.nesc):
        Re_pwd[ie,:] = aki.Re_A[ie,:,:].dot(Q[ie,:])
        Im_pwd[ie,:] = aki.Im_A[ie,:,:].dot(Q[ie,:])

    # Print 2-photon transition matrix elements
    for ich in range(0, aki.nchan):
        print('\nTwo-photon matrix elements (channel {}):\n'.format(ich))
        print('    {:10}   {:10}   {:10}'.format('E [eV]', 'Re M [a.u.]', 'Im M [a.u.]'))
        for ie in range(0, aki.nesc):
            print('    {:10.3f}   {:10.8e}   {:10.8e}'.format(aki.escat[ie] * au_to_eV, Re_pwd[ie,ich], -Im_pwd[ie,ich]))

    # Calculate the partial cross sections and corresponding photon energies
    energies = 0.5 * (IP + numpy.repeat(aki.escat.reshape((aki.nesc, 1)), aki.nchan, axis = 1))
    psigma = Bconst * numpy.multiply(numpy.power(Re_pwd, 2) + numpy.power(Im_pwd, 2), numpy.power(energies, 2))

    # Sum channel contributions
    sigma = numpy.sum(psigma, axis = 1)

    print('\nTwo-photon cross section:\n')
    print('    {:10}   {:10}'.format('E [eV]', 'cs [a.u.]'))
    for k in range(0, aki.nesc):
        print('    {:10.3f}   {:10.5e}'.format(energies[k,0] * au_to_eV, sigma[k]))


def calc_4p_cs(moldat, aki, akn):

    '''
    Calculate photoionization cross section in the four order of the perturbation. At the moment assumes that
    the absorbed photons only switch the total symmetry between the initial one and the intermediate one (i.e. the
    chain is i -> n -> i -> n -> i).

    The input parameters are:
        moldat  ...  molecular data structure
        aki     ...  scattering coefficients data structure for the initial and final state symmetry (e.g. Ag)
        akn     ...  scattering coefficients data structure for the intermediate state symmetry (a.g. B1u)
    '''

    mgvn_i = aki.mgvn   # symmetry of the initial state
    mgvn_n = akn.mgvn   # symmetry of the intermediate states
    stat_i = 0          # index of the initial state
    transp = False      # transpose the dipole block

    Ei = moldat.eig[mgvn_i][stat_i]     # eigen-energy of the initial state
    IP = moldat.etarg[0] - Ei           # first ionization potential

    # optionally, override the first ionization potential (and initial state energy)
    #IP = 0.60618
    #Ei = moldat.etarg[0] - IP

    Q = numpy.zeros((aki.nesc, aki.nstat))
    Q[:,stat_i] = -1

    # Find the transition (mgvn_i -> mgvn_f) in molecular data
    m, ib, transp = moldat.find_dipole_block(mgvn_i, mgvn_n)
    print('\nProcessing transition {0} -> {1} -> {0} -> {1} -> {0} (m = {2}, iblock = {3})'.format(mgvn_i, mgvn_n, m, ib))

    # Get the inner region dipoles between the initial state and all states of mgvn_n
    ni = moldat.mnp1[mgvn_i]    # number of inner eigenstates in mgvn_i
    nn = moldat.mnp1[mgvn_n]    # number of inner eigenstates in mgvn_n
    Di = moldat.dipsto[m][ib, 0:nn, stat_i] if not transp else moldat.dipsto[m][ib, stat_i, 0:nn]
    Df = moldat.dipsto[m][ib, 0:nn, 0:ni].T if not transp else moldat.dipsto[m][ib, 0:ni, 0:nn]

    # Evaluate the projection operator for all energies
    for ie in range(0, len(aki.escat)):
        Ephoton = 0.25 * (IP + aki.escat[ie])
        D2 = Df  .dot(Di / (Ei + 1*Ephoton - moldat.eig[mgvn_n][0:nn]))
        D3 = Df.T.dot(D2 / (Ei + 2*Ephoton - moldat.eig[mgvn_i][0:ni]))
        D4 = Df  .dot(D3 / (Ei + 3*Ephoton - moldat.eig[mgvn_n][0:nn]))
        Q[ie,:] = Q[ie,:] + D4

    # Calculate partial wave dipoles
    Re_pwd = numpy.zeros((aki.nesc, aki.nchan))
    Im_pwd = numpy.zeros((aki.nesc, aki.nchan))
    for ie in range(0, aki.nesc):
        Re_pwd[ie,:] = aki.Re_A[ie,:,:].dot(Q[ie,:])
        Im_pwd[ie,:] = aki.Im_A[ie,:,:].dot(Q[ie,:])

    # Print 4-photon transition matrix elements
    for ich in range(0, aki.nchan):
        print('\nFour-photon matrix elements (channel {}):\n'.format(ich))
        print('    {:10}   {:10}   {:10}'.format('E [eV]', 'Re M [a.u.]', 'Im M [a.u.]'))
        for ie in range(0, aki.nesc):
            print('    {:10.3f}   {:10.8e}   {:10.8e}'.format(aki.escat[ie] * au_to_eV, Re_pwd[ie,ich], -Im_pwd[ie,ich]))

    # Calculate the partial cross sections and corresponding photon energies
    energies = 0.25 * (IP + numpy.repeat(aki.escat.reshape((aki.nesc, 1)), aki.nchan, axis = 1))
    psigma = Dconst * numpy.multiply(numpy.power(Re_pwd, 2) + numpy.power(Im_pwd, 2), numpy.power(energies, 4))

    # Sum channel contributions
    sigma = numpy.sum(psigma, axis = 1)

    print('\nFour-photon cross section:\n')
    print('    {:10}   {:10}'.format('E [eV]', 'cs [a.u.]'))
    for k in range(0, aki.nesc):
        print('    {:10.3f}   {:10.5e}'.format(energies[k,0] * au_to_eV, sigma[k]))

# Main program

parser = argparse.ArgumentParser(description = 'Calculate multi-photon cross sections.')
parser.add_argument('--order', nargs = 1, type = int, required = True, metavar = 'n',
                    help = 'Perturbation order (1 = one-photon, 2 = two-photon, 4 = four-photon)')
parser.add_argument('--mgvns', nargs = '+', type = int, default = [0, 1, 2, 3, 4, 5, 6, 7], metavar = 'M',
                    help = 'List of MGVNs present in the molecular data file (default "0 1 2 3 4 5 6 7" for full D2h)')
parser.add_argument('--akfiles', nargs = '+', required = True, metavar = 'F',
                    help = 'List of Ak files')
args = parser.parse_args()

moldat = MolecularData('molecular_data', args.mgvns)
Ak_files = [ ScatAkCoeffs(Akfile) for Akfile in args.akfiles ]

if args.order[0] == 1:
    if len(Ak_files) < 2:
        raise Exception('At least two Ak files required!')
    calc_1p_cs(moldat, Ak_files[0], Ak_files[1:])
    sys.exit(0)

if args.order[0] == 2:
    if len(Ak_files) < 2:
        raise Exception('At least two Ak files required!')
    calc_2p_cs(moldat, Ak_files[0], Ak_files[1:])
    sys.exit(0)

if args.order[0] == 4:
    if len(Ak_files) != 2:
        raise Exception('Two Ak files required!')
    calc_4p_cs(moldat, Ak_files[0], Ak_files[1])
    sys.exit(0)

raise Exception('Perturbation order must be 1, 2 or 4.')
