// Cytosim was created by Francois Nedelec. Copyright Cambridge University 2020

#include "mecafil_code.cc"
#include "exceptions.h"
#include "xpttrf.h"

// required for debugging:
#include "cytoblas.h"
#include "vecprint.h"

/*
 Selection of LAPACK routines
 The LAPACK implementation is the safest choice
 The Alsatian version is faster as it avoids divisions
 */

#if ( 1 )
#  define DPTTRF lapack_xpttrf
#  define DPTTS2 lapack_xptts2
#else
#  define DPTTRF lapack::xpttrf
#  define DPTTS2 lapack::xptts2
#endif


/*
Selection of projectForces() routines optimized for some architectures
*/

#if ( DIM == 3 ) && REAL_IS_DOUBLE && defined(__AVX__)
#  define projectForcesU projectForcesU3D_AVX
#  define projectForcesD projectForcesD3D_AVX
#elif ( DIM == 3 ) && REAL_IS_DOUBLE && defined(__SSE3__)
#  define projectForcesU projectForcesU3D_SSE
#  define projectForcesD projectForcesD3D_SSE
#elif ( DIM == 3 ) && defined(__SSE3__)
#  define projectForcesU projectForcesU_
#  define projectForcesD projectForcesD3D_SSE
#elif ( DIM == 2 ) && REAL_IS_DOUBLE && defined(__AVX__)
#  define projectForcesU projectForcesU2D_AVX
#  define projectForcesD projectForcesD2D_AVX
#elif ( DIM == 2 ) && REAL_IS_DOUBLE && defined(__SSE3__)
#  define projectForcesU projectForcesU2D_SSE
#  define projectForcesD projectForcesD2D_SSE
#else
#  warning "Using scalar Fiber::projectForces"
#  define projectForcesU projectForcesU_
#  define projectForcesD projectForcesD_
#endif


//------------------------------------------------------------------------------
#pragma mark -


void Mecafil::initProjection()
{
    //reset all variables for the projections:
    iJJt   = nullptr;
#if ADD_PROJECTION_DIFF
    iJJtJF = nullptr;
#endif
}


void Mecafil::allocateProjection(const size_t ms)
{
    //std::clog << reference() << "allocateProjection(" << nbp << ")\n";
    free_real(iJJt);
#if ADD_PROJECTION_DIFF
    real * mem = new_real(3*ms);
    //zero_real(3*ms, mem);
    iJJt   = mem;
    iJJtU  = mem + ms;
    iJJtJF = mem + ms * 2;
#else
    real * mem = new_real(2*ms);
    //zero_real(2*ms, mem);
    iJJt   = mem;
    iJJtU  = mem + ms;
#endif
}


void Mecafil::destroyProjection()
{
    //std::clog << reference() << "destroyProjection\n";
    free_real(iJJt);
    iJJt   = nullptr;
    iJJtU  = nullptr;
#if ADD_PROJECTION_DIFF
    iJJtJF = nullptr;
#endif
}


/** This is the standard version assuming isotropic drag coefficients */
void Mecafil::makeProjection()
{
    assert_true( nbPoints() >= 2 );

    //set the diagonal and off-diagonal of J*J'
    const size_t nbu = nbPoints() - 2;

    for ( size_t jj = 0; jj < nbu; ++jj )
    {
        const real* X = iDir + DIM * jj;
#if ( DIM == 2 )
        real xn = X[0]*X[2] + X[1]*X[3];
#else
        real xn = X[0]*X[3] + X[1]*X[4] + X[2]*X[5];
#endif
        
        // this term should be 2.0, since iDir[] vectors are normalized:
#if ( DIM == 2 )
        iJJt[jj] = 2 * ( X[0]*X[0] + X[1]*X[1] );
#else
        iJJt[jj] = 2 * ( X[0]*X[0] + X[1]*X[1] + X[2]*X[2] );
#endif
        // iJJt[jj]  = 2.0;
        
        iJJtU[jj] = -xn;
    }
    
    const real* X = iDir + DIM*nbu;
    // this term should be 2, since iDir[] vectors are normalized
#if ( DIM == 2 )
    iJJt[nbu] = 2 * ( X[0]*X[0] + X[1]*X[1] );
#else
    iJJt[nbu] = 2 * ( X[0]*X[0] + X[1]*X[1] + X[2]*X[2] );
#endif
    //iJJt[nbu] = 2.0;

    int info = 0;
    DPTTRF(nbu+1, iJJt, iJJtU, &info);

    if ( 0 )
    {
        VecPrint::print("D", nbu+1, iJJt, 2);
        VecPrint::print("U", nbu, iJJtU, 2);
        //VecPrint::print("X", DIM*(nbu+2), pPos, 2);
    }

    if ( info )
    {
        std::clog << "Mecafil::makeProjection failed (" << info << ")\n";
        throw Exception("could not build Fiber's projection matrix");
    }
}


