////////////////////////////////////////////////////////////////////////////////
//                                                                            //
// Class for linear, logarithmic, inverse (1/x) grids in 1D                   //
//                                                                            //
// Burkhard Militzer                                        Urbana 4-9-99     //
//                                                                            //
////////////////////////////////////////////////////////////////////////////////

#ifndef GRID_H
// #pragma interface
#define GRID_H

#include "Standard.h"
#include "Vector.h"

enum GridType {LINEAR, LOG, INVERSE, GEOMETRIC};

// return true if the are equal or different by a tiny bit
inline bool Compare(const double a, const double b, const double acc=1e-5) {
  return fabs(a-b)<acc*max(fabs(a),fabs(b));
}

class Grid {
 public:
  Vector<double> grid;
  GridType type;
  string myTypeString;
  int nPoints; 
  int nExtra;
  double r1,r2;

  int Size() const {
    return nPoints;
  }
  double CutOff() const {
    return r2;
  }
  static string gridString;

  inline double operator()(int i) const {
    return grid[i];
  }
  
  virtual int Index (const double r) const = 0;
  virtual void Print(ostream & os) const = 0;
  virtual void Read(istream & is) = 0;
  virtual double GetExtra(const int i) const = 0;
  virtual void SetNPoints(const int n) = 0;
  // BEFORE virtual void ExtendRMax(const double rMax)
  // BUT now the new rMax can also be smaller than before
  virtual void ChangeRMax(const double rMax) = 0;
  virtual Grid * Clone() const = 0;
  virtual double InterpolateFunction(const double r, const int ii, const double f1, const double f2) const = 0;
  virtual ~Grid() {}

  bool operator==(Grid & grid) {
    //    Write2(nExtra,grid.nExtra);
    //    cout << (type!=grid.type) <<  (myTypeString!=grid.myTypeString)
    //	 << (nPoints!=grid.nPoints) << !Compare(r1,grid.r1) << !Compare(r2,grid.r2) << (nExtra!=grid.nExtra) << endl;
    if (type!=grid.type || myTypeString!=grid.myTypeString || 
	nPoints!=grid.nPoints || !Compare(r1,grid.r1) || !Compare(r2,grid.r2) || nExtra!=grid.nExtra) return false;
    for(int i=0; i<nExtra; i++) 
      if (GetExtra(i)!=grid.GetExtra(i)) return false;
    return true;
  }

  friend istream& operator>>(istream &is, Grid & grid) {
    grid.Read(is);
    return is;
  }
  
  friend ostream& operator<<(ostream &os, const Grid & grid ) {
    grid.Print(os);
    return os;
  }

  static Grid * Define(istream & is, const bool formatFlag=false);

  void CheckType(istream & is) {
    string s,type;
    is >> s >> type;
    if (s!=gridString) 
      error("Could not read grid parameters");
    if (type!=myTypeString)
      error("Wrong grid",type,myTypeString);
  }   
};



class LogGrid : public Grid {
 public:
  double GetExtra(const int i) const {
    error("log grid");
    return 0.0;
  }
  static const string typeString;

  double Elem(const int i) const {
    return r1 * pow(r2/r1, (double)i/(double)(nPoints-1));
  }

  void Read(istream & is) {
    CheckType(is);
    is >> r1 >> r2 >> nPoints;
    Init();
  }

  void Print(ostream &os) const {
    os << typeString << " " 
      //       << sci(12,r1) << "  " << sci(12,r2) << "  " << nPoints;
       << r1 << " " << r2 << " " << nPoints;
  }
  
  friend istream& operator>>(istream &is, LogGrid & logGrid ) {
    logGrid.Read(is);
    return is;
  }
  friend ostream& operator<<(ostream &os, const LogGrid & g ) {
    g.Print(os);
    return os;
  }

  void SetNPoints(const int n) {
    nPoints=n;
    Init();
  }

  // extend grid beyond rMax by changing r2 but keeping resolution the same
  void ChangeRMax(const double rMax) {
    if (rMax<r1) error("ChangeRMax",r1,rMax);
    int i=Index(rMax*(1.0+1e-12));
    double r=Elem(i);
    if (r*(1.0+1e-12)<rMax) {
      i++;
      r=Elem(i);
    }
    Init(r1,r,i+1);
  }

  void Init() {
    grid.resize(nPoints);
    for (int i=0; i<nPoints; i++)
      grid[i] = Elem(i);
  }

  void Init(const double r11, const double r21, const int n) {
    r1=r11;
    r2=r21;
    nPoints=n;
    Init();
  }

  int Index (const double r) const {
    int index;
    index = int (floor((double)(nPoints-1)*log(r/r1)/log(r2/r1)));
    return index;
  }

  void SetType() {
    type = LOG;
    myTypeString = typeString;
    nExtra =0;
  }

  LogGrid(const double r1, const double r2, const int numPoints) {
    SetType();
    Init(r1,r2,numPoints);
  }
  
