#!/usr/bin/env python

"""
    Utility to generate GEOS-5 file spec information, a.k.a.
    the Rob Lucchesi emulator.

    Arlindo da Silva, February 2014.

"""

import os
import sys
import re

from optparse     import OptionParser
from netCDF4      import Dataset
from MAPL         import Config

try:
    from PyRTF    import *
    HAS_RTF = True
except ImportError:
    try:
        from rtfw import *
        HAS_RTF = True
    except:
        HAS_RTF = False

if HAS_RTF:

    nNAME = int(TabPS.DEFAULT_WIDTH * 2.5)
    nDIMS = int(TabPS.DEFAULT_WIDTH * 1)
    nDESC = int(TabPS.DEFAULT_WIDTH * 7.5)
    nUNIT = int(TabPS.DEFAULT_WIDTH * 2)
    #grey_bkg = ShadingPS(shading='Grey')
    thin_edge   = BorderPS( width=5, style=BorderPS.SINGLE )
    normal_edge = BorderPS( width=20, style=BorderPS.SINGLE )
    thick_edge  = BorderPS( width=40, style=BorderPS.SINGLE )
    none_edge   = BorderPS( width=0,  style=BorderPS.SINGLE )
    header_frame  = FramePS( None,      normal_edge, thick_edge, normal_edge )
    header_frameL = FramePS( None,      None,        thick_edge, normal_edge )
    header_frameR = FramePS( None,      normal_edge, thick_edge, None )
    thin_frame    = FramePS( thin_edge, normal_edge, thin_edge,  normal_edge )
    thin_frameL   = FramePS( thin_edge, None,        thin_edge,  normal_edge )
    thin_frameR   = FramePS( thin_edge, normal_edge, thin_edge,  None )

    
#     -----
class Spec(object):
#     -----
    """                                                                                  
    Generic container for Variables
    """
    def __init__(self,name,mytype):
        self.name = name
        self.type = mytype
        self.section = 0

#.................................................................................................

