#!/usr/bin/env nemesis
# =================================================================================================
# This code is part of PyLith, developed through the Computational Infrastructure
# for Geodynamics (https://github.com/geodynamics/pylith).
#
# Copyright (c) 2010-2025, University of California, Davis and the PyLith Development Team.
# All rights reserved.
#
# See https://mit-license.org/ and LICENSE.md and for license information. 
# =================================================================================================
"""
Pyre application to compute power-law parameters used by PyLith, given
spatial databases describing the temperature and laboratory-derived
properties for the various materials. The output is a `SimpleDB` spatial
database containing the power-law parameters for PyLith.
"""

import numpy

from pythia.pyre.applications.Script import Script as Application


class PowerLawApp(Application):
    """Pyre application to generate a spatial database with power-law parameters for PyLith.
    """
    TENSOR_COMPONENTS = ["xx", "yy", "zz", "xy", "yz", "xz"]
    
    import pythia.pyre.inventory
    from pythia.pyre.units.pressure import MPa
    from pythia.pyre.units.time import s

    refSelection = pythia.pyre.inventory.str("reference_value",
                                             default="strain_rate",
                                             validator=pythia.pyre.inventory.choice(['stress',
                                                                                     'strain_rate']))
    refSelection.meta['tip'] = "Indicates whether reference stress or " \
        "reference strain rate is provided as input."

    powerlawRefStress = pythia.pyre.inventory.dimensional("reference_stress", default=1.0 * MPa)
    powerlawRefStress.meta['tip'] = "Reference stress value."

    pwerlawRefStrainRate = pythia.pyre.inventory.dimensional("reference_strain_rate", default=1.0e-6 / s)
    pwerlawRefStrainRate.meta['tip'] = "Reference strain rate value."

    from spatialdata.spatialdb.SimpleDB import SimpleDB

    dbExponent = pythia.pyre.inventory.facility("db_exponent", family="spatial_database", factory=SimpleDB)
    dbExponent.meta['tip'] = "Spatial database for power-law exponent, n."

    dbActivationE = pythia.pyre.inventory.facility("db_activation_energy", family="spatial_database", factory=SimpleDB)
    dbActivationE.meta['tip'] = "Spatial database for activation energy, Q."

    dbTemperature = pythia.pyre.inventory.facility("db_temperature", family="spatial_database", factory=SimpleDB)
    dbTemperature.meta['tip'] = "Spatial database for temperature, T."

    dbAe = pythia.pyre.inventory.facility("db_powerlaw_coefficient", family="spatial_database", factory=SimpleDB)
    dbAe.meta['tip'] = "Spatial database for power-law coefficient, Ae."

    dbStateVars = pythia.pyre.inventory.facility("db_state_variables", family="spatial_database", factory=SimpleDB)
    dbStateVars.meta["tip"] = "Spatial database for power-law state variables, viscous_strain, reference_stress, reference_strain"

    from spatialdata.spatialdb.generator.Geometry import Geometry
    geometry = pythia.pyre.inventory.facility("geometry", family="geometry", factory=Geometry)
    geometry.meta['tip'] = "Geometry for output database."

    dbFilename = pythia.pyre.inventory.str("database_filename", default="powerlaw.spatialdb")
    dbFilename.meta['tip'] = "Filename for generated spatial database."

    def __init__(self, name="powerlaw_gendb"):
        Application.__init__(self, name)

    def main(self, *args, **kwds):
        """
        Application driver.
        """
        # Get output points
        self._info.log("Reading geometry.")
        self.geometry.read()
        points = self.geometry.vertices
        coordsys = self.geometry.coordsys

        (npoints, spaceDim) = points.shape
        pwerlawRefStrainRate = numpy.zeros((npoints,), dtype=numpy.float64)
        powerlawRefStress = numpy.zeros((npoints,), dtype=numpy.float64)
        
        # Query databases to get inputs at output points
        self._info.log("Querying for parameters at output points.")
        n = self._queryDB(self.dbExponent, ["power_law_exponent"], points, coordsys)
        Q = self._queryDB(self.dbActivationE, ["activation_energy"], points, coordsys)
        logAe = self._queryDB(self.dbAe, ["log_flow_constant"], points, coordsys)
        scaleAe = self._queryDB(self.dbAe, ["flow_constant_scale"], points, coordsys)
        T = self._queryDB(self.dbTemperature, ["temperature"], points, coordsys)

        # Compute power-law parameters
        self._info.log("Computing parameters at output points...")
        from pythia.pyre.handbook.constants.fundamental import R
        Ae = 10**(logAe - scaleAe * n)
        At = 3**(0.5 * (n + 1)) / 2.0 * Ae * numpy.exp(-Q / (R.value * T))

        if self.refSelection == "stress":
            powerlawRefStress[:] = self.powerlawRefStress.value
            pwerlawRefStrainRate = self.powerlawRefStress.value**n * At
        elif self.refSelection == "strain_rate":
            pwerlawRefStrainRate[:] = self.pwerlawRefStrainRate.value
            powerlawRefStress = (self.pwerlawRefStrainRate.value / At)**(1.0 / n)
        else:
            raise ValueError(f"Invalid value (self.refSelection) for reference value.")

        powerlawRefStressInfo = {
            'name': "power_law_reference_stress",
            'units': "Pa",
            'data': powerlawRefStress.flatten()
        }
        pwerlawRefStrainRateInfo = {
            'name': "power_law_reference_strain_rate",
            'units': "1/s",
            'data': pwerlawRefStrainRate.flatten()
        }
        exponentInfo = {
            'name': "power_law_exponent",
            'units': "none",
            'data': n.flatten()
        }

        tensorComponents = self.TENSOR_COMPONETS if spaceDim == 3 else self.TENSOR_COMPONENTS[:4]
            
        valueNames = [f"viscous_strain_{component}" for component in tensorComponents]
        viscousStrain = self._queryDB(self.dbStateVars, valueNames, points, coordsys)
        viscousStrainInfo = [{
            "name": "viscous_strain_xx",
            "units": "none",
            "data": viscousStrain[:,0].flatten(),
            },{
            "name": "viscous_strain_yy",
            "units": "none",
            "data": viscousStrain[:,1].flatten(),
            },{
            "name": "viscous_strain_zz",
            "units": "none",
            "data": viscousStrain[:,2].flatten(),
            },{
            "name": "viscous_strain_xy",
            "units": "none",
            "data": viscousStrain[:,3].flatten(),
            }]
        if spaceDim == 3:
            viscousStrainInfo += [{
                "name": "viscous_strain_yz",
                "units": "none",
                "data": viscousStrain[:,4].flatten(),
                },{
                "name": "viscous_strain_xz",
                "units": "none",
                "data": viscousStrain[:,5].flatten(),
                }]
                
        valueNames = [f"deviatoric_stress_{component}" for component in tensorComponents]
        devStress = self._queryDB(self.dbStateVars, valueNames, points, coordsys)
        devStressInfo = [{
            "name": "deviatoric_stress_xx",
            "units": "none",
            "data": viscousStrain[:,0].flatten(),
            },{
            "name": "deviatoric_stress_yy",
            "units": "none",
            "data": viscousStrain[:,1].flatten(),
            },{
            "name": "deviatoric_stress_zz",
            "units": "none",
            "data": viscousStrain[:,2].flatten(),
            },{
            "name": "deviatoric_stress_xy",
            "units": "none",
            "data": viscousStrain[:,3].flatten(),
            }]
        if spaceDim == 3:
            viscousStrainInfo += [{
                "name": "deviatoric_stress_yz",
                "units": "none",
                "data": viscousStrain[:,4].flatten(),
                },{
                "name": "deviatoric_stress_xz",
                "units": "none",
                "data": viscousStrain[:,5].flatten(),
                }]
                
        valueNames = [f"reference_stress_{component}" for component in tensorComponents]
        refStress = self._queryDB(self.dbStateVars, valueNames, points, coordsys)
        refStressInfo = [{
            "name": "reference_stress_xx",
            "units": "Pa",
            "data": refStress[:,0].flatten(),
            },{
            "name": "reference_stress_yy",
            "units": "Pa",
            "data": refStress[:,1].flatten(),
            },{
            "name": "reference_stress_zz",
            "units": "Pa",
            "data": refStress[:,2].flatten(),
            },{
            "name": "reference_stress_xy",
            "units": "Pa",
            "data": refStress[:,3].flatten(),
            }]
        if spaceDim == 3:
            refStressInfo += [{
                "name": "reference_stress_yz",
                "units": "Pa",
                "data": refStress[:,4].flatten(),
                },{
                "name": "reference_stress_xz",
                "units": "Pa",
                "data": refStress[:,5].flatten(),
                }]

        valueNames = [f"reference_strain_{component}" for component in tensorComponents]
        refStrain = self._queryDB(self.dbStateVars, valueNames, points, coordsys)
        refStrainInfo = [{
            "name": "reference_strain_xx",
            "units": "none",
            "data": refStrain[:,0].flatten(),
            },{
            "name": "reference_strain_yy",
            "units": "none",
            "data": refStrain[:,1].flatten(),
            },{
            "name": "reference_strain_zz",
            "units": "none",
            "data": refStrain[:,2].flatten(),
            },{
            "name": "reference_strain_xy",
            "units": "none",
            "data": refStrain[:,3].flatten(),
            }]
        if spaceDim == 3:
            refStrainInfo += [{
                "name": "reference_strain_yz",
                "units": "none",
                "data": refStrain[:,4].flatten(),
                },{
                "name": "reference_strain_xz",
                "units": "none",
                "data": refStrain[:,5].flatten(),
                }]

        # Write database
        self._info.log("Writing database.")
        data = {
            'points': points,
            'coordsys': coordsys,
            'data_dim': self.geometry.dataDim,
            'values': [powerlawRefStressInfo, pwerlawRefStrainRateInfo, exponentInfo]
        }
        data["values"] += viscousStrainInfo
        data["values"] += refStressInfo
        data["values"] += refStrainInfo
        data["values"] += devStressInfo
        from spatialdata.spatialdb.SimpleIOAscii import createWriter
        writer = createWriter(self.dbFilename)
        writer.write(data)

    def _queryDB(self, db, valueNames, points, cs):
        """
        Query spatial database
        """

        (npoints, spaceDim) = points.shape
        data = numpy.zeros((npoints, len(valueNames)), dtype=numpy.float64)
        err = numpy.zeros((npoints,), dtype=numpy.int32)

        db.open()
        db.setQueryValues(valueNames)
        db.multiquery(data, err, points, cs)
        db.close()
        errSum = numpy.sum(err)
        if errSum > 0:
            msg = "Query for %s failed at %d points.\n" \
                "Coordinates of points:\n" % (valueName, errSum)
            msg += "%s" % points[err, :]
            raise ValueError(msg)

        return data


# ----------------------------------------------------------------------
if __name__ == '__main__':
    PowerLawApp().run()


# End of file
