/////////////////////////////////////////////////////////////////////////////
//                                                                         //
//  1D search routines for roots of equation                               //
//  uses templates                                                         //
//                                                                         //
//  B. Militzer                                     Livermore 09-03-01     //
//                                                                         //
/////////////////////////////////////////////////////////////////////////////

#ifndef _FINDROOT_
#define _FINDROOT_

#include "Standard.h"

inline double NRSign(const double a, const double b) {
  return (b) >= 0.0 ? fabs(a) : -fabs(a);
}

class FindRootProblem {
 public:
  string s;
  double x1,x2,f1,f2;
  FindRootProblem(string & s_,double x1_, double x2_, double f1_, double f2_)
    :s(s_),x1(x1_),x2(x2_),f1(f1_),f2(f2_){}

  friend ostream& operator<<(ostream & os, const FindRootProblem & frp) {
    os << frp.s;
    os << " x1= " << frp.x1;
    os << " x2= " << frp.x2;
    os << " f(x1)= " << frp.f1;
    os << " f(x2)= " << frp.f2;
    return os;
  }
};

///////////////////////////////////////////////////////////////////////////////

// pass in f(x1) and f(x2) because we may have had a need to compute beforehand
template <class T>
double FindRootBrentsMethod(const double x1, const double x2, const double fx1, const double fx2, const double tolX, T & f, const string name="") {
  double a=x1,b=x2,c=x2,d,e=0.0;
  //  double fa=f(a);
  //  double fb=f(b);
  double fa=fx1;
  double fb=fx2;
  double fc=fb;
  double p,q,r;

  if ((fa > 0.0 && fb > 0.0) || (fa < 0.0 && fb < 0.0))
    error("FindRootBrentsMethod() \""+name+"\" : Both function values have same sign:",fa,fb,a,b);

  const int ITMAX = 100;
  for (int iter=0; iter<ITMAX; iter++) {
    if ((fb > 0.0 && fc > 0.0) || (fb < 0.0 && fc < 0.0)) {
      c=a;
      fc=fa;
      e=d=b-a;
    }
    if (fabs(fc) < fabs(fb)) {
      a=b;
      b=c;
      c=a;
      fa=fb;
      fb=fc;
      fc=fa;
    }
    double eps  = 3.0e-8;
    //    double eps  = 3.0e-12; // does not improve things
    double tol1 = 2.0*eps*fabs(b)+0.5*tolX;
    double xm   = 0.5*(c-b);
    if (fabs(xm) <= tol1 || fb == 0.0) return b;

    if (fabs(e) >= tol1 && fabs(fa) > fabs(fb)) {
      double s=fb/fa;
      if (a == c) {
	p = 2.0*xm*s;
	q = 1.0-s;
      } else {
	q = fa/fc;
	r = fb/fc;
	p = s*(2.0*xm*q*(q-r)-(b-a)*(r-1.0));
	q = (q-1.0)*(r-1.0)*(s-1.0);
      }
      if (p > 0.0) q = -q;
      p = fabs(p);
      double min1 = 3.0*xm*q-fabs(tol1*q);
      double min2 = fabs(e*q);
      if (2.0*p < (min1 < min2 ? min1 : min2)) {
	e = d;
	d = p/q;
      } else {
	d = xm;
	e = d;
      }
    } else {
      d = xm;
      e = d;
    }

    a  = b;
    fa = fb;
    if (fabs(d) > tol1) {
      b += d;
    } else {
      b += NRSign(tol1,xm);
    }
    fb=f(b);
  }
  error("in ZBrent \""+name+"\" : Maximum number of iterations exceeded in ZBrent",ITMAX);
  return 0.0;
}

template <class T>
inline double FindRootBrentsMethod(const double x1, const double x2, const double tolX, T & f, const string name="") {
  return FindRootBrentsMethod(x1,x2,f(x1),f(x2),tolX,f,name);
}