#     ------
class Writer(object):
#     ------
    """
    Base class for writer.
    """
    def writeVariable(self,vname,var,newUnits=None):
        """
        Write row for a single variable.
        """
        dims = var.dimensions
        if len(dims) == 1:
            a, = dims
            dim = get1d(a)
        elif len(dims) == 2:
            a, b = dims
            dim = get1d(a) + get1d(b)
        elif len(dims) == 3:
            a, b, c = dims
            dim = get1d(a) + get1d(b) + get1d(c)
        elif len(dims) == 4:
            a, b, c, d = dims
            dim = get1d(a) + get1d(b) + get1d(c) + get1d(d)
        else:
            raise ValueError, 'invalid dimensions: %s for <%s>'%(str(dims),vname)
        descr = var.long_name.replace('__ENSEMBLE__','').replace('_',' ')
        if newUnits:
            unitlessDescr = re.sub('\[.*\]','',descr)
            self.row(vname,dim,unitlessDescr,newUnits)
        else:
            self.row(vname,dim,descr,var.units)
    
    def doCollection(self,Coll,options):
        """
        Write spec for a given collection.
        """
       
        Title = Coll.title.split(',')
        Name = Coll.type.split('_')

        # Defaults deverived from collection name
        # ---------------------------------------
        if Name[0][0:4].upper() == 'TAVG':
            sampling = 'Time-Averaged'
        elif Name[0][0:4].upper() == 'INST':
            sampling = 'Instantaneous'
        elif Name[0][0:4].upper() == 'CONS':
            sampling = 'Invariant'
        else:
            sampling = 'unknown'

        try:
            freq = Name[0][4]+'-Hourly' # not robust
        except:
            freq = 'unknown'

        dim = Name[1]

        if options.dyamond and 'Invariant' not in sampling:
            resfield = 4
        else:
            resfield = 3

        onCube = False
        if Name[resfield][0] == 'C':
            res = 'Coarsened Horizontal Resolution'
        elif Name[resfield][0] == 'N':
            res = 'Full Horizontal Resolution'
        elif Name[resfield][0] == 'M':
            onCube = True
            res = 'Full Horizontal Resolution On Cube'
        else:
            res = 'Unknown resolution'

        if Name[resfield][1] == 'v':
            level = 'Model-Level'
        elif Name[resfield][1] == 'e':
            level = 'Model-Edge-Level'
        elif Name[resfield][1] == 'p':
            level = 'Pressure-Level'
        elif Name[resfield][1] == 'x':
            level = 'Single-Level'
        elif Name[resfield][1] == 'z':
            level = 'Height-Level'
        else:
            level = 'unknown'

        have_channels = False
        if Coll.lev_name:
            if 'CHANNEL' in Coll.lev_name.upper():
                have_channels = True

        if Title[0] == 'Invariants':
            dim, freq, sampling, level = ('2d','invariants','time independent', 'surface')
            descr = Title[1]
        else:
            if options.nature:
                dim_, freq, sampling, level, res = Title[:5]
                descr = ','.join(Title[5:])[:-2]
            elif options.merraobs:
                dim_, freq, sampling, level = Title[:4]
                descr = ','.join(Title[4:])
            elif options.merra:
                dim_, freq, sampling, level = Title[:4]
                descr = ','.join(Title[5:]) #[:-2].replace("'","")
            else:
                if len(Title) >= 4:
                    dim_, freq, sampling, level = Title[:4]
                    descr = ','.join(Title[4:])
                else:
                    descr = 'fix me, please'
                    
        #if Coll.LongName is not None:
        #    Descr = Coll.LongName.split()[4:]
        #    descr = ' '.join(Descr)
            
        descr = descr.replace('Forecast,','').replace('Assimilation,','').replace('Analysis,','').replace('Model,','')

        # Section Header
        # --------------
        if Coll.ShortName is None:
            self.section(Coll.name,descr)
        else:
            cname = Coll.name + ' (%s)'%str(Coll.ShortName)
            self.section(cname,descr)

        # Properties
        # ----------
        if Coll.tbeg is None:
            self.property('Frequency',"%s (%s)"%(freq.lower(),sampling.lower()))
        else:
            date,time = Coll.tbeg.split('T')
            tbeg = time[:5]+' UTC'
            self.property('Frequency',"%s from %s (%s)"%(freq.lower(),tbeg,sampling.lower()))
        self.property("Spatial Grid","%s, %s, %s"%(dim.upper(),level.lower(),res.lower()))
        if Coll.nz<1:
            if onCube:
                self.property("Dimensions","grid resolution=%d, time=%d "%(Coll.nx,Coll.nt))
            else:
                self.property("Dimensions","longitude=%d, latitude=%d, time=%d "%(Coll.nx,Coll.ny,Coll.nt))
        else:
            if have_channels:
                self.property("Dimensions","longitude=%d, latitude=%d, channels=%d, time=%d "%(Coll.nx,Coll.ny,Coll.nz,Coll.nt))
            else:
                if onCube:
                    self.property("Dimensions","grid resolution=%d, level=%d, time=%d "%(Coll.nx,Coll.nz,Coll.nt))
                else:
                    self.property("Dimensions","longitude=%d, latitude=%d, level=%d, time=%d "%(Coll.nx,Coll.ny,Coll.nz,Coll.nt))


        usize = 'MB'
        gsize = int(float(Coll.size)/(1024.*1024.)+0.5)
        if ( gsize>1024):
             usize = 'GB'
             gsize = float(Coll.size)/(1024*1024.*1024.)+0.05
             self.property("Granule Size","~%3.1f %s"%(gsize,usize))
        else:
             self.property("Granule Size","~%d %s"%(gsize,usize))

        # Variable table
        # --------------
        self.table('header')
        for vname in sorted(Coll.variables):
            if any(Coll.newUnits):
                self.writeVariable(vname, Coll.variables[vname], Coll.newUnits[vname] )
            else:
                self.writeVariable(vname, Coll.variables[vname] )
        self.table('footer')

#.................................................................................................

