/* File: DirectSolver.cpp
 *
 * Realization of the direct solver of linear DAEs
 * for a
 * given grid and approximation space. Taken from Barlow 1992
 * 
 * Copyright (C) Michael Hanke 2020
 * Version: 2022-06-06
 */

/* 
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

*/

#include "DirectSolver.hpp"
#include "GridFkt.hpp"
#include <Eigen/Dense>
#include <Eigen/SparseCore>
#include <Eigen/SparseLU>
#include <Eigen/SPQRSupport>
#include <vector>
#include <ctime>
#ifdef USEMKL
#include <mkl.h>
#else
#include <lapacke.h>
#include <cblas.h>
#endif

// Fortran name mangling
#define DLACON dlacon_

#ifndef USEMKL
extern "C" {
    void DLACON(const lapack_int* n, double* v, double* x, lapack_int* isgn,
              double* est, lapack_int* kase);
}
#endif

using namespace LSCM;
using namespace Eigen;
using namespace std;

// This function is a workaround for a deficiency LSCMint Eigen's SPQR interface.
namespace {
LSCMSparseMatrix
SPQRtoQ(const SPQR<LSCMSparseMatrix>& qrC, LSCMint istart, LSCMint col, LSCMint dim)
{
    vector<LSCMTriplet> Qlist;
    Qlist.reserve(dim*col/10); // Factor from experiments
    VectorXd sj = VectorXd::Zero(dim);
    for (LSCMint j = 0; j < col; ++j) {
        sj(istart+j) = 1.0;
        VectorXd tqj = qrC.matrixQ()*sj;
        sj(istart+j) = 0.0;
        for (LSCMint i = 0; i < dim; ++i)
            if (tqj(i) != 0.0)
                Qlist.push_back(LSCMTriplet(i,j,tqj(i)));
    }
    LSCMSparseMatrix Q(dim,col);
    Q.setFromTriplets(Qlist.begin(),Qlist.end());

    return Q;
}
}

void LSCM::DirectSolver::factorizeExec()
{
    clock_t tass = clock();
    
    LSCMSparseMatrix A = mat->genA();
    LSCMSparseMatrix C = mat->genC();
    
    timings.tass = ((double)(clock()-tass))/CLOCKS_PER_SEC;
    
    Crow = C.rows();
    nun = C.cols();
    Arow = A.rows();
    
    mem.nnzA = A.nonZeros();
    mem.dimA = Arow;
    mem.nnzC = C.nonZeros();
    mem.dimC = Crow;
    mem.ndim = nun;
    
    clock_t tfact = clock();
   
    // The algorithm requires too much memory in this implementation!
    // This version of SPQR is not documented!!
    //SPQR<SparseMatrix<double>> qrC(C);
    SPQR<LSCMSparseMatrix> qrC;
    qrC.compute(C);
    if (qrC.info() != Success) {
        // cerr << QR.lastErrorMessage() << endl;
        cerr << "DirectSolver: QR decomposition failed" << endl;
        exit(1);
    }
    if (!homConstraints)
        Q = SPQRtoQ(qrC,0,Crow,Crow);
    auto U = qrC.matrixR();
    U1 = U.leftCols(Crow);
    // Check fullrank: Always true!
    for (LSCMint i = 0; i < Crow; ++i)
        if (U1.coeff(i,i) == 0.0) {
            cerr << "DirectSolver: Continuity conditions dependent!" << endl;
            exit(1);
        }

    U2 = U.rightCols(nun-Crow);
    P = qrC.colsPermutation();
    auto Xperm = A*P;
    X1 = Xperm.leftCols(Crow);
    auto X2 = Xperm.rightCols(nun-Crow);
    
    // Workaround considering efficiency
    // Check if U1 is a diagonal matrix
    if (U1.nonZeros() == Crow) {
        // U1 is diagonal
        // Compute the inverses in order to avoid division
        double *U1values = U1.valuePtr();
        for (LSCMint i = 0; i < Crow; ++i) U1values[i] = 1.0/U1values[i];
        // Sparse Matrix format is column major!
        for (LSCMint k = 0; k < U2.outerSize(); ++k) 
            for (LSCMSparseMatrix::InnerIterator it(U2,k); it; ++it)
                //it.valueRef() *= U1.coeff(it.row(),it.row());
                it.valueRef() *= U1values[it.row()];
    }
    else {
        // This does not work! Not implemented in Eigen!!!
        // tmp = U1.triangularView<Upper>().solve(U2);
        //
        // The following procedure is a workaround for a deficiency in Eigen's SPQR
        // interface: SuiteSparseQR_solve() for sparse matrix right-hand sides is not
        // accessible in Eigen. Hopefully, this procedure will only seldom be used.
        vector<LSCMTriplet> tmp;
        tmp.reserve(10*U2.nonZeros());
        for (LSCMint i = 0; i < nun-Crow; ++i) {
            VectorXd f = VectorXd::Zero(Crow);
            for (LSCMSparseMatrix::InnerIterator it(U2,i); it; ++it)
                f(it.row()) = it.value();
            VectorXd r = U1.triangularView<Upper>().solve(f);
            for (LSCMint j = 0; j < Crow; ++j)
                if (r(j) != 0.0) tmp.push_back(LSCMTriplet(j,i,r(j)));
        }
        U2.setFromTriplets(tmp.begin(),tmp.end());
    }

    LSCMSparseMatrix Xtilde = X2-(X1*U2).pruned();
    Xtilde.makeCompressed();
    Xcol = Xtilde.cols();
    qrX = new SPQR<LSCMSparseMatrix>;
    qrX->compute(Xtilde);
    if (qrX->info() != Success) {
        // cerr << QR.lastErrorMessage() << endl;
        std::cerr << "DirectSolver: QR decomposition failed" << std::endl;
        exit(1);
    }
    
    timings.tfact = ((double)(clock()-tfact))/CLOCKS_PER_SEC;
   
    // Check statistics
    cholmod_common *cc = qrC.cholmodCommon();
    cholmod_common *cx = qrX->cholmodCommon();
    mem.nWork = 
        cc->SPQR_istat[0]+   // cc->SPQR_istat[1] not needed (but stored in this implementation
        cx->SPQR_istat[0]+
        cx->SPQR_istat[1]+
        U.nonZeros()+
        U2.nonZeros();
        // cerr << "+++ H nnz: " << cx->SPQR_istat[1] << endl;

    DecompositionAvail = true;
}

