// Copyright 2019
//
// Alex G. Harvey with ontributions from Danilo S. Brambila and Zdenek Masin.
//
// This file is part of UKRmol-out (UKRmol+ suite).
//
//     UKRmol-out is free software: you can redistribute it and/or modify
//     it under the terms of the GNU General Public License as published by
//     the Free Software Foundation, either version 3 of the License, or
//     (at your option) any later version.
//
//     UKRmol-out is distributed in the hope that it will be useful,
//     but WITHOUT ANY WARRANTY; without even the implied warranty of
//     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//     GNU General Public License for more details.
//
//     You should have received a copy of the GNU General Public License
//     along with  UKRmol-out (in source/COPYING). Alternatively, you can also visit
//     <https://www.gnu.org/licenses/>.
//


#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
#include <vector>

#include <ctime>

#include <mblas_gmp.h>
#include <mlapack_gmp.h>

using namespace std;


extern "C" { 
    void compak_multi_prec(int *nchan_f, int *nopen_f,double *rmatr_f, double *rafinv_f,
                       double *rmat_a_f, double *crv_f, double *fx_f, double *fxp_f,
                       double *open_kmatrix_f,double *full_kmatrix_f, double *wamp,
                       int *nocsf_f, double *eig, double *etarg, double *echl, int *nrk,
                       double *en_f,int *ntarg_f, double *akr, double *aki);
}

void calculate_kmatrix(int nchan, mpf_class *rmat, mpf_class *fx_1, mpf_class *fx_2,  mpf_class *fxp_1, mpf_class *fxp_2, mpf_class *K_matrix)
{

    mpackint noelm = nchan*nchan;
    mpf_class *AA = new mpf_class[noelm];
    mpf_class *BB = new mpf_class[noelm];
    mpf_class df;
    

    for(int j=0;j<nchan;j++) {   
        for(int i=0;i<nchan;i++) {
            AA[i+nchan*j]=fx_2[i+nchan*j];     
        }
    }

    for(int j=0;j<nchan;j++) {
        for(int k=0;k<nchan;k++) {
            df =  fxp_2[k+nchan*j];
            for(int i=0;i<nchan;i++) {
                AA[i+nchan*j] = AA[i+nchan*j]-rmat[i+nchan*k]*df;             
            }
        }
    }
    for(int j=0;j<nchan;j++) {   
        for(int i=0;i<nchan;i++) {
            BB[i+nchan*j]=-fx_1[i+nchan*j];     
        }
    }
    for(int j=0;j<nchan;j++) {
        for(int k=0;k<nchan;k++) {
            df =  fxp_1[k+nchan*j];
            for(int i=0;i<nchan;i++) {
                BB[i+nchan*j] = BB[i+nchan*j]+rmat[i+nchan*k]*df;             
            }
        }
    }

    if (nchan==1) {
        BB[1]=BB[1]/AA[1];
    }
    else {
        mpackint info;
        mpackint *ipiv = new mpackint[nchan];
        Rgetrf(nchan, nchan, AA, nchan, ipiv, &info);     
        Rgetrs("n",nchan,nchan, AA, nchan, ipiv, BB, nchan, &info);
        delete[]ipiv;
      
    }  
    for(int i=0;i<nchan*nchan;i++) {       
        K_matrix[i]=BB[i];
    }    
    delete[]AA;
    delete[]BB;
}

void forward_propagate_rmatrix(int nchan, mpf_class *rmat_a, mpf_class *r11, mpf_class *r12,  mpf_class *r21, mpf_class *r22, mpf_class *rmat_c)
{
    mpackint noelm=nchan*nchan;
    mpf_class *A = new mpf_class[noelm];

    for (int i=0;i<noelm;i++)
    {    
        A[i]=rmat_a[i]+r11[i];
    }


    mpackint lwork, info;
    mpackint *ipiv = new mpackint[nchan];

//  work space query
    lwork = -1;
    mpf_class *work = new mpf_class[1];

    Rgetri(nchan, A, nchan, ipiv, work, lwork, &info);
    lwork = (int) work[0].get_d();
    delete[]work;
    work = new mpf_class[std::max(1, (int) lwork)];
    
//  invert matrix
    Rgetrf(nchan, nchan, A, nchan, ipiv, &info);
    Rgetri(nchan, A, nchan, ipiv, work, lwork, &info);

    mpf_class *temp_rmat_c = new mpf_class[noelm];  
    
    Rgemm("N","N",nchan,nchan,nchan,1.0,A,nchan,r12,nchan,0.0,temp_rmat_c,nchan);
    Rgemm("N","N",nchan,nchan,nchan,1.0,r21,nchan,temp_rmat_c,nchan,0.0,rmat_c,nchan);

    for (int i=0;i<noelm;i++)
    {    
        rmat_c[i]=r22[i]-rmat_c[i];
    }
    
//     printf("invA =");
//     printmat(nchan, nchan, rmat_c, nchan);
//     printf("\n");
    
    
    delete[]work;
    delete[]ipiv;
    delete[]A;
    delete[]temp_rmat_c;
}