  LogGrid() {
    SetType();
  }

  ~LogGrid() {}

  LogGrid* Clone() const {
    return new LogGrid(*this);
  }

  // can also be used for logarthmic EXTRApolation if ii=0 or ii==n-2
  double InterpolateFunction(const double r, const int ii, const double f1, const double f2) const {
    double logR1 = log(grid[ii]  );
    double logR2 = log(grid[ii+1]);
    double delta = logR2 - logR1;
    double q2 = (log(r)-logR1)/delta;
    double q1 = 1.0-q2;
    return f1*q1+f2*q2;
  }
};


class InverseGrid : public Grid {
  double x0, dri, rr;
 public:
  static const string typeString;

  double clusterFactor;
  double GetExtra(const int i) const {
    return clusterFactor;
  }

  double Elem(const int i) const {
    double val = rr+dri/(i+1.0-x0);
    return val;
  }
  
  int Index(const double r) const {
    int index = int( floor (dri/(r-rr) -1.0 + x0) );
    return index;
  }

  void SetType() {
    type = INVERSE;
    myTypeString = typeString;
    nExtra=1;
  }

  void SetNPoints(const int n) {
    nPoints=n;
    Init();
  }

  // extend grid beyond rMax by changing r2 but keeping resolution the same
  void ChangeRMax(const double rMax) {
    if (rMax<r1) error("ChangeRMax",r1,rMax);
    int i=Index(rMax*(1.0+1e-12));
    double r=Elem(i);
    if (r*(1.0+1e-12)<rMax) {
      i++;
      r=Elem(i);
    }
    Init(r1,r,clusterFactor,i+1);
  }

  void Init(const double r11, const double r21, 
	    const double cf, const int numPoints) {
    r1= r11;
    r2= r21;
    clusterFactor = cf;
    nPoints = numPoints;
    Init();
  }

  void Init() {
    x0 = (nPoints - clusterFactor)/(1.0-clusterFactor);
    dri = -(r2-r1)*( double(nPoints)-x0)*(1.0-x0) / 
      ( double(nPoints)-1.0);
    rr = r1 - dri/(1.0-x0);
    grid.resize(nPoints);
    for (int i=0; i<nPoints; i++)
      grid[i] = Elem(i);
  }
  
  InverseGrid() {
    SetType();
  }

  InverseGrid(const double r11, const double r21, const int numPoints,
	      const double cluster) {
    SetType();
    Init(r11,r21,cluster,numPoints);
  }
  ~InverseGrid() {}

  InverseGrid* Clone() const {
    return new InverseGrid(*this);
  }

  friend istream& operator>>(istream &is, InverseGrid & g ) {
    g.Read(is);
    return is;
  }

  friend ostream& operator<<(ostream &os, const InverseGrid & g ) {
    g.Print(os);
    return os;
  }

  void Read(istream & is) {
    CheckType(is);
    is >> r1 >> r2 >> nPoints >> clusterFactor;
    Init();
  }

  void Print(ostream &os) const {
    os << typeString << " "
       << r1 << " " << r2 << " " << nPoints << " "
       << clusterFactor;
    //       << sci(12,r1) << "  " << sci(12,r2) << "  " << nPoints << " "
    //       << sci(12,clusterFactor);
  }
  
  double InterpolateFunction(const double r, const int ii, const double f1, const double f2) const {
    error("This is function has not yet been defined.");
    return 0.0;
  }
};

class LinearGrid : public Grid {
 public:
  static const string typeString;
  static const string linearTypeString;
  double GetExtra(const int i) const {
    error("linear grid");
    return 0.0;
  }

  void SetType() {
    type = LINEAR;
    myTypeString = typeString;
    nExtra =0;
  }

  LinearGrid() {
    SetType();
  }
  
  LinearGrid(const double r1, const double r2, const int n) {
    SetType();
    Init(r1, r2, n);
  }
  
  ~LinearGrid() {}

  LinearGrid* Clone() const {
    return new LinearGrid(*this);
  }

  double Elem(const int i) const {
    return r1 + delta*double(i);
  }
  double Elem(const double i) const {
    return r1 + delta*i;
  }
  double GetMidPoint(const int i) const {
    return Elem(double(i)+0.5);
  }

  friend istream& operator>>(istream &is, LinearGrid & g ) {
    g.Read(is);
    return is;
  }

  friend ostream& operator<<(ostream &os, const LinearGrid & g ) {
    g.Print(os);
    return os;
  }

  void Read(istream & is) {
    CheckType(is);
    is >> r1 >> r2 >> nPoints;
    Init();
  }

  void Print(ostream &os) const {
    os << gridString << " " << typeString << " "
      //       << sci(12,r1) << "  " << sci(12,r2) << "  " << nPoints;
       << r1 << " " << r2 << " " << nPoints;
  }