//------------------------------------------------------------------------------
#pragma mark - Reference (scalar) code

/**
 Perform first calculation needed by projectForces:
     mul[] <- dot(dif[], src[+DIM] - src[])
 which is:
     mul[i] <- dot(dif[i*DIM], src[i*DIM+DIM] - src[i*DIM])
     for i in [ 0, nbs-1 ]
 
 with 'nbs' = number of segments, and
      dif[] of size nbs*DIM
      vec[] of size (nbs+1)*DIM
      mul[] of size nbs
 
 Note that this should work even if 'mul==src'
 */
void projectForcesU_(size_t nbs, const real* dir, const real* src, real* mul)
{
    const real *const end = mul + nbs;

    while ( mul < end )
    {
        *mul = dir[0] * ( src[DIM  ] - src[0] )
             + dir[1] * ( src[DIM+1] - src[1] )
#if ( DIM > 2 )
             + dir[2] * ( src[DIM+2] - src[2] )
#endif
        ;
        src += DIM;
        dir += DIM;
        ++mul;
    }
}

/**
 Perform second calculation needed by projectForces:
     dst <- src +/- dif * mul
 which is:
     for i in DIM * [ 0, nbs-1 ]
         dst[i] <- src[i] + dif[i] * mul[i] - dif[i-1] * mul[i-1]

 with 'nbs' = number of segments, and
      dif[] of size nbs*DIM
      src[] and dst[] of size (nbs+1)*DIM
      mul[] of size nbs
 
 Note that this should work even if 'dst==src'
 */
void projectForcesD_(const size_t nbs, const real* dir, const real* src, const real* mul, real* dst)
{
    for ( size_t s = 0; s < DIM; ++s )
        dst[s] = src[s] + dir[s] * mul[0];
    
    for ( size_t e = DIM*nbs; e < DIM*(nbs+1); ++e )
        dst[e] = src[e] - dir[e-DIM] * mul[nbs-1];
    
    for ( size_t j = 1; j < nbs; ++j )
    {
        const size_t kk = DIM * j;
        const real M = mul[j], P = mul[j-1];
        dst[kk  ] = src[kk  ] + dir[kk  ] * M - dir[kk-DIM  ] * P;
        dst[kk+1] = src[kk+1] + dir[kk+1] * M - dir[kk-DIM+1] * P;
#if ( DIM > 2 )
        dst[kk+2] = src[kk+2] + dir[kk+2] * M - dir[kk-DIM+2] * P;
#endif
    }
}


/**
 This will update each vector of `dst`:

     dst <- src + (TT') src

 Where T is a local direction given by `dir` at every vertex.
 This is used to double the tangential component of force `X`,
 without changing the orthogonal components.
 
 Note that this must work properly even if `dst` == `src`
 */
void scaleTangentially(size_t nbp, const real* src, const real* dir, real* dst)
{
    const real* const end = src + DIM * nbp;
    while ( src < end )
    {
#if ( DIM == 2 )
        real s = src[0] * dir[0] + src[1] * dir[1];
        dst[0] = src[0] + s * dir[0];
        dst[1] = src[1] + s * dir[1];
#elif ( DIM >= 3 )
        real s = src[0] * dir[0] + src[1] * dir[1] + src[2] * dir[2];
        dst[0] = src[0] + s * dir[0];
        dst[1] = src[1] + s * dir[1];
        dst[2] = src[2] + s * dir[2];
#endif
        src += DIM;
        dir += DIM;
        dst += DIM;
    }
}

//------------------------------------------------------------------------------

/*
 Y <- components of X that are compatible with the length constaints
 Note that this should work correctly even if ( X == Y ), which is always the case
 */
void Mecafil::projectForces(const real* X, real* Y) const
{
    const size_t nbs = nbSegments();
    //VecPrint::print("X", DIM*nbPoints(), X);
    
    // calculate `iLLG` without modifying `X`
    projectForcesU(nbs, iDir, X, iLLG);
    
    // Lagrange multipliers <- inv( J * Jt ) * iLLG
    DPTTS2(nbs, 1, iJJt, iJJtU, iLLG, nbs);

    // set Y, using values in X and iLLG
    projectForcesD(nbs, iDir, X, iLLG, Y);

    //VecPrint::print("Y", DIM*nbPoints(), Y);
}


/**
 This sets `iLag` corresponding to the given forces
 */
void Mecafil::computeTensions(const real* force)
{
    const size_t nbs = nbSegments();
    
    projectForcesU(nbs, iDir, force, iLag);
    
    // determine the multipliers: iLag <- inv( J * Jt ) * iLag
    DPTTS2(nbs, 1, iJJt, iJJtU, iLag, nbs);
}