#     ---------
class stdoutWriter(Writer):
#     ---------
    """
    Implements text file writer.
    """
    def __init__(self,filename=None):
        self.filename = filename # ignore this for now.
        self.secn = 0
    def section(self,collection,description):
        self.secn += 1
        print
        print
        print "%d) Collection %s: %s"%(self.secn,collection,description)
        print
    def property(self,name,value):
        print "%19s: %s"%(name,value)
    def table(self,choice='header'):
        if choice == 'header':
            print
            print " -------------|------|--------------------------------------------------------------------------------|------------"
            print "     Name     | Dims |                               Description                                      |    Units   "
        print     " -------------|------|--------------------------------------------------------------------------------|------------"
        return
    def row(self,vname,dim,descr,units):
        print ' %-12s | %4s | %-78s | %-10s '%(vname,dim,descr,units)

    def close(self):
        pass

#.................................................................................................

#     ---------
class txtWriter(Writer):
#     ---------
    """
    Implements text file writer.
    """
    def __init__(self,filename=None):
        self.filename = filename # ignore this for now.
        self.doc = open(filename,'w')
        self.secn = 0
    def section(self,collection,description):
        self.secn += 1
        self.doc.write("\n")
        self.doc.write("\n")
        self.doc.write( "%d) Collection %s: %s\n"%(self.secn,collection,description) )
        self.doc.write("\n")
    def property(self,name,value):
        self.doc.write( "%19s: %s\n"%(name,value) )
    def table(self,choice='header'):
        if choice == 'header':
            self.doc.write("\n")
            self.doc.write( " -------------|------|--------------------------------------------------------------------------------|------------\n" )
            self.doc.write( "     Name     | Dims |                               Description                                      |    Units   \n" )
        self.doc.write(     " -------------|------|--------------------------------------------------------------------------------|------------\n" )
        return
    def row(self,vname,dim,descr,units):
        self.doc.write( ' %-12s | %4s | %-78s | %-10s \n'%(vname,dim,descr,units) )

    def close(self):
        self.doc.close()

#.................................................................................................

#     ---------
class rtfWriter(Writer):
#     ---------
    """
    Implements RTF file writer.
    """
    def __init__(self,filename):
        self.filename = filename # ignore this for now.
        self.doc = Document(style_sheet=MakeMyStyleSheet())
        self.ss = self.doc.StyleSheet
        self.secn = 0
        self.Section = Section()
        self.doc.Sections.append(self.Section)
    def skip(self):    
        p  = Paragraph(self.ss.ParagraphStyles.Normal)
        p.append('')
        self.Section.append(p)
    def section(self,collection,description):
        self.secn += 1
        font = font=self.ss.Fonts.Arial
        p  = Paragraph(self.ss.ParagraphStyles.Heading3)
        # p.append(TEXT('%d) Collection '%self.secn,font=font))
        # p.append(TEXT('Collection ',font=font))
        # p.append(TEXT(collection,colour=self.ss.Colours.Blue,font=font))
        # p.append(TEXT(': %s'%str(description),font=font))
        p.append(TEXT(collection,colour=self.ss.Colours.Blue))
        p.append(TEXT(': %s'%str(description)))
        self.Section.append(p)
        self.skip()
    def property(self,name,value):
        s = str("%s: %s"%(name,value))
        p  = Paragraph(self.ss.ParagraphStyles.Normal)
        p.append(TEXT('     '+name+': ',bold=True),TEXT(str(value),bold=False,italic=True))
        #p.append(TAB,s)
        self.Section.append(p)
    def table(self,choice='header'):
        if choice == 'header':
            self.Table = Table(nNAME,nDIMS,nDESC,nUNIT)
            c1 = Cell( Paragraph(TEXT('       Name',bold=True,italic=True)),header_frameL)
            c2 = Cell( Paragraph(TEXT(' Dim',bold=True,italic=True)),header_frame)
            c3 = Cell( Paragraph(TEXT('                               Description',bold=True,italic=True)),header_frame)
            c4 = Cell( Paragraph(TEXT('      Units',bold=True,italic=True)),header_frameR)
            self.Table.AddRow(c1, c2, c3, c4)
        else:
            self.skip()
            self.Section.append(self.Table)
            self.skip()
    def row(self,vname,dim,descr,units):
        #print ' %-10s | %4s | %-58s | %-10s '%(vname,dim,descr,units)
        c1 = Cell(Paragraph(TEXT(str(vname))),thin_frameL)
        c2 = Cell(Paragraph(TEXT(str(dim))),thin_frame)
        c3 = Cell(Paragraph(TEXT(str(descr))),thin_frame)
        c4 = Cell(Paragraph(TEXT(str(units))),thin_frameR)
        self.Table.AddRow(c1, c2, c3, c4)
    def close(self):
        r = Renderer()
        f = file(self.filename,'w')
        r.Write(self.doc,f)
        