void back_propagate_radial_functions(int nchan, mpf_class *rmat_a, mpf_class *rmat_c, 
                                     mpf_class *r11, mpf_class *r12,  mpf_class *r21, mpf_class *r22,
                                     mpf_class *fx_1, mpf_class *fx_2,  mpf_class *fxp_1, mpf_class *fxp_2,
                                     mpf_class *fx_prop_1, mpf_class *fx_prop_2,  mpf_class *fxp_prop_1, mpf_class *fxp_prop_2)
{
  
    mpackint noelm=nchan*nchan;    
    mpf_class *rhs_1 = new mpf_class[noelm];
    mpf_class *rhs_2 = new mpf_class[noelm];
    mpf_class *lhs = new mpf_class[noelm];

    
    for(int j=0;j<nchan;j++) {   
        for(int i=0;i<nchan;i++) {
            rhs_1[i+nchan*j]=fx_1[i+nchan*j];
            rhs_2[i+nchan*j]=fx_2[i+nchan*j];
        }
    }  

//  Trying just using fxp
//  ---------------------
//     Rgemm("N","N",nchan,nchan,nchan,1.0,rmat_c,nchan,fxp_1,nchan,0.0,rhs_1,nchan);
//     Rgemm("N","N",nchan,nchan,nchan,1.0,rmat_c,nchan,fxp_2,nchan,0.0,rhs_2,nchan);      
//  ---------------------    

    Rgemm("N","N",nchan,nchan,nchan,1.0,r22,nchan,fxp_1,nchan,-1.0,rhs_1,nchan);
    Rgemm("N","N",nchan,nchan,nchan,1.0,r22,nchan,fxp_2,nchan,-1.0,rhs_2,nchan);   

    for(int i=0;i<nchan*nchan;i++) {       
        lhs[i]=r21[i]  ;
    }
    
//     mpf_class *fxp_prop_1 = new mpf_class[noelm];
//     mpf_class *fxp_prop_2 = new mpf_class[noelm];

    mpackint info;
    mpackint *ipiv = new mpackint[nchan];
    Rgetrf(nchan, nchan, lhs, nchan, ipiv, &info); 
    Rgetrs("n",nchan,nchan, lhs, nchan, ipiv, rhs_1, nchan, &info);
    Rgetrs("n",nchan,nchan, lhs, nchan, ipiv, rhs_2, nchan, &info);
    delete[]ipiv; 
    

    for(int i=0;i<nchan*nchan;i++) {       
        fxp_prop_1[i]=rhs_1[i];
        fxp_prop_2[i]=rhs_2[i];
    }    
    
//     mpf_class *fx_prop_1 = new mpf_class[noelm];
//     mpf_class *fx_prop_2= new mpf_class[noelm];    

//   fx_prop_1=r12*fxp_square_1 - r11*fxp_prop_1
//     fx_prop_2=r12*fxp_square_2 - r11*fxp_prop_2   


    Rgemm("N","N",nchan,nchan,nchan,1.0,r11,nchan,fxp_prop_1,nchan,0.0,rhs_1,nchan);
    Rgemm("N","N",nchan,nchan,nchan,1.0,r11,nchan,fxp_prop_2,nchan,0.0,rhs_2,nchan);
    
    Rgemm("N","N",nchan,nchan,nchan,1.0,r12,nchan,fxp_1,nchan,-1.0,rhs_1,nchan);
    Rgemm("N","N",nchan,nchan,nchan,1.0,r12,nchan,fxp_2,nchan,-1.0,rhs_2,nchan);       

    for(int i=0;i<nchan*nchan;i++) {       
        fx_prop_1[i]=rhs_1[i];
        fx_prop_2[i]=rhs_2[i];
    }
    
    delete[]lhs;
    delete[]rhs_1;
    delete[]rhs_2;
  
}
void back_propagate_radial_functions_full(int nchan, int nopen,mpf_class *rmat_a, mpf_class *rmat_c,mpf_class *K_matrix, 
                                     mpf_class *r11, mpf_class *r12,  mpf_class *r21, mpf_class *r22,
                                     mpf_class *fx_1, mpf_class *fx_2,  mpf_class *fxp_1, mpf_class *fxp_2,
                                     mpf_class *full_fx_prop, mpf_class *full_fxp_prop)
{
  
    mpf_class *full_fx = new mpf_class[nchan*nchan];
    mpf_class *full_fxp = new mpf_class[nchan*nchan];

//  First construct the K-matrix radial functions
    for(int i=0;i<nchan*nopen;i++) {   
     
         full_fx[i]=fx_1[i];
         full_fxp[i]=fxp_1[i]; 
    }
    Rgemm("N","N",nchan,nopen,nchan,1.0,fxp_2,nchan,K_matrix,nchan,1.0,full_fxp,nchan);
    Rgemm("N","N",nchan,nopen,nchan,1.0,fx_2, nchan,K_matrix,nchan,1.0,full_fx,nchan);

    mpackint noelm=nchan*nchan;    
    mpf_class *rhs_1 = new mpf_class[noelm];
    mpf_class *rhs_2 = new mpf_class[noelm];
    mpf_class *lhs = new mpf_class[noelm];         
         
    
    for(int j=0;j<nchan;j++) {   
        for(int i=0;i<nchan;i++) {
            rhs_1[i+nchan*j]=full_fx[i+nchan*j];
        }
    }  

    Rgemm("N","N",nchan,nchan,nchan,1.0,r22,nchan,full_fxp,nchan,-1.0,rhs_1,nchan);

    for(int i=0;i<nchan*nchan;i++) {       
        lhs[i]=r21[i]  ;
    }
//     mpf_class *fxp_prop_1 = new mpf_class[noelm];
//     mpf_class *fxp_prop_2 = new mpf_class[noelm];

    mpackint info;
    mpackint *ipiv = new mpackint[nchan];
    Rgetrf(nchan, nchan, lhs, nchan, ipiv, &info); 
    Rgetrs("n",nchan,nchan, lhs, nchan, ipiv, rhs_1, nchan, &info);
    delete[]ipiv; 

    for(int i=0;i<nchan*nchan;i++) {       
        full_fxp_prop[i]=rhs_1[i];
    }    
//     mpf_class *fx_prop_1 = new mpf_class[noelm];
//     mpf_class *fx_prop_2= new mpf_class[noelm];    

//   fx_prop_1=r12*fxp_square_1 - r11*fxp_prop_1
//     fx_prop_2=r12*fxp_square_2 - r11*fxp_prop_2   


    Rgemm("N","N",nchan,nchan,nchan,1.0,r11,nchan,full_fxp_prop,nchan,0.0,rhs_1,nchan);
    
    Rgemm("N","N",nchan,nchan,nchan,1.0,r12,nchan,full_fxp,nchan,-1.0,rhs_1,nchan);     
    for(int i=0;i<nchan*nchan;i++) {       
        full_fx_prop[i]=rhs_1[i];
    }

//  Take out the closed channels part
//     for(int j=0;j<nchan;j++) {   
//         for(int i=0;i<nchan;i++) {
//             if (i>=nopen) {
//             full_fx_prop[i+nchan*j]=0;
//             full_fxp_prop[i+nchan*j]=0;
//             }
//         }
//     }  


    delete[]lhs;
    delete[]rhs_1;
    delete[]rhs_2;
  
}