/** This extracts the matrix underlying the 'Mecafil::projectForces()' */
void Mecafil::printProjection(std::ostream& os) const
{
    const size_t nbv = DIM * nbPoints();
    real * res = new_real(nbv*nbv);
    real * src = new_real(nbv);
    real * dst = new_real(nbv);
    zero_real(nbv, src);
    zero_real(nbv, dst);
    for ( size_t i = 0; i < nbv; ++i )
    {
        src[i] = 1.0;
        projectForces(src, dst);
        copy_real(nbv, dst, res+nbv*i);
        src[i] = 0.0;
    }
    free_real(dst);
    free_real(src);
    os << "Mecafil:Projection  " << reference() << " (" << nbPoints() << ")\n";
    VecPrint::full(os, nbv, nbv, res, nbv);
    free_real(res);
}



//------------------------------------------------------------------------------
#pragma mark - Correction terms to the Projection

#if ADD_PROJECTION_DIFF

// add debug code to compare with reference implementation
#define CHECK_PROJECTION_DIFF 0

/** This assumes that the Lagrange multipliers in 'iLLG' can be used */
void Mecafil::setProjectionDiff(const real threshold)
{
    const size_t nbs = nbSegments();

    // use Lagrange multipliers computed from the last projectForces() in iLLG
    // check for extensile ( positive ) multipliers
    for ( size_t i = 0; i < nbs; ++i )
    {
        if ( iLLG[i] > threshold )
        {
            useProjectionDiff = true;
            break;
        }
    }
    
    // remove compressive ( negative ) multipliers
    if ( useProjectionDiff )
    {
        const real alpha = 1.0 / segmentation();
        #pragma vector unaligned
        for ( size_t jj = 0; jj < nbs; ++jj )
            iJJtJF[jj] = std::max(threshold, alpha * iLLG[jj]);
        
        //std::clog << "projectionDiff: " << blas::nrm2(nbs, iJJtJF) << '\n';
        //VecPrint::print("projectionDiff:", std::min(20u,nbs), iJJtJF);
    }
}


/// Reference (scalar) code
/**
 This looks similar to addProjectionD() except with dir[i] = X[i+DIM]-x[i]
 */
inline void addProjectionDiff_(const size_t nbs, const real* mul, const real* X, real* Y)
{
    for ( size_t i = 0; i < nbs; ++i )
    {
        for ( size_t d = 0; d < DIM; ++d )
        {
            const real w = mul[i] * ( X[DIM*i+DIM+d] - X[DIM*i+d] );
            Y[DIM*i+(    d)] += w;
            Y[DIM*i+(DIM+d)] -= w;
        }
    }
}


void Mecafil::addProjectionDiff(const real* X, real* Y) const
{
#if CHECK_PROJECTION_DIFF
    size_t nbp = nbPoints()*DIM;
    real * vec = new_real(nbp);
    copy_real(nbp, Y, vec);
    addProjectionDiff_(nbSegments(), iJJtJF, X, vec);
#endif

#if ( DIM == 2 ) && REAL_IS_DOUBLE && defined(__SSE3__)
    addProjectionDiff2D_SSE(nbSegments(), iJJtJF, X, Y);
    //addProjectionDiff_AVX(nbSegments(), iJJtJF, X, Y);
#elif ( DIM == 3 ) && REAL_IS_DOUBLE && defined(__SSE3__)
    addProjectionDiff3D_SSE(nbSegments(), iJJtJF, X, Y);
#else
    addProjectionDiff_F(nbSegments(), iJJtJF, X, Y);
    //addProjectionDiff_(nbSegments(), iJJtJF, X, Y);
#endif
    
#if CHECK_PROJECTION_DIFF
    real e = blas::difference(nbp, Y, vec);
    if ( e > 1e-6 )
    {
        std::clog << "Mecafil:addProjectionDiff(" << nbp << ") error " << e << "\n";
        VecPrint::edges("ref ", nbp, vec, 3);
        VecPrint::edges("sse ", nbp, Y, 3);
    }
    free_real(vec);
#endif
}

#endif

#if ADD_PROJECTION_DIFF == 2

/** This is the debug pathway */
void Mecafil::makeProjectionDiff(const real* force)
{
    if ( force )
    {
        // compute tensions in 'iLag' using the given forces
        computeTensions(force);
        return;
    }
    
    // Check that `iLLG` contains the same values as `iLag`
    real e = blas::difference(nbSegments(), iLLG, iLag);
    if ( e > 1e-6 )
    {
        size_t S = std::min(16LU, nbSegments());
        std::clog << "Mecafil: |iLag - iLLG| = " << e << "\n";
        VecPrint::print("iLag ", S, iLag);
        VecPrint::print("iLLG ", S, iLLG);
    }
    setProjectionDiff(0);
}

#elif ADD_PROJECTION_DIFF

/** This is the normal pathway without verifications */
void Mecafil::makeProjectionDiff(const real*)
{
    setProjectionDiff(0);
}

#endif