#   -------------
def getCollection(filename,collname,options,colltype=None):
#   -------------
    """
    Parses netCDF file and gathers collection metadata.
    """
    if colltype:
        Coll = Spec(collname,colltype)
    else:
        Coll = Spec(collname,collname)

    Coll.size = os.path.getsize(filename)
    Coll.nc = Dataset(filename)
    Coll.variables = dict()
    Coll.newUnits = dict()

    try:
        Coll.title = Coll.nc.Title
    except:
        Coll.title = None

    try:
        Coll.LongName = Coll.nc.LongName
    except:
        Coll.LongName = None

    try:
        Coll.ShortName = Coll.nc.ShortName
    except:
        Coll.ShortName = None

    try:
        Coll.nx = len(Coll.nc.dimensions['lon'])
    except:
        try:
            Coll.nx = len(Coll.nc.dimensions['longitude'])
        except:
            Coll.nx = len(Coll.nc.dimensions['Xdim'])

    try:
        Coll.ny = len(Coll.nc.dimensions['lat'])
    except:
        try:
            Coll.ny = len(Coll.nc.dimensions['latitude'])
        except:
            Coll.ny = len(Coll.nc.dimensions['Ydim'])


    try:
        Coll.nt = len(Coll.nc.dimensions['time'])
    except:
        Coll.nt = 0

    try:
        Coll.nz = len(Coll.nc.dimensions['lev'])
        Coll.lev_name = Coll.nc.variables['lev'].long_name
    except:
        try:
            Coll.nz = len(Coll.nc.dimensions['levels'])
            Coll.lev_name = Coll.nc.variables['levels'].long_name
        except:
            Coll.nz = 0
            Coll.lev_name = None

    ignore_vars = ['lon','lons','longitude','xdim',
                   'lat','lats','latitude','ydim',
                   'lev','levels',
                   'time','taitime',
                   'nf','ncontact','orientationStrLen','cubed_sphere','anchor','contacts','orientation']
    for v in Coll.nc.variables:
        if v.lower() in ignore_vars:
            continue
        Coll.variables[v] = Coll.nc.variables[v]

    try:
        time = Coll.nc.variables['time']
        tunits = time.units
        units, offset = time.units.split(' since ')
        Coll.tbeg = offset.replace(' ','T')
    except:
        Coll.tbeg = None
 
    return Coll
        
def get1d(dname):
    return dname.replace('time','t').replace('longitude','x').replace('latitude','y').replace('levels','z').replace('lon','x').replace('lat','y').replace('lev','z').replace('Xdim','x').replace('Ydim','y').replace('nf','')