void triangle_to_square(vector<double> &A)
{

    int ntriang =A.size();
    int n=(sqrt(8*ntriang+1)-1)/2;
    vector<double> B(n*n);

    int m=0;
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<i+1;j++)
        {
            B.at(i+n*j)=A.at(m);
            B.at(j+n*i)=A.at(m);
            m++;
        }
    }

    A.swap(B);
}

void column_to_row_order(vector<double> &A)
{
    vector<double> B;
    int nsquared =A.size();
    int n=sqrt(nsquared);

    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            B.push_back(A.at(i+n*j));
        }
    }

    A.swap(B);
}
//Matlab/Octave format
bool replace(std::string& str, const std::string& from, const std::string& to) {
    size_t start_pos = str.find(from);
    if(start_pos == std::string::npos)
        return false;
    str.replace(start_pos, from.length(), to);
    return true;
}


void printmat(int N, int M, mpf_class * A, int LDA)
{
    mpf_class mtmp;

    printf("[ ");
    for (int i = 0; i < N; i++) {
      if (i==0)
	    printf("[ ");
      else
          printf("[   ");
	for (int j = 0; j < M; j++) {
	    mtmp = A[i + j * LDA];
	    gmp_printf("% -8.5Fe", mtmp.get_mpf_t());
	    if (j < M - 1)
		printf(", ");
	}
	if (i < N - 1){
	    printf("]; ");
          cout << endl;
      }
	else
	    printf("] ");
    }
    printf("]");
}

