"""
Last updated on 2023 September 12

@author: Shunya Kaneki (AIST)
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from utilmod import index as idx


class bg():
    def __init__(self, sft,rint):
        x = np.concatenate([sft[idx(sft,1100):idx(sft,1150)],sft[idx(sft,1700):idx(sft,1750)]],0)
        y = np.concatenate([rint[idx(sft,1100):idx(sft,1150)],rint[idx(sft,1700):idx(sft,1750)]],0)
        res = stats.linregress(x,y)
        self.intercept = res.intercept
        self.slope = res.slope
        self.bgint = self.slope*sft+self.intercept
        self.bint = rint-self.bgint
        if self.bint[idx(sft,1320):idx(sft,1380)].max()>0.5*self.bint[idx(sft,1570):idx(sft,1600)].max():
            pass
        else:
            x = np.concatenate([sft[idx(sft,1200):idx(sft,1250)],sft[idx(sft,1700):idx(sft,1750)]],0)
            y = np.concatenate([rint[idx(sft,1200):idx(sft,1250)],rint[idx(sft,1700):idx(sft,1750)]],0)
            res = stats.linregress(x,y)
            self.intercept = res.intercept
            self.slope = res.slope
            self.bgint = self.slope*sft+self.intercept
            self.bint = rint-self.bgint

class norm():
    def __init__(self, sft,bint):
        amp = bint[idx(sft,1100):idx(sft,1750)].max()
        self.nbint = bint/amp

class treat():
    def __init__(self, file,dtype,datadir,mkdir,sft,rint,specdat=True):
        self.sft = sft
        self.rint = rint
        self.bint = bg(sft,rint).bint
        self.bgint = bg(sft,rint).bgint
        self.nbint = norm(sft,self.bint).nbint
        self.bgslope = bg(sft,rint).slope
        self.bgintercept = bg(sft,rint).intercept
        
        if specdat:
            specdir = mkdir.specdir
            value = np.c_[sft,rint,self.bint,self.nbint]
            csvname = specdir+'/'+file.lstrip(datadir).rstrip('.'+dtype)+'_treat.csv'
            np.savetxt(csvname,value,delimiter=',',comments='',header='shift,raw,bg,bg+norm')
        elif not specdat:
            pass

def savefig(file,dtype,datadir,mkdir,tdat,figurePDF=True,figurePNG=True):
    if figurePDF and figurePNG:
        c1 = 'black'
        c2 = 'gray'
        ls1 = 'solid'
        ls2 = 'dashed'
        lw1 = 1
        imgdir = mkdir.imgdir
        
        plt.figure(figsize=(7,7),dpi=200)
        plt.plot(tdat.sft,tdat.rint,lw=lw1,color=c2,linestyle=ls1,label='raw')
        plt.legend()
        plt.xlim(400,2400)
        plt.xticks(np.linspace(400,2400,11))
        plt.xlabel('Raman shift [/cm]')
        plt.ylabel('Intensity')
        plt.title(file.lstrip(datadir).rstrip('.'+dtype)+', Raw')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_raw.pdf')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_raw.png')
        plt.close()
        
        plt.figure(figsize=(7,7),dpi=200)
        plt.plot(tdat.sft,tdat.rint,lw=lw1,color=c2,linestyle=ls1,label='raw')
        plt.plot(tdat.sft,tdat.bgint,lw=lw1,color=c2,linestyle=ls2,label='bg, y=ax+b\na={}\nb={}'.format(tdat.bgslope,tdat.bgintercept))
        plt.plot(tdat.sft,tdat.bint,lw=lw1,color=c1,linestyle=ls1,label='bg-corrected')
        plt.legend()
        plt.xlim(1100,1750)
        plt.xlabel('Raman shift [/cm]')
        plt.ylabel('Intensity')
        plt.title(file.lstrip(datadir).rstrip('.'+dtype)+', BG correction')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_raw_bg.pdf')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_raw_bg.png')
        plt.close()
        
        plt.figure(figsize=(7,7),dpi=200)
        plt.plot(tdat.sft,tdat.nbint,lw=lw1,color=c1,linestyle=ls1)
        plt.xlim(1100,1750)
        plt.ylim(-0.1,1.1)
        plt.xlabel('Raman shift [/cm]')
        plt.ylabel('Normalized intensity')
        plt.title(file.lstrip(datadir).rstrip('.'+dtype)+', normalization')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_norm.pdf')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_norm.png')
        plt.close()
    
    elif figurePDF and not figurePNG:
        c1 = 'black'
        c2 = 'gray'
        ls1 = 'solid'
        ls2 = 'dashed'
        lw1 = 1
        imgdir = mkdir.imgdir
        
        plt.figure(figsize=(7,7),dpi=200)
        plt.plot(tdat.sft,tdat.rint,lw=lw1,color=c2,linestyle=ls1,label='raw')
        plt.legend()
        plt.xlim(400,2400)
        plt.xticks(np.linspace(400,2400,11))
        plt.xlabel('Raman shift [/cm]')
        plt.ylabel('Intensity')
        plt.title(file.lstrip(datadir).rstrip('.'+dtype)+', Raw')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_raw.pdf')
        plt.close()
        
        plt.figure(figsize=(7,7),dpi=200)
        plt.plot(tdat.sft,tdat.rint,lw=lw1,color=c2,linestyle=ls1,label='raw')
        plt.plot(tdat.sft,tdat.bgint,lw=lw1,color=c2,linestyle=ls2,label='bg, y=ax+b\na={}\nb={}'.format(tdat.bgslope,tdat.bgintercept))
        plt.plot(tdat.sft,tdat.bint,lw=lw1,color=c1,linestyle=ls1,label='bg-corrected')
        plt.legend()
        plt.xlim(1100,1750)
        plt.xlabel('Raman shift [/cm]')
        plt.ylabel('Intensity')
        plt.title(file.lstrip(datadir).rstrip('.'+dtype)+', BG correction')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_raw_bg.pdf')
        plt.close()
        
        plt.figure(figsize=(7,7),dpi=200)
        plt.plot(tdat.sft,tdat.nbint,lw=lw1,color=c1,linestyle=ls1)
        plt.xlim(1100,1750)
        plt.ylim(-0.1,1.1)
        plt.xlabel('Raman shift [/cm]')
        plt.ylabel('Normalized intensity')
        plt.title(file.lstrip(datadir).rstrip('.'+dtype)+', normalization')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_norm.pdf')
        plt.close()
    
    elif not figurePDF and figurePNG:
        c1 = 'black'
        c2 = 'gray'
        ls1 = 'solid'
        ls2 = 'dashed'
        lw1 = 1
        imgdir = mkdir.imgdir
        
        plt.figure(figsize=(7,7),dpi=200)
        plt.plot(tdat.sft,tdat.rint,lw=lw1,color=c2,linestyle=ls1,label='raw')
        plt.legend()
        plt.xlim(400,2400)
        plt.xticks(np.linspace(400,2400,11))
        plt.xlabel('Raman shift [/cm]')
        plt.ylabel('Intensity')
        plt.title(file.lstrip(datadir).rstrip('.'+dtype)+', Raw')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_raw.png')
        plt.close()
        
        plt.figure(figsize=(7,7),dpi=200)
        plt.plot(tdat.sft,tdat.rint,lw=lw1,color=c2,linestyle=ls1,label='raw')
        plt.plot(tdat.sft,tdat.bgint,lw=lw1,color=c2,linestyle=ls2,label='bg, y=ax+b\na={}\nb={}'.format(tdat.bgslope,tdat.bgintercept))
        plt.plot(tdat.sft,tdat.bint,lw=lw1,color=c1,linestyle=ls1,label='bg-corrected')
        plt.legend()
        plt.xlim(1100,1750)
        plt.xlabel('Raman shift [/cm]')
        plt.ylabel('Intensity')
        plt.title(file.lstrip(datadir).rstrip('.'+dtype)+', BG correction')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_raw_bg.png')
        plt.close()
        
        plt.figure(figsize=(7,7),dpi=200)
        plt.plot(tdat.sft,tdat.nbint,lw=lw1,color=c1,linestyle=ls1)
        plt.xlim(1100,1750)
        plt.ylim(-0.1,1.1)
        plt.xlabel('Raman shift [/cm]')
        plt.ylabel('Normalized intensity')
        plt.title(file.lstrip(datadir).rstrip('.'+dtype)+', normalization')
        plt.savefig(imgdir+'/fit/'+file.lstrip(datadir).rstrip('.'+dtype)+'_norm.png')
        plt.close()
    
    elif not figurePDF and not figurePNG:
        pass
    
    else:
        print('figurePDF and figurePNG should be either True or False!')    