  int Inside(const double r) const {
    return (r1<=r) && (r<=r2);
  }
  int InsideNoUpperEnd(const double r) const {
    return (r1<=r) && (r<r2);
  }
  int Index(const double r) const {
    return int(floor((r-r1) / delta));
  }
  double Remainder(const double r) const {
    return r-Elem(Index(r)); 
  }
  double InverseMap(const double r) const {
    return (r-r1) / delta;
  }
  double IntervalSize(const double r) const {
    return delta;
  }

  void SetNPoints(const int n) {
    nPoints=n;
    Init();
  }

  // extend grid beyond rMax by changing r2 but keeping resolution the same
  void ChangeRMax(const double rMax) {
    if (rMax<r1) error("ChangeRMax",r1,rMax);
    int i=Index(rMax*(1.0+1e-12));
    double r=Elem(i);
    if (r*(1.0+1e-12)<rMax) {
      i++;
      r=Elem(i);
    }
    Init(r1,r,i+1);
  }

  void Init() {
    delta = (r2 - r1) / (double(nPoints) - 1.0);
    grid.resize(nPoints);
    for (int i=0; i<nPoints; i++)
      grid[i] = Elem(i);
  }

  void Init(const double r11, const double r21, const int n) {
    r1 = r11;
    r2 = r21;
    nPoints = n;
    Init();
  }

  // can also be used for linear EXTRApolation if ii=0 or ii==n-2
  double InterpolateFunction(const double r, const int ii, const double f1, const double f2) const {
    double q2 = (r-Elem(ii))/delta;
    double q1 = 1.0-q2;
    return f1*q1+f2*q2;
  }

  double delta;
};

class GeometricGrid : public Grid {
 public:
  double f;
  double intervalRatio;
  double delta;

  static const string typeString;
  static const string linearTypeString;
  double GetExtra(const int i) const {
    error("geometric grid");
    return 0.0;
  }

  void SetType() {
    type = GEOMETRIC;
    myTypeString = typeString;
    nExtra =0;
  }

  GeometricGrid() {
    SetType();
  }
  
  GeometricGrid(const double r1, const double r2, const int n, double intervalRatio) {
    SetType();
    Init(r1, r2, n, intervalRatio);
  }
  
  GeometricGrid(const GeometricGrid & g) {
    SetType();
    Init(g.r1, g.r2, g.nPoints, g.intervalRatio);
  }

  ~GeometricGrid() {}

  GeometricGrid* Clone() const {
    return new GeometricGrid(*this);
  }

  double Elem(const int i) const {
    return Elem(double(i));
  }
  double Elem(const double i) const {
    if (i==0.0) return r1;
    return r1 + delta * (pow(f,i)-1.0) / (f-1.0);
  }
  double GetMidPoint(const int i) const {
    return Elem(double(i)+0.5);
  }

   // extend grid beyond rMax by changing r2 but keeping resolution the same
  void ChangeRMax(const double rMax) {
    error("ChangeRMax() not implemented.");
  }

  friend istream& operator>>(istream &is, GeometricGrid & g ) {
    g.Read(is);
    return is;
  }

  friend ostream& operator<<(ostream &os, const GeometricGrid & g ) {
    g.Print(os);
    return os;
  }

  void Read(istream & is) {
    CheckType(is);
    is >> r1 >> r2 >> nPoints >> intervalRatio;
    Init();
  }

  void Print(ostream &os) const {
    os << gridString << " " << typeString << " "
      //       << sci(12,r1) << "  " << sci(12,r2) << "  " << nPoints;
       << r1 << " " << r2 << " " << nPoints << " " << intervalRatio;
  }

  int Inside(const double r) const {
    return (r1<=r) && (r<=r2);
  }
  int InsideNoUpperEnd(const double r) const {
    return (r1<=r) && (r<r2);
  }
  double InverseMap(const double r) const {
    double z = (r-r1)/delta * (f-1.0) + 1.0;
    double index = log(z-f);
    return index;
  }
  int Index(const double r) const {
    return int(floor(InverseMap(r)));
  }
  double Remainder(const double r) const {
     return r-Elem(Index(r)); 
  }

  void SetNPoints(const int n) {
    nPoints=n;
    Init();
  }

  void Init() {
    f = pow(intervalRatio,1.0/(nPoints-2.0));
    double y = (pow(f,nPoints-1)-1.0) / (f-1.0);
    delta = (r2 - r1) / y;
    //    Write3(r1,r2,nPoints);
    //    Write3(intervalRatio,f,delta);
    grid.resize(nPoints);
    for (int i=0; i<nPoints; i++) {
      grid[i] = Elem(i);
      //      Write2(i,grid[i]);
    }
  }

  void Init(const double r1_, const double r2_, const int nPoints_, const double intervalRatio_) {
    r1 = r1_;
    r2 = r2_;
    nPoints = nPoints_;
    intervalRatio = intervalRatio_;
    Init();
  }

  double InterpolateFunction(const double r, const int ii, const double f1, const double f2) const {
    error("This is function has not yet been defined.");
    return 0.0;
  }

};

#endif