void cprintmat(int N, int M, mpc_class * A, int LDA)
{
    mpc_class mtmp;

    printf("[ ");
    for (int i = 0; i < N; i++) {
      if (i==0)
          printf("[ ");
      else
          printf("[   ");
      for (int j = 0; j < M; j++) {
          mtmp = A[i + j * LDA];
//           gmp_printf("% -5.2Fe", mtmp.real().get_t());
//           gmp_printf("% -5.2Fe", mtmp.imag().get_t());
          cout <<mtmp.real().get_d() << "  i "<< mtmp.imag().get_d();
//           printf("% -20.5e", mtmp.real().get_d());
//           cout << "  i ";
//           printf("% -20.5e", mtmp.imag().get_d());
//           if (j < M - 1)
            printf(", ");
      }
      if (i < N - 1){
          printf("]; ");
          cout << endl;
      }
      else
          printf("] ");
    }
    printf("]");
}



void read_rmatrix_data(int &nchan, int &nopen, double &rmatr, double &rafinv,  vector<double> &everything)
{
   
   string line;
   ifstream myfile ("test.input");
   if (myfile.is_open())
   {
     myfile >> nchan >> nopen >> rmatr >> rafinv;

     cout << nchan << '\n';
     cout << nopen << '\n';
     cout << rmatr << '\n';
     cout << rafinv << '\n';

     int itemp=7*nchan*nchan+nchan;
     double temp[itemp];

     for (int i=0;i<itemp;i++)
     {
         myfile >>   temp[i];

         everything.push_back(temp[i]);
     }
//     while ( getline (myfile,line) )
//     {
//       replace(line,"D","E");
//        cout << line << endl;
//            /* to read say "a(5) = 1234" */
//     }
     myfile.close();

   }

   else cout << "Unable to open file"; 

}
void compak_multi_prec(int *nchan_f, int *nopen_f,double *rmatr_f, double *rafinv_f,
                       double *rmat_a_f, double *crv_f, double *fx_f, double *fxp_f,
                       double *open_kmatrix_f, double *full_kmatrix_f, double *wamp,
                       int *nocsf_f, double *eig, double *etarg, double *echl,int *nrk,
                       double *en_f,int *ntarg_f, double *akr, double *aki)
{

   // INOUT
   // double *fx_f, double *fxp_f,
  
   // OUT
   // double *open_kmatrix_f, double *full_kmatrix_f, double *akr, double *aki

    int default_prec = 512; // quad-quad precision
//     int default_prec = 256; // double-quad precision
//     int default_prec = 128; // quad precision
//     int default_prec = 64;  // double precision
    
clock_t    start;
start = clock();
//  Do I need 64 bit ints? i.e. int64_t nchan ?
  
    int nchan=*nchan_f;
    int nopen=*nopen_f;
    int nocsf=*nocsf_f;
    int ntarg=*ntarg_f;
    double rmatr= *rmatr_f;
    double rafinv= *rafinv_f;
    double en=*en_f;
    
    
    int nsize_crv=2*nchan*nchan+nchan;//nchan*nchan+2*nchan*nchan+nchan+2*nchan*nchan+2*nchan*nchan;

    vector<double> vec_rmat_a, vec_crv, vec_fx, vec_fxp;
    
    vec_rmat_a.assign(rmat_a_f, rmat_a_f + nchan*nchan);
    vec_crv.assign(crv_f, crv_f + nsize_crv);
    vec_fx.assign(fx_f, fx_f + 2*nchan*nchan);
    vec_fxp.assign(fxp_f, fxp_f + 2*nchan*nchan);

//  Unpack curly R matrices used for propagation

    int ir11=0;
    int ir12=ir11+(nchan*nchan+nchan)/2;
    int ir22=ir12+nchan*nchan;

    vector<double> vec_r11, vec_r12, vec_r21, vec_r22; 

    vec_r11.insert(vec_r11.begin(),vec_crv.begin()+ir11, vec_crv.begin()+ir12);
    vec_r12.insert(vec_r12.begin(),vec_crv.begin()+ir12, vec_crv.begin()+ir22);
    vec_r21.insert(vec_r21.begin(),vec_crv.begin()+ir12, vec_crv.begin()+ir22);
    vec_r22.insert(vec_r22.begin(),vec_crv.begin()+ir22, vec_crv.end());

//  Unpack regular and irregular coulomb functions

    vector<double> vec_fx_1, vec_fx_2, vec_fxp_1, vec_fxp_2; 

    vec_fx_1.insert(vec_fx_1.begin(),vec_fx.begin(), vec_fx.begin()+nchan*nchan);
    vec_fx_2.insert(vec_fx_2.begin(),vec_fx.begin()+nchan*nchan, vec_fx.end());
    vec_fxp_1.insert(vec_fxp_1.begin(),vec_fxp.begin(), vec_fxp.begin()+nchan*nchan);
    vec_fxp_2.insert(vec_fxp_2.begin(),vec_fxp.begin()+nchan*nchan, vec_fxp.end());


//  Unpack triangular matrices. Transpose r21

    column_to_row_order(vec_r21);
    triangle_to_square(vec_r11);
    triangle_to_square(vec_r22);

//  Now we need to convert to multiprecision matrices

//  initialization of GMP


    mpf_set_default_prec(default_prec);

    mpackint noelm = nchan*nchan;

    mpf_class *rmat_a = new mpf_class[noelm];
    mpf_class *r11 = new mpf_class[noelm];
    mpf_class *r12 = new mpf_class[noelm];
    mpf_class *r21 = new mpf_class[noelm];
    mpf_class *r22 = new mpf_class[noelm];
    mpf_class *fx_1 = new mpf_class[noelm];
    mpf_class *fx_2 = new mpf_class[noelm];
    mpf_class *fxp_1 = new mpf_class[noelm];
    mpf_class *fxp_2 = new mpf_class[noelm];
    


    for (int i=0;i<noelm;i++)
    {    
       rmat_a[i]=vec_rmat_a.at(i);
       r11[i]=vec_r11.at(i);
       r12[i]=vec_r12.at(i);
       r21[i]=vec_r21.at(i);
       r22[i]=vec_r22.at(i);
       fx_1[i]=vec_fx_1.at(i);
       fx_2[i]=vec_fx_2.at(i);
       fxp_1[i]=vec_fxp_1.at(i);
       fxp_2[i]=vec_fxp_2.at(i);
    }

//       for (int i=0;i<nchan;i++) {
//         for (int j=0;i<nchan;j++) {
//            cout << fx_1[i+nchan*j].get_d();
//         }
//         cout << endl;
//       }

    
    
//     cout << " Conversion to multiprecision complete." << endl;
// 
//     printf("R-matrix at a =");
//     cout << endl;
//     printmat(nchan, nchan, rmat_a, nchan);
//     printf("\n");

cout << "Time 1: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;

//  Propagate R-matrix to matching radius.    
//  --------------------------------------
    mpf_class *rmat_c = new mpf_class[noelm];

    if (rafinv > rmatr){
       forward_propagate_rmatrix(nchan, rmat_a, r11, r12, r21, r22, rmat_c);
    }
    else {
       for(int i=0;i<nchan*nchan;i++) {       
          rmat_c[i]=rmat_a[i];
       } 
    }

cout << "Time: 2 " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;    

//  Calculate K-matrix at matching radius.    
//  --------------------------------------
 
    
    mpf_class *K_matrix_c = new mpf_class[noelm];
    
    calculate_kmatrix(nchan, rmat_c, fx_1, fx_2, fxp_1, fxp_2, K_matrix_c);  
        
//     printf("K-matrix at c =");
//     cout << endl;
//     printmat(nchan, nchan, K_matrix_c, nchan);
//     printf("\n");

cout << "Time 3: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;

//  Back propagation of the radial functions
//  ----------------------------------------

    mpf_class *fxp_prop_1 = new mpf_class[noelm];
    mpf_class *fxp_prop_2 = new mpf_class[noelm];
    mpf_class *fx_prop_1 = new mpf_class[noelm];
    mpf_class *fx_prop_2= new mpf_class[noelm];  

    if (rafinv > rmatr)
    {    
       back_propagate_radial_functions(nchan, rmat_a, rmat_c, 
                                       r11, r12, r21, r22,
                                       fx_1, fx_2, fxp_1, fxp_2,
                                       fx_prop_1, fx_prop_2, fxp_prop_1, fxp_prop_2);
    }
    else {
       for(int i=0;i<nchan*nchan;i++) {   
          fx_prop_1[i]=fx_1[i];
          fx_prop_2[i]=fx_2[i];
          fxp_prop_1[i]=fxp_1[i];
          fxp_prop_2[i]=fxp_2[i];
       }
    }

cout << "Time 4: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;    
//  K-matrix at the R-matrix boundary
//  ---------------------------------
    mpf_class *K_matrix_a = new mpf_class[noelm];
    mpf_class *open_K_matrix = new mpf_class[noelm];
    
    calculate_kmatrix(nchan, rmat_a, fx_prop_1, fx_prop_2, fxp_prop_1, fxp_prop_2, K_matrix_a);  

 
cout << "Time 5: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;
//     printf("K-matrix at a =");
//     cout << endl;
//     printmat(nchan, nchan, K_matrix_a, nchan);
//     printf("\n");

//  Assign calculated quantites to double precision output variables
//  
    mpf_class mtmp;
    for(int i=0;i<nchan*nchan;i++) {
        mtmp=fx_prop_1[i];
        fx_f[i]=mpf_get_d(mtmp.get_mpf_t());

        mtmp=fx_prop_2[i];
        fx_f[i+nchan*nchan]=mpf_get_d(mtmp.get_mpf_t());

        mtmp=fxp_prop_1[i];
        fxp_f[i]=mpf_get_d(mtmp.get_mpf_t());

        mtmp=fxp_prop_2[i];
        fxp_f[i+nchan*nchan]=mpf_get_d(mtmp.get_mpf_t());

        mtmp=K_matrix_a[i];
        full_kmatrix_f[i]=mpf_get_d(mtmp.get_mpf_t());
    }
    
    for(int j=0;j<nopen;j++) {     
        for(int i=0;i<nopen;i++) {     
           open_kmatrix_f[i+nopen*j]=full_kmatrix_f[i+nchan*j];
           open_K_matrix[i+nopen*j]=K_matrix_a[i+nchan*j];
        }  
      
    }  
cout << "Time 6: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;    
//   Apply photoionisation boundary conditions to radial functions
//   -------------------------------------------------------------

    mpackint ibctyp=1;
    mpf_class gamma=0.0;
     
    mpackint noelm_open = nopen*nopen;
    mpc_class *matrix_identity = new mpc_class[noelm_open];
    mpc_class *zidentity_plusminus_iK = new mpc_class[noelm_open];

    for(int i=0;i<nopen;i++) {
         // check that mpz_class initalise to zero!!
        zidentity_plusminus_iK[i+i*nopen].real()=1.0;
    }
    
    for(int j=0;j<nopen;j++) {    
        for(int i=0;i<nopen;i++) {
         // check that mpz_class initalise to zero!!
            zidentity_plusminus_iK[i+j*nopen].imag()=ibctyp*open_K_matrix[i+nopen*j];
        }
    }
    
    mpackint info;
    mpackint *ipiv = new mpackint[nchan];
    Cgetrf(nopen, nopen, zidentity_plusminus_iK, nopen, ipiv, &info);     

    mpc_class *fx_times_fkmat_all = new mpc_class[nchan*nopen];
    mpc_class *fxp_times_fkmat_all = new mpc_class[nchan*nopen];
//     mpf_class *fx_times_fkmat_all = new mpf_class[nchan*nopen];
//     mpf_class *fxp_times_fkmat_all = new mpf_class[nchan*nopen];    
    mpc_class *complex_fx_1 = new mpc_class[nchan*nchan];
    mpc_class *complex_fx_2 = new mpc_class[nchan*nchan];    
    mpc_class *complex_fxp_1 = new mpc_class[nchan*nchan];
    mpc_class *complex_fxp_2 = new mpc_class[nchan*nchan];    
    mpc_class *complex_K_matrix = new mpc_class[nchan*nopen];
    
    for(int i=0;i<nchan*nchan;i++) {
        complex_fx_1[i].real()=fx_prop_1[i];
        complex_fx_2[i].real()=fx_prop_2[i];
        complex_fxp_1[i].real()=fxp_prop_1[i];
        complex_fxp_2[i].real()=fxp_prop_2[i];
 
    }   
    
    for(int i=0;i<nchan*nopen;i++) {   
        complex_K_matrix[i].real()=K_matrix_a[i];
//         complex_K_matrix[i].real()=0;
    }

    mpc_class alpha;
    mpc_class beta;
    alpha.real()=1.0;
    beta.real()=0.0;
//     mpf_class alpha;
//     mpf_class beta;
//     alpha=1.0;
//     beta=0.0;    

cout << "Time 7: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl; 

    Cgemm("N","N",nchan,nopen,nchan,alpha,complex_fxp_2,nchan,complex_K_matrix,nchan,beta,fxp_times_fkmat_all,nchan);
    Cgemm("N","N",nchan,nopen,nchan,alpha,complex_fx_2, nchan,complex_K_matrix,nchan,beta,fx_times_fkmat_all,nchan);
//     Rgemm("N","N",nchan,nopen,nchan,alpha,fxp_prop_2,nchan,K_matrix_a,nchan,beta,fxp_times_fkmat_all,nchan);
//     Rgemm("N","N",nchan,nopen,nchan,alpha,fx_prop_2, nchan,K_matrix_a,nchan,beta,fx_times_fkmat_all,nchan);  

cout << "Time 8: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;

    for(int i=0;i<nchan*nopen;i++) {   
     
         fx_times_fkmat_all[i]+=complex_fx_1[i];
         fxp_times_fkmat_all[i]+=complex_fxp_1[i]; 
//          fx_times_fkmat_all[i]+=fx_prop_1[i];
//          fxp_times_fkmat_all[i]+=fxp_prop_1[i];          
         
    }
//    printf("fx_times_fkmat_all=");
//     cout << endl;
//     cprintmat(nchan, nopen, fx_times_fkmat_all, nchan);
// //     printmat(nchan, nopen, fx_times_fkmat_all, nchan);    
//     printf("\n");    
//     printf("fxp_times_fkmat_all =");
//     cout << endl;
//     cprintmat(nchan, nopen, fxp_times_fkmat_all, nchan);
// //     printmat(nchan, nopen, fxp_times_fkmat_all, nchan);
//     printf("\n"); 

//   Transpose
    mpc_class *fx_times_fkmat_all_trans = new mpc_class[nchan*nopen];
    mpc_class *fxp_times_fkmat_all_trans = new mpc_class[nchan*nopen];


//
    
    int k=0;
    for(int i=0;i<nchan;i++)
    {
        for(int j=0;j<nopen;j++)
        {
            fx_times_fkmat_all_trans[k]=fx_times_fkmat_all[i +nchan*j];
            fxp_times_fkmat_all_trans[k]=fxp_times_fkmat_all[i +nchan*j];
            k++;
        }
    }
  
//  Testing stuff


//  -------------------------------------------------------------------------------
    if (rafinv > rmatr)
    {
       mpf_class *full_fxp_prop = new mpf_class[nchan*nchan];
       mpf_class *full_fx_prop = new mpf_class[nchan*nchan];
    
       back_propagate_radial_functions_full(nchan,nopen, rmat_a, rmat_c,K_matrix_c, 
                                            r11, r12, r21, r22,
                                            fx_1, fx_2, fxp_1, fxp_2,
                                            full_fx_prop, full_fxp_prop);  
                                    
       k=0;
       for(int i=0;i<nchan;i++)
       {
          for(int j=0;j<nopen;j++)
          {
             fx_times_fkmat_all_trans[k].real()=full_fx_prop[i +nchan*j];
             fxp_times_fkmat_all_trans[k].real()=full_fxp_prop[i +nchan*j];
             k++;
          }
       }
    }
//  --------------------------------------------------------------------------------  
  
  
cout << "Time 9: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;  

    Cgetrs("T",nopen,nchan, zidentity_plusminus_iK, nopen, ipiv, fx_times_fkmat_all_trans, nopen, &info);
    Cgetrs("T",nopen,nchan, zidentity_plusminus_iK, nopen, ipiv, fxp_times_fkmat_all_trans, nopen, &info);

    mpc_class *fx_plusminus = new mpc_class[nchan*nopen];
    mpc_class *fxp_plusminus = new mpc_class[nchan*nopen];
    
   

//   Transpose      
    k=0;
    for(int i=0;i<nopen;i++)
    {
        for(int j=0;j<nchan;j++)
        {
            fx_plusminus[k]=fx_times_fkmat_all_trans[i +nopen*j];
            fxp_plusminus[k]=fxp_times_fkmat_all_trans[i +nopen*j];
            k++;
        }
    }         

    mpf_class two_over_pi, root_two_over_pi, pi;
    pi=3.141592653589793238462643383279502884197169399375105; 
         //8209749445923078164062862089986280348253421170679;
    two_over_pi=2.0/pi;
    mpf_sqrt(root_two_over_pi.get_mpf_t(), two_over_pi.get_mpf_t());
    
    for(int i=0;i<nchan*nopen;i++)
    {
        fx_plusminus[i]*=root_two_over_pi;
        fxp_plusminus[i]*=root_two_over_pi;
    }
//     cout << " BOUNDARY CONDITIONS" << endl;
//     
//     printf("fx_plusminus =");
//     cout << endl;
//     cprintmat(nchan, nopen, fx_plusminus, nchan);
//     printf("\n");    
//     printf("fxp_plusminus =");
//     cout << endl;
//     cprintmat(nchan, nopen, fxp_plusminus, nchan);
//     printf("\n");
//     cout << "The matrix elm fxp_plusminu(nchan-1,0)" << endl;
//     cout << fxp_plusminus[nchan-1].real().get_d()<< "  i " <<fxp_plusminus[nchan-1].imag().get_d() << endl;

//  Calculate scattering wavefunction coefficients.
//  -----------------------------------------------

    mpc_class *delta_pole = new mpc_class[nocsf];
    mpf_class one, two;
    mpc_class denominator;
    one=1.0;
    two=2.0;
    
    for(int i=0;i<nocsf;i++)
    {
      
        denominator.real()=eig[i]-etarg[0]-en/two;
        denominator.imag()=ibctyp*gamma;
        delta_pole[i]=one/denominator;
        
    }
    
    mpc_class *wampt = new mpc_class[nocsf*nchan];
    mpc_class *wampt_x_delta_pole = new mpc_class[nocsf*nchan];

//   Transpose    
    k=0;
    for(int i=0;i<nchan;i++)
    {
        for(int j=0;j<nocsf;j++)
        {
            wampt[k].real()=wamp[i +nchan*j];
            k++;
        }
    }
    
    mpf_class root_two, bamp_norm,half;
    mpf_sqrt(root_two.get_mpf_t(), two.get_mpf_t());    

    bamp_norm=root_two;
    half=0.5;
    
    for(int i=0;i<nocsf;i++){  
        for(int j=0;j<nchan;j++)
        {
            wampt_x_delta_pole[i+j*nocsf]=delta_pole[i]*wampt[i+j*nocsf];
        }
    }
    
//     printf("wampt_x_delta_pole =");
//     cout << endl;
//     cprintmat(nocsf, nchan, wampt_x_delta_pole, nocsf);
//     printf("\n");   

// Benchmark REAL Rgemm
// --------------------
// mpf_class *real_wampt_x_delta_pole = new mpf_class[nocsf*nchan];
// mpf_class *real_fxp_plusminus = new mpf_class[nchan*nopen];
// mpf_class *real_wavefunction_coefs = new mpf_class[nocsf*nopen];
// mpf_class real_alpha, real_beta;
// real_alpha=alpha.real();
// real_beta=beta.real();
// for(int i=0;i<nocsf;i++){  
//    for(int j=0;j<nchan;j++) 
//    {
//             real_wampt_x_delta_pole[i+j*nocsf]=wampt_x_delta_pole[i+j*nocsf].real();
//    }
// }
// for(int i=0;i<nchan;i++){  
//    for(int j=0;j<nopen;j++)
//    {
//             real_fxp_plusminus[i+j*nchan]=fxp_plusminus[i+j*nchan].real();
//    }
// }    
// --------------------





    mpc_class *wavefunction_coefs = new mpc_class[nocsf*nopen];
     
cout << "Time 10: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl; 
    Cgemm("N","N",nocsf,nopen,nchan,alpha,wampt_x_delta_pole,nocsf,fxp_plusminus,nchan,beta,wavefunction_coefs,nocsf);    
cout << "Time 11: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;

//  Test matrix multiply by hand
//  ----------------------------

// Rgemm("N","N",nocsf,nopen,nchan,real_alpha,real_wampt_x_delta_pole,nocsf,real_fxp_plusminus,nchan,real_beta,real_wavefunction_coefs,nocsf);    
// cout << "Time 12: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;


// // cout << "nocf " << nocsf << " nchan " << nchan << " nopen " << nopen << endl;
//  
//         for(int j=0;j<nopen;j++){
//            for (int k=0;k<nchan;k++){
//                 for(int i=0;i<nocsf;i++){  
//               wavefunction_coefs[i+j*nocsf]=wampt_x_delta_pole[i+k*nocsf]*fxp_plusminus[k+j*nchan];
//            }
//         }
//     }
// cout << "Time 12: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;

//  ----------------------------



    for(int i=0;i<nocsf*nopen;i++){
        wavefunction_coefs[i]*=bamp_norm*half;
    }

    for(int i=0;i<nocsf*nopen;i++){
        akr[i]=wavefunction_coefs[i].real().get_d();
        aki[i]=wavefunction_coefs[i].imag().get_d(); 
    }
cout << "Time 13: " << (clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << endl;
}