template <class T>
inline double ZBrent(const double x1, const double x2, const double tolX, T & f, const string name="") { // not pefect name I gave it recently
  return FindRootBrentsMethod(x1,x2,tolX,f,name);
}

template <class T>
double FindRootScaleThenBrent(double x1, const double tolX, const double sign_dfdx, T & f, const double factor=1.1, const string name="") {
  error("Not tested 1/16/24");
  double fx1 = f(x1);
  while (true) {
    double x2 = (fx1*sign_dfdx>0.0) ? x1/factor : x1*factor;
    double fx2 = f(x2);
    if (fx1*fx2<0.0) return FindRootBrentsMethod(x1,x2,fx1,fx2,tolX,f,name);
    x1  = x2;
    fx1 = fx2;
  }
}

template <class T>
double FindRootScaleIntervalThenBrent(double x1, double x2, const double tolX, T & f, const double factor=1.1, const string name="") {
  double fx1 = f(x1);
  double fx2 = f(x2);
  if (sign(fx1)!=sign(fx2)) return FindRootBrentsMethod(x1,x2,fx1,fx2,tolX,f,name);

  while (true) {
    if (fabs(fx2)>fabs(fx1)) { 
      Swap(x1,x2);
      Swap(fx1,fx2);
    } // switch x1 and x2 so that x2 is likely to be closer to the root

    x2  = x1 + factor*(x2-x1);
    fx2 = f(x2);
    if (sign(fx1)!=sign(fx2)) return FindRootBrentsMethod(x1,x2,fx1,fx2,tolX,f,name);

    x1  = x2 + factor*(x1-x2);
    fx1 = f(x1);
    if (sign(fx1)!=sign(fx2)) return FindRootBrentsMethod(x1,x2,fx1,fx2,tolX,f,name);
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

// pass in f(x1) and f(x2) because we may have had a need to compute beforehand
template <class T>
double FindRootBrentsMethodThrow(const double x1, const double x2, const double fx1, const double fx2, const double tolX, T & f, const string name="") {
  double a=x1,b=x2,c=x2,d,e=0.0;
  //  double fa=f(a);
  //  double fb=f(b);
  double fa=fx1;
  double fb=fx2;
  double fc=fb;
  double p,q,r;

  //  if ((fa > 0.0 && fb > 0.0) || (fa < 0.0 && fb < 0.0))
  //    error("FindRootBrentsMethod() \""+name+"\" : Both function values have same sign:",fa,fb,a,b);

  if (sign(fx1)==sign(fx2)) {
    string s = "FindRootBrentsMethodThrow: same sign at both interval boundaries for \""+name+"\"";
    s += " x1= "+DoubleToString(x1);
    s += " x2= "+DoubleToString(x2);
    s += " f(x1)= "+DoubleToString(fx1);
    s += " f(x2)= "+DoubleToString(fx2);
    cout << s << endl;
    cout.flush();
    FindRootProblem frp(s,x1,x2,fx1,fx2);
    //    throw(s.c_str());
    throw(frp);
  }

  const int ITMAX = 100;
  for (int iter=0; iter<ITMAX; iter++) {
    if ((fb > 0.0 && fc > 0.0) || (fb < 0.0 && fc < 0.0)) {
      c=a;
      fc=fa;
      e=d=b-a;
    }
    if (fabs(fc) < fabs(fb)) {
      a=b;
      b=c;
      c=a;
      fa=fb;
      fb=fc;
      fc=fa;
    }
    double eps  = 3.0e-8;
    //    double eps  = 3.0e-12; // does not improve things
    double tol1 = 2.0*eps*fabs(b)+0.5*tolX;
    double xm   = 0.5*(c-b);
    if (fabs(xm) <= tol1 || fb == 0.0) return b;

    if (fabs(e) >= tol1 && fabs(fa) > fabs(fb)) {
      double s=fb/fa;
      if (a == c) {
	p = 2.0*xm*s;
	q = 1.0-s;
      } else {
	q = fa/fc;
	r = fb/fc;
	p = s*(2.0*xm*q*(q-r)-(b-a)*(r-1.0));
	q = (q-1.0)*(r-1.0)*(s-1.0);
      }
      if (p > 0.0) q = -q;
      p = fabs(p);
      double min1 = 3.0*xm*q-fabs(tol1*q);
      double min2 = fabs(e*q);
      if (2.0*p < (min1 < min2 ? min1 : min2)) {
	e = d;
	d = p/q;
      } else {
	d = xm;
	e = d;
      }
    } else {
      d = xm;
      e = d;
    }

    a  = b;
    fa = fb;
    if (fabs(d) > tol1) {
      b += d;
    } else {
      b += NRSign(tol1,xm);
    }
    fb=f(b);
  }
  //  error("in ZBrent \""+name+"\" : Maximum number of iterations exceeded in ZBrent",ITMAX);
  throw("In ZBrent \""+name+"\" : Maximum number of iterations exceeded in ZBrent");
  return 0.0;
}


////////////////////////////////////////////////////////////////////////////////////////////////////

template <class T>
double FindRootStepThenBrent(double x1, const double tolX, T & f, double step, const string name="") {
  int n=1;
  double fx1 = f(x1);
  while (true) {
    double x2 = x1 + step;
    double fx2 = f(x2);
    if (fx1*fx2<0.0) return FindRootBrentsMethod(x1,x2,fx1,fx2,tolX,f,name);

    const int nMax = 100;
    if (++n>=nMax) error("FindRootStepThenBrent() exceeded number of allowed steps",nMax,name);

    if (fabs(fx2)>fabs(fx1)) {
      //      step *= -1.0; // do not make smaller b/c we only needs point with opposite sign of f to call Brent
      step *= -0.8; // make it a bit smaller so that we do not oscillate btewen only two points when we are near a minute 
      continue; // do not change x1
    }

    x1  = x2;
    fx1 = fx2;
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <class T>
double FindRootBisection(const double x1, 
			 const double x2, 
			 const double tolX,
			 const T & f) {
  if (tolX<0.0)
    error("The accuracy parameter must be positive",tolX);
  
  double f1=sign(f(x1));
  double f2=sign(f(x2));
  
  if (f1==f2)
    error("FindRootBisection: same sign at both interval boundaries",x1,x2,f1,f2);
  
  const double acc=fabs(x2-x1)*tolX;
  double xx1=x1;
  double xx2=x2;
  double xx;
  do {

    xx=(xx2+xx1)*0.5;
    double ff=sign(f(xx));

    if (ff==f1) xx1=xx;
    else xx2=xx;

    //    Write3(xx1,xx2,fabs(xx2-xx1)/acc);

  } while (fabs(xx2-xx1)>acc);
  return xx;
}

template <class T>
double FindRootRegulaFalsi(const double x1, 
			   const double x2, 
			   const double tolX,
			   const T & f) {
  if (tolX<0.0)
    error("The accuracy parameter must be positive",tolX);

  double f1=f(x1);
  double f2=f(x2);
  
  const double acc=fabs(x2-x1)*tolX;

  double xx1=x1;
  double xx2=x2;
  double xx;
  do {
    xx=(f1*xx2-f2*xx1)/(f1-f2);
    double ff=f(xx);
    xx1=xx2; f1=f2;
    xx2=xx;  f2=ff;
  } while (fabs(xx2-xx1)>acc);
  return xx;
}

// initial f1 and f2 are known - do not call f(x1) and f(x2)
// good for expensive f() functions
template <class T>
double FindRootRegulaFalsiSafely(const double x1, 
				 const double x2, 
				 double f1, 
				 double f2, 
				 const double relTolX,
				 T & f, const string name="", const bool print=false) {
  if (relTolX<0.0)
    error("The accuracy parameter must be positive",relTolX);

  int n=2;

  if (sign(f1)==sign(f2)) {
    cout << "FindRootRegulaFalsiSafely: " << name << endl;
    error("FindRootRegulaFalsiSafely: same sign at both interval boundaries",name,x1,x2,f1,f2);
  }

  const double acc=fabs(x2-x1)*relTolX;

  double xx1=x1;
  double xx2=x2;
  double xx;

  int nOneSideMax = 5;
  int nSide       = 0;
  do {
    if (abs(nSide)<nOneSideMax) { 
      xx=(f1*xx2-f2*xx1)/(f1-f2); // regular regular falsi
    } else {                      // do bisection step if things get too lop-sided
      xx=(xx1+xx2)/2.0;           // bisection step
      nSide = 0;
    }

    double ff=f(xx);
    n++;
    if (ff==0.0) return xx; // not sure if this will avoid getting stuck in all cases
    //    Write6(xx1,xx,xx2,f1,ff,f2);
    if (sign(ff)==sign(f1)) { // make sure that f1 and f2 keep having opposite signs
      f1  = ff;
      xx1 = xx;
      nSide--;
    } else {
      f2  = ff;
      xx2 = xx;
      nSide++;
    }
    if (print) { 
      bool flag = (fabs(xx2-xx1)<acc);
      cout << name; Write8(n,xx1,xx2,fabs(xx2-xx1),acc,flag,f1,f2); 
    }
  } while (fabs(xx2-xx1)>acc);
  return xx;
}

template <class T>
double FindRootRegulaFalsiSafely(const double x1, 
				 const double x2, 
				 const double relTolX,
				 T & f, const string name="", const bool print=false) {
  double f1=f(x1);
  double f2=f(x2);
  return FindRootRegulaFalsiSafely(x1,x2,f1,f2,relTolX,f,name,print);
}

/////////////////////////////////////////////////////////////////////////////////////////////

template <class T>
double FindRootRegulaFalsiThrow(const double x1, 
				const double x2, 
				double f1, 
				double f2, 
				const double relTolX,
				T & f, const string name="", const bool print=false) {
  if (relTolX<0.0)
    error("The accuracy parameter must be positive",relTolX);

  int n=2;

  if (sign(f1)==sign(f2)) {
    //    cout << "FindRootRegulaFalsiSafely: " << name << endl;
    //    error("FindRootRegulaFalsiSafely: same sign at both interval boundaries",name,x1,x2,f(x1),f(x2));
    string s = "FindRootRegulaFalsiSafely: same sign at both interval boundaries for \""+name+"\"";
    s += " x1= "+DoubleToString(x1);
    s += " x2= "+DoubleToString(x2);
    s += " f(x1)= "+DoubleToString(f1);
    s += " f(x2)= "+DoubleToString(f2);
    cout << s << endl;
    cout.flush();
    FindRootProblem frp(s,x1,x2,f1,f2);
    //    throw(s.c_str());
    throw(frp);
  }

  const double acc=fabs(x2-x1)*relTolX;

  double xx1=x1;
  double xx2=x2;
  double xx;

  int nOneSideMax = 5;
  int nSide       = 0;
  do {
    if (abs(nSide)<nOneSideMax) { 
      xx=(f1*xx2-f2*xx1)/(f1-f2); // regular regular falsi
    } else {                      // do bisection step if things get too lop-sided
      xx=(xx1+xx2)/2.0;           // bisection step
      nSide = 0;
    }

    double ff=f(xx);
    n++;
    if (ff==0.0) return xx; // not sure if this will avoid getting stuck in all cases
    if (print) { cout << name; Write6(xx1,xx,xx2,f1,ff,f2); }
    if (sign(ff)==sign(f1)) { // make sure that f1 and f2 keep having opposite signs
      f1  = ff;
      xx1 = xx;
      nSide--;
    } else {
      f2  = ff;
      xx2 = xx;
      nSide++;
    }
    if (print) { 
      bool flag = (fabs(xx2-xx1)<acc);
      cout << name; Write8(n,xx1,xx2,fabs(xx2-xx1),acc,flag,f1,f2); 
    }
  } while (fabs(xx2-xx1)>acc);
  return xx;
}

template <class T>
double FindRootRegulaFalsiThrow(const double x1, 
				const double x2, 
				const double relTolX,
				T & f, const string name="", const bool print=false) {
  double f1=f(x1);
  double f2=f(x2);
  return FindRootRegulaFalsiThrow(x1,x2,f1,f2,relTolX,f,name,print);
}
////////////////////////////////////////////////////////////////////////////////////////////

template <class T>
double FindRootStepper(const double x1, 
		       const double x2, 
		       const double tolX,
		       const double tolF,
		       const T & f) {
  if (tolX<0.0)
    error("The accuracy parameter must be positive",tolX);
  if (tolF<0.0)
    error("The accuracy parameter must be positive",tolF);

  const double r=0.2;
  double f1=f(x1);
  double f2=f(x2);

  if (sign(f1)!=sign(f2)) 
    return FindRootBisection(x1,x2,tolX,f);

  double xx,dx,ff;
  if (fabs(f1)<fabs(f2)) {
    xx=x1;
    dx=(x2-x1)*r;
    ff=f1;
  } else {
    xx=x2;
    dx=(x1-x2)*r;
    ff=f2;
  }

  for(;;) {
    const double xxn=xx+dx;
    const double ffn=f(xxn);
    //    Write5(xx,xxn,dx,ff,ffn);

    if (sign(ffn)!=sign(ff))
      return FindRootBisection(xx,xxn,tolX,f);

    if (sign(ffn)!=sign(ff) || fabs(ffn)>fabs(ff)) {
      //      Write2(int(sign(ffn)!=(ff)),int(fabs(ffn)>fabs(ff)));
      dx *= -r;
      if (fabs(dx)<tolX && fabs(ffn)<tolF) {
	return (xxn+xx)*0.5;
      }

      const double overrun=1e-8;
      if (fabs(dx)<tolX*overrun) 
	error("FindRootStepper: Could not find root (dx got too small)",dx,ffn);
    }
    xx=xxn;
    ff=ffn;
  }
  return 0.0;
}

template <class T>
double FindRootStepper(const double x, 
		       const double tolX,
		       const double tolF,
		       const T & f) {
  const double r=0.1;
  return FindRootStepper( (1.0-r)*x, (1.0+r)*x, tolX, tolF, f );
}

template <class T>
double FindRootStepperThenRegulaFalsi(const double x1, 
				      const double x2, 
				      const double relTolX,
				      const double tolF,
				      const T & f, const string name="", const bool print=false) {
  if (relTolX<0.0) error("The accuracy parameter must be positive",relTolX);

  const double r=0.2;
  double f1=f(x1);
  double f2=f(x2);

  if (sign(f1)!=sign(f2)) return FindRootRegulaFalsiSafely(x1,x2,f1,f2,relTolX,f,name,print);

  double xx,dx,ff;
  if (fabs(f1)<fabs(f2)) {
    xx=x1;
    dx=(x2-x1)*r;
    ff=f1;
  } else {
    xx=x2;
    dx=(x1-x2)*r;
    ff=f2;
  }

  for(;;) {
    const double xxn=xx+dx; 
    //    Write3(xx,xxn,dx);
    const double ffn=f(xxn);
    //    Write5(xx,xxn,dx,ff,ffn);

    if (sign(ffn)!=sign(ff)) return FindRootRegulaFalsiSafely(xx,xxn,ff,ffn,relTolX,f,name,print);

    if (fabs(ffn)>fabs(ff)) {
      dx *= -r; // shall we go the other way?
      if (fabs(dx)<xx*relTolX && fabs(ffn)<tolF) {
	return (xxn+xx)*0.5;
      }

      const double overrun=1e-8;
      if (fabs(dx)<xx*relTolX*overrun) 
	error("FindRootStepper: Could not find root (dx got too small)",dx,ffn);
    }
    xx=xxn;
    ff=ffn;
  }
  return 0.0;
}

////////////////////////////////////////////////////////////////////////////////////////////

// use bisection directly if possible
// survey the provided interval and then use bisection
// then look outside the provided interval
template <class T>
double FindRootBrancher(const double x1, 
			const double x2, 
			const double tolX,
			const double tolF,
			const T & f) {
  if (tolX<0.0)
    error("The accuracy parameter must be positive",tolX);
  if (tolF<0.0)
    error("The accuracy parameter must be positive",tolF);

  double f1=f(x1);
  double f2=f(x2);

  if (sign(f1)!=sign(f2)) 
    return FindRootBisection(x1,x2,tolX,f);

  // are there any solution inside the interval?
  const int nn=10;
  for(int i=1; i<nn; i++) {
    double xx = x1+(x2-x1)*i/nn;
    double ff = f(xx);
    if (sign(ff)!=sign(f1)) {
      if (i<=nn/2) return FindRootBisection(x1,xx,tolX,f);
      else return FindRootBisection(xx,x2,tolX,f);
    }
  }

  // now look outside
  const int nTrials=100;
  double xx1=x1;
  double xx2=x2;
  for(int i=0; i<nTrials; i++) {
    xx1 /= (1.0+1.0/nn);
    xx2 *= (1.0+1.0/nn);
    double ff1 = f(xx1);
    double ff2 = f(xx2);
    if (sign(f1)!=sign(ff1)) return FindRootBisection(x1,xx1,tolX,f);
    if (sign(f2)!=sign(ff2)) return FindRootBisection(x2,xx2,tolX,f);
  }

  error("Could not find any root in and outside of interval",x1,x2,xx1,xx2);
  return 0.0;
}

template <class T>
double FindRootBrancher(const double x, 
			const double tolX,
			const double tolF,
			const T & f) {
  const double r=0.1;
  return FindRootBrancher( (1.0-r)*x, (1.0+r)*x, tolX, tolF, f );
}

///////////////////////////////////////////////////////////////////////////////////

template <class T> 
class EqualsConstant {
 public:
  EqualsConstant(const T & t_, const double fTarget_):t(t_),fTarget(fTarget_){};
  const T & t;
  const double fTarget;
  double operator()(const double x) const {
    return t(x)-fTarget;
  }
};

template <class T>
double FindRootBisection(const double x1, 
			 const double x2, 
			 const double tolX,
			 const double fTarget,
			 const T & f) {
  EqualsConstant <T> ec(f,fTarget);
  return FindRootBisection(x1,x2,tolX,ec);
}

template <class T>
double FindRootRegulaFalsi(const double x1, 
			   const double x2, 
			   const double tolX,
			   const double fTarget,
			   const T & f) {
  EqualsConstant <T> ec(f,fTarget);
  return FindRootRegulaFalsi(x1,x2,tolX,ec);
}

template <class T>
double FindRootStepper(const double x1, 
		       const double x2, 
		       const double tolX,
		       const double tolF,
		       const double fTarget,
		       const T & f) {
  EqualsConstant <T> ec(f,fTarget);
  return FindRootStepper(x1,x2,tolX,tolF,ec);
}

/*
template <class T>
double FindRootStepper(const double x, 
		       const double tolX,
		       const double tolF,
		       const double fTarget,
		       const T & f) {
  EqualsConstant <T> ec(f,fTarget);
  return FindRootStepper(x,tolX,tolF,ec);
}
*/

template <class T>
double FindRootBrancher(const double x1, 
			const double x2, 
			const double tolX,
			const double tolF,
			const double fTarget,
			const T & f) {
  EqualsConstant <T> ec(f,fTarget);
  return FindRootBrancher(x1,x2,tolX,tolF,ec);
}

#endif // _FINDROOT_