def MakeMyStyleSheet( ):

    result = StyleSheet()
    NormalText = TextStyle( TextPropertySet( result.Fonts.TimesNewRoman, 22 ) )
    ps = ParagraphStyle( 'Normal',
						 NormalText.Copy(),
						 ParagraphPropertySet( space_before = 60,
											   space_after  = 60 ) )
    result.ParagraphStyles.append( ps )

    ps = ParagraphStyle( 'Normal Short',
						 NormalText.Copy() )
    result.ParagraphStyles.append( ps )

    NormalText.TextPropertySet.SetSize( 32 )
    ps = ParagraphStyle( 'Heading 1',
						 NormalText.Copy(),
						 ParagraphPropertySet( space_before = 240,
											   space_after  = 60 ) )
    result.ParagraphStyles.append( ps )

    NormalText.TextPropertySet.SetSize( 24 ).SetBold( True )
    ps = ParagraphStyle( 'Heading 2',
						 NormalText.Copy(),
						 ParagraphPropertySet( space_before = 240,
											   space_after  = 60 ) )
    result.ParagraphStyles.append( ps )

    NormalText.TextPropertySet.SetSize( 24 ).SetBold( True ).SetItalic(True)
    ps = ParagraphStyle( 'Heading 3',
						 NormalText.Copy(),
						 ParagraphPropertySet( space_before = 240,
											   space_after  = 60 ) )
    result.ParagraphStyles.append( ps )

	#	Add some more in that are based on the normal template but that
	#	have some indenting set that makes them suitable for doing numbered
    normal_numbered = result.ParagraphStyles.Normal.Copy()
    normal_numbered.SetName( 'Normal Numbered' )
    normal_numbered.ParagraphPropertySet.SetFirstLineIndent( TabPropertySet.DEFAULT_WIDTH * -1 )
    normal_numbered.ParagraphPropertySet.SetLeftIndent     ( TabPropertySet.DEFAULT_WIDTH )

    result.ParagraphStyles.append( normal_numbered )

    normal_numbered2 = result.ParagraphStyles.Normal.Copy()
    normal_numbered2.SetName( 'Normal Numbered 2' )
    normal_numbered2.ParagraphPropertySet.SetFirstLineIndent( TabPropertySet.DEFAULT_WIDTH * -1 )
    normal_numbered2.ParagraphPropertySet.SetLeftIndent     ( TabPropertySet.DEFAULT_WIDTH *  2 )

    result.ParagraphStyles.append( normal_numbered2 )

	## LIST STYLES
    for idx, indent in [ (1, TabPS.DEFAULT_WIDTH    ),
						 (2, TabPS.DEFAULT_WIDTH * 2),
						 (3, TabPS.DEFAULT_WIDTH * 3) ] :
		indent = TabPropertySet.DEFAULT_WIDTH
		ps = ParagraphStyle( 'List %s' % idx,
							 TextStyle( TextPropertySet( result.Fonts.Arial, 22 ) ),
							 ParagraphPropertySet( space_before = 60,
												   space_after  = 60,
												   first_line_indent = -indent,
												   left_indent       = indent) )
		result.ParagraphStyles.append( ps )

    return result

#------------------------------------ M A I N ------------------------------------
if __name__ == "__main__":
    
    format = 'text'
    history = 'none'
    outFile = 'filespec.txt'
    
