/* File: QuadratureRule.cpp
 *
 * Interface class for quadrature rules
 * 
 * C Michael Hanke 2023
 * Version: 2023-02-13
 */

/* 
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

*/

#include "QuadratureRule.hpp"
#include "LSCMConfig.hpp"
#include <Eigen/Dense>
#include <iostream>
#include <cstdlib>
#include <limits>

using namespace LSCM;
using namespace Eigen;
using namespace std;

// Local functions
namespace {
    
// Evaluate Jacobi polynomial and derivative via recurrence relation.
void jac_eval(const LSCMint n, const VectorXd& x, const double a, const double b,
              VectorXd& P, VectorXd& Pp) {
    // Initialise:
    double ab = a + b;
    P = (x*(ab+2.0)+VectorXd::Constant(x.size(),a-b))*0.5;
    VectorXd Pm1 = VectorXd::Constant(x.size(),1.0);
    Pp = VectorXd::Constant(x.size(),0.5*(ab+2));
    VectorXd Ppm1 = VectorXd::Zero(x.size());
    
    if (n == 0) {
        P = Pm1;
        Pp = Ppm1;
        return;
    }

    for (LSCMint k = 1; k < n; ++k) {
        // Useful values:
        double A1 = 1.0/(2*(k+1)*(k+ab+1.0)*(2*k+ab));
        double B = (2*k + ab + 1.0)*(a*a - b*b);
        double C = (2*k+ab)*(2*k+ab+1.0)*(2*k+ab+2.0);
        double D = 2*(k + a)*(k + b)*(2*k + ab + 2.0);

        // Recurrence:
        VectorXd Pa1 = ((VectorXd::Constant(x.size(),B)+C*x).cwiseProduct(P)
                    -D*Pm1)*A1;
        VectorXd Ppa1 = ((VectorXd::Constant(x.size(),B)+C*x).cwiseProduct(Pp)
                    +C*P-D*Ppm1)*A1;

        // Update:
        Pm1 = P; 
        P = Pa1;  
        Ppm1 =  Pp; 
        Pp = Ppa1;
    }
}

// Jacobi polynomial recurrence relation.
void jac_main(const LSCMint n, const double a, const double b,
              const bool flag, VectorXd& x, VectorXd& PP) {
    // Asymptotic formula (WKB) - only positive x.
    LSCMint rmax = flag ? (n+1)/2 : n/2;  // integer division
    ArrayXd C(rmax);
    for (LSCMint r = rmax; r > 0; --r) 
        C(rmax-r) = (M_PI/(2*n+a+b+1.0))*(2*r+a-0.5);
    ArrayXd tanC = tan(C*0.5);
    x = (C+1.0/pow(2*n+a+b+1.0,2)*
         ((ArrayXd::Constant(rmax,0.25-a*a)/tanC)-
                    tanC*(0.25-b*b))).cos();

    // Initialise:
    VectorXd dx =
                VectorXd::Constant(rmax,numeric_limits<double>::infinity()); 
    LSCMint l = 0;
    VectorXd P;
    // Loop until convergence:
    while ( 
        (dx.lpNorm<Infinity>() >
                    1e-3*sqrt(numeric_limits<double>::epsilon()))
        && (l < 10) ) {
        ++l;
        jac_eval(n,x,a,b,P,PP);
        dx = -P.array()/PP.array(); 
        x += dx;
    }
    
    // Once more for derivatives:
    jac_eval(n,x,a,b,P,PP);
}

// End unnamed namespace
}

void QuadratureRule::Jacobi(const LSCMint n, const double a, const double b,
                             VectorXd& x,
                             VectorXd& w, VectorXd& c, VectorXd& t, double& scale) {
    // Check inputs
    if ( a <= -1 || b <= -1 ) {
        cerr << "Error QuadratureRule: Alpha and beta must be greater than -1." << endl;
        exit(1);
    }
    
    // Go! Inspired by Chebfun
    VectorXd x1, x2;
    VectorXd ders1, ders2;
    x = VectorXd(n);
    VectorXd ders(n);
    
    jac_main(n,a,b,true,x1,ders1);
    if (a != b) {
        jac_main(n,b,a,false,x2,ders2);
        x << -x2(seq(last,0,fix<-1>)), x1;
        ders << ders2(seq(last,0,fix<-1>)), ders1;
    }
    else {
        x << -x1(seq(last,n%2,fix<-1>)), x1;
        ders << -ders1(seq(last,n%2,fix<-1>)), ders1;
    }

    // Security checks
    // Monotonicy
    for (LSCMint i = 1; i < n; ++i) {
        if (x(i) <= x(i-1)) {
            cerr << "Error QuadratureRule: Something went wrong" << endl;
            exit(1);
        }
    }
    
    // Final step
    c = (ArrayXd::Constant(n,1.0)/ders.array()).abs();
    for (LSCMint i = n-2; i >= 0; i -= 2) c(i) = -c(i);
    
    w = c.cwiseProduct(c).array()/
                (VectorXd::Constant(n,1.0)-x.cwiseProduct(x)).array();
    double C = pow(2,a+b+1.0) * exp(lgamma(n+a+1.0)-lgamma(n+a+b+1)+
                lgamma(n+b+1.0)-lgamma(n+1.0));
    w *= C;
    scale = bscaljac(n,a,b);
    t = VectorXd::Zero(0);
}