GridFkt LSCM::DirectSolver::solveExec(const VectorXd& f)
{
    clock_t tslv = clock();

    VectorXd al;
    VectorXd t(Crow);
    if (homConstraints)
        al = qrX->solve(f);
    else {
        VectorXd b = VectorXd::Zero(Crow);
        b.head(numBC) = f.tail(numBC);
        // Solve U1*t = Q.transpose()*b
        auto Qtb = Q.transpose()*b;
        if (U1.nonZeros() == Crow) {
            // U1 is diagonal
            double *U1values = U1.valuePtr();
            for (LSCMint i = 0; i < Crow; ++i) t(i) = U1values[i]*Qtb(i);
        }
        else {
            t = U1.triangularView<Upper>().solve(Qtb);
        }
        al = qrX->solve((f.head(Arow)-X1*t).eval());
    }

    if (qrX->info() != Success) {
        // cerr << QR.lastErrorMessage() << endl;
        cerr << "DirectSolver: solve failed" << endl;
        exit(1);
    }

    VectorXd tmpCoeffs(nun);
    if (mat->getBCversion())
        tmpCoeffs.head(U2.rows()) = -U2*al;
    else
        tmpCoeffs.head(U2.rows()) = t-U2*al;
    tmpCoeffs.tail(Xcol) = al;
    auto SolCoeffs = P*tmpCoeffs;

    timings.tslv = ((double)(clock()-tslv))/CLOCKS_PER_SEC;
    
    // Error estimation if required: Adapted from LAPACK documentation
    if (errorest) {
        // Only Xtilde, corresponding to \kappa_C(A), is used
        double bnorm = f.norm();
        // Norm of residual
        double rnorm = (qrX->matrixQ().transpose()*f).tail(Arow-Xcol).norm();
    
        // Condition number estimator
        // How much memory copy is involved????
        // TODO: Check memory copies!
        // The following version is stolen in part from SuiteSparseQRSupport.h
        double *v = new double[Xcol];
        VectorXd px(Xcol);
        double *x = px.data();
        LSCMint *isgn = new LSCMint[Xcol];
        double est;
        LSCMint kase = 0;
        DLACON(&Xcol,v,x,isgn,&est,&kase);
        while (kase > 0) {
            if (kase == 1) {
                qrX->matrixR().template triangularView<Upper>().solveInPlace(px);
            }
            else {
                qrX->matrixR().transpose().leftCols(Xcol).template triangularView<Lower>().solveInPlace(px);
            }
            DLACON(&Xcol,v,x,isgn,&est,&kase);
        }
        //cerr << "First estimation OK" << endl;
        delete [] isgn;
        //delete [] x;
        delete [] v;
        
        LSCMSparseMatrix r = qrX->matrixR();
        // from Eigen documentation: Does only work for full matrices!
        //double Rnorm = r.cwiseAbs().colwise().sum().maxCoeff();
        double Rnorm = 0.0;
        for (LSCMint i = 0; i <  r.outerSize(); ++i) {
            double sum = 0.0;
            for (LSCMSparseMatrix::InnerIterator it(r,i);
                 it; ++it)
                 sum += abs(it.value());
            if (Rnorm < sum) Rnorm = sum;
        }
        double rcond = 1.0/(Rnorm*est);
        
        // error estimator
        if (rcond < numeric_limits<double>::epsilon())
            rcond = numeric_limits<double>::epsilon();
        double sint = 0.0;
        if (bnorm > 0.0) sint = rnorm/bnorm;
        double cost = sqrt((1.0-sint)*(1.0+sint));
        if (cost < numeric_limits<double>::epsilon())
            cost = numeric_limits<double>::epsilon();
        double tant = sint/cost;
        error = (2.0/(rcond*cost)+tant/(rcond*rcond))*numeric_limits<double>::epsilon();
    }
    
    return GridFkt(mat->getspace(),static_cast<VectorXd>(SolCoeffs));
}