#   Parse command line options
#   --------------------------
    parser = OptionParser(usage="Usage: %prog [OPTIONS] netcdf_file(s)",
                          version='1.1.0' )

    parser.add_option("-o", "--output", dest="outFile", default=outFile,
              help="Output file name for the File Spec, ignored for format 'stdout' (default=%s)"\
                          %outFile )

    parser.add_option("-f", "--format", dest="format", default=format,
              help="Output file format: one of 'stdout', 'text', 'rtf' (default=%s)"%format )

    parser.add_option("-H", "--history", dest="history", default=None,
              help="Optional HISTORY.rc resource file (default=None)")

    parser.add_option("-M", "--merra",
                      action="store_true", dest="merra",
                      help="Assume MERRA-2 file name conventions")

    parser.add_option("-N", "--nature",
                      action="store_true", dest="nature",
                      help="Assume Nature Run file name conventions")

    parser.add_option("-O", "--merraobs",
                      action="store_true", dest="merraobs",
                      help="Assume merraobs Run file name conventions")

    parser.add_option("-F", "--forward_processing",
                      action="store_true", dest="fp",
                      help="Assume Forward Processing file name conventions")

    parser.add_option("-D", "--dyamond",
                      action="store_true", dest="dyamond",
                      help="Assume DYAMOND file name conventions")

    parser.add_option("-v", "--verbose",
                      action="store_true", dest="verbose",
                      help="Verbose mode.")

    (options, inFiles) = parser.parse_args()

    if len(inFiles) < 1:
        parser.error("must have 1 afile name")

    # Load HISTORY.rc if specified
    # ----------------------------
    if options.history is not None:
        cf = Config(options.history)
        Descr = dict()
        for v in cf.keys():
            v_ = v.replace('-.descr','.descr')
            Descr[v_] = cf(v)
    else:
        cf = None
    
    # Only one type of files
    # ----------------------
    n = 0
    if options.merra:  n+=1
    if options.nature: n+=1
    if options.merraobs: n+=1
    if options.dyamond: n+=1
    if options.fp:     n+=1
    if n>1:
       raise 'ValueError', 'only one of -N, -M, -F can be specified'

    # Instantiate writer
    # ------------------
    name, ext = os.path.splitext(options.outFile)
    if 'text' in options.format:
        options.outFile = name + '.txt'
        writer = txtWriter(options.outFile)
    elif 'txt' in options.format:
        options.outFile = name + '.txt'
        writer = txtWriter(options.outFile)
    elif 'stdout' in options.format:
        options.outFile = name + '.txt'
        writer = stdoutWriter(options.outFile)
    elif 'rtf' in options.format:
        if not HAS_RTF:
            raise ValueError, 'PyRTF is not installed, therefore format RTF is not available'
        options.outFile = name + '.doc'
        writer = rtfWriter(options.outFile)
    elif 'doc' in options.format:
        if not HAS_RTF:
            raise ValueError, 'PyRTF is not installed, therefore format RTF is not available'
        options.outFile = name + '.doc'
        writer = rtfWriter(options.outFile)
    else:
        raise ValueError, "unsupported format <%s>"%options.format

    # Unique name of collections
    # --------------------------
    Files = dict()
    for fname in inFiles:
        if options.nature:
            collname = os.path.basename(fname).split('.')[1] 
            colltype = None
        elif options.dyamond:
            collname = os.path.basename(fname).split('.')[1]
            colltype = None
            #print collname
        elif options.merraobs:
            # Strip off datestamp and nc4
            colllist = os.path.basename(fname).split('.')[:-2]
            # strip off merra2
            colllist = colllist[1:]
            # Make a string joined by dots
            collname = '.'.join(colllist)
            rcn = collname+'.colltype'
            colltype = Descr[rcn].replace("'","").replace(",","")
        elif options.fp:
            collname = os.path.basename(fname).split('.')[3] 
            colltype = None
        else:
            collname = os.path.basename(fname).split('.')[1]
            colltype = None

        if colltype:
            Name = colltype.split('_')
        else:
            Name = collname.split('_')
        if len(Name) >= 4:
            Files[collname] = fname

    # Sort collections by (res,name,freq,dim)
    # --------------------------------------
    CollNames = dict()
    if options.merraobs:
        for collname in Files:
            freq, dim, name, res = colltype.split('_')
            if   freq == 'const': prefix = 'a'
            elif res[0] == 'N':   prefix = 'b'
            elif res[0] == 'M':   prefix = 'c'
            else:                 prefix = 'd'
            name = prefix + collname 
            CollNames[name] = collname
    elif options.dyamond:
        for collname in Files:
            if 'const' in collname.split('_')[0]:
                freq, dim, name, res = collname.split('_')
            else:
                freq, dur, dim, name, res = collname.split('_')

            if   freq == 'const': prefix = 'a'
            elif res[0] == 'N':   prefix = 'b'
            elif res[0] == 'M':   prefix = 'c'
            else:                 prefix = 'd'
            name = prefix + collname 
            CollNames[name] = collname
    else:
        for collname in Files:
            freq, dim, name, res = collname.split('_')
            if   freq == 'const': prefix = 'a'
            elif res[0] == 'N':   prefix = 'b'
            elif res[0] == 'M':   prefix = 'c'
            else:                 prefix = 'd'
            name = prefix + collname 
            CollNames[name] = collname

    # Gather metadata for each collection
    # -----------------------------------
    for name in sorted(CollNames):
        collname = CollNames[name]
        if options.verbose:
            print "[] Working on collection <%s>"%collname
        fname = Files[collname]
        Collection = getCollection(fname,collname,options,colltype)
        if cf is not None:
            try:
                rcn = collname+'.descr'
                title = Descr[rcn][:-2].replace("'","") # override what is on file 
                Collection.title = title
            except:
                pass
        if options.merraobs:
            for vname in Collection.variables:
                long_name = Collection.variables[vname].long_name
                try:
                    findUnits = re.search('.*\[(.*)\].*',long_name)
                    newUnits = findUnits.group(1)
                    if newUnits == "hPA":
                        newUnits = "hPa"
                except AttributeError:
                    newUnits = '1'
               
                Collection.newUnits[vname] = newUnits
        writer.doCollection(Collection,options)

    # All done
    # --------
    writer.close()