// Implements formula (2.25) of Wang, Huybrechts, Vandewalle
double QuadratureRule::bscaljac(const LSCMint n, const double a, const double b) {
    return exp(lgamma(2*n+a+b+3)
                -0.5*(lgamma((double) n+2)+lgamma(n+a+b+2)+lgamma(n+a+2)+lgamma(n+b+2)))
                    /pow(2,n+1+0.5*(a+b+1));
}

void QuadratureRule::genDmat() {
    // Construct Dx and Dw.
    MatrixXd Dxi(n_,n_);
    // Could both of these cases done in one step?
    if (RuleProperties & QR_SYMMETRIC) {
        if (t.size() == n_) {
            //Trig identity
            //VectorXd tloc(n_);
            //for (LSCMint i = 0; i < n_; ++i) tloc(i) = 0.5*t(i); //(n_-i-1);
            for (LSCMint k = 0; k < n_; ++k) {
                Dxi(k,k) = 1.0;
                for (LSCMint j = n_-k-1; j < n_; ++j) {
                    if (j == k) continue;
                    Dxi(k,j) = 0.5/(sin(0.5*(t(k)+t(j)))*sin(0.5*(t(j)-t(k))));
                    Dxi(n_-k-1,n_-j-1) = -Dxi(k,j);
                }
            }
        }
        else
            for (LSCMint k = 0; k < n_; ++k) {
                Dxi(k,k) = 1.0;
                for (LSCMint j = n_-k-1; j < n_; ++j) {
                    if (j == k) continue;
                    Dxi(k,j) = 1.0/(x(k)-x(j));
                    Dxi(n_-k-1,n_-j-1) = -Dxi(k,j);
                }
            }
    }
    else {
        if (t.size() == n_) {
            //Trig identity
            //VectorXd tloc(n_);
            //for (LSCMint i = 0; i < n_; ++i) tloc(i) = 0.5*t(i); //(n_-i-1);
            for (LSCMint k = 0; k < n_; ++k) {
                Dxi(k,k) = 1.0;
                for (LSCMint j = 0; j < n_; ++j) {
                    if (j == k) continue;
                    Dxi(k,j) = 0.5/(sin(0.5*(t(k)+t(j)))*sin(0.5*(t(j)-t(k))));
                }
            }
        }
        else {
            for (LSCMint k = 0; k < n_; ++k) {
                Dxi(k,k) = 1.0;
                for (LSCMint j = 0; j < n_; ++j) {
                    if (j == k) continue;
                    Dxi(k,j) = 1.0/(x(k)-x(j));
                }
            }
        }
    }

    // Pairwise divisions
    Dmat = MatrixXd(n_,n_);
    for (LSCMint k = 0; k < n_; ++k) {
        for (LSCMint j = 0; j < n_; ++j) Dmat(k,j) = c(j)/c(k);
        Dmat(k,k) = 0.0;
    }

    //   Generate
    Dmat = Dmat.cwiseProduct(Dxi);
}

MatrixXd QuadratureRule::gen2mass() {
    // Prepare using Legendre polynomials. Use recursion for stability
    MatrixXd V(n_,n_);

    V.col(0) = VectorXd::Constant(n_,1.0);
    if (n_ > 1) {
        V.col(1) = x;
        for (LSCMint j = 2; j < n_; j++) {
            VectorXd tmp = x.cwiseProduct(V.col(j-1));
            V.col(j) = (((double) (j-1))/j)*(tmp-V.col(j-2))+tmp;
        }
    }
    for (LSCMint j = 0; j < n_; ++j) {
        double fac = sqrt(2.0*j+1.0);
        V.col(j) *= fac;
    }
        
    // Create transformation matrix and scale
    FullPivLU<MatrixXd> lu(V);
    return lu.inverse();
}
