/* File: WeightedSolver.cpp
 *
 * Realization of the weighted solver of linear DAEs
 * for a
 * given grid and approximation space.
 * 
 * Copyright (C) Michael Hanke 2020
 * Version: 2022-06-05
 */

/* 
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

*/

#include "LSCMConfig.hpp"
#include "WeightedSolver.hpp"
#include "GridFkt.hpp"
#include <Eigen/Dense>
#include <Eigen/SparseCore>
#include <Eigen/SPQRSupport>
#include <memory>
#include <ctime>

using namespace LSCM;
using namespace Eigen;
using namespace std;

void LSCM::WeightedSolver::factorizeExec()
{
    clock_t tass = clock();
    LSCMSparseMatrix X = genmatT();
    timings.tass = ((double)(clock()-tass))/CLOCKS_PER_SEC;
    
    mem.nnzA = X.nonZeros();
    mem.dimA = X.rows();
    mem.ndim = X.cols();
    
    clock_t tfact = clock();
    
    QR = new SPQR<LSCMSparseMatrix>;
    QR->compute(X);
    
    timings.tfact = ((double)(clock()-tfact))/CLOCKS_PER_SEC;

    if (QR->info() != Success) {
        // cerr << QR.lastErrorMessage() << endl;
        cerr << "WeightedSolver: QR decomposition failed" << endl;
        exit(1);
    }

    // Check statistics
    cholmod_common *cc = QR->cholmodCommon();
    mem.nWork = cc->SPQR_istat[0]+cc->SPQR_istat[1];
    
    DecompositionAvail = true;
}

GridFkt LSCM::WeightedSolver::solveExec(const VectorXd& f)
{
    clock_t tslv = clock();
    
    VectorXd q1 = mat->genrhsq();
    VectorXd rhs = VectorXd::Zero(Arows+Crows);
    if (homConstraints)
        rhs.tail(q1.size()) = q1;
    else {
        rhs.tail(Arows) = q1.head(Arows);
        rhs.head(numBC) = weight*q1.tail(numBC);
    }        
    VectorXd SolCoeffs = QR->solve(rhs);
    if (QR->info() != Success) {
        cerr << "WeightedSolver: Solve failed" << endl;
        exit(1);
    }
    
    timings.tslv = ((double)(clock()-tslv))/CLOCKS_PER_SEC;

    return GridFkt(mat->getspace(),SolCoeffs);
}

LSCMSparseMatrix LSCM::WeightedSolver::genmatT()
{
    // This version should be more stable w r t Householder QR.
    // NOTE: This version uses less memory than genmatS() at the expense of
    //       explicit memory management
    LSCMSparseMatrix A = mat->genA();
    LSCMSparseMatrix C = mat->genC();
    A.makeCompressed();  // Both A and C should already be compressed as output of
    C.makeCompressed();  // LinLSQMatrices!
    Arows = A.rows();
    Crows = C.rows();
    Acols = A.cols();

    auto Acolptr = A.outerIndexPtr();
    auto Arowidx = A.innerIndexPtr();
    auto Avals = A.valuePtr();
    auto Ccolptr = C.outerIndexPtr();
    auto Crowidx = C.innerIndexPtr();
    auto Cvals = C.valuePtr();
    
    LSCMint Xrows = Arows+Crows;
    LSCMint Xnnz = A.nonZeros()+C.nonZeros();
    Xvals = new double[Xnnz];
    Xcolptr = new LSCMindex[Acols+1];
    Xrowidx = new LSCMindex[Xnnz];
    
    LSCMindex Xstj = 0;
    for (LSCMindex j = 0; j < Acols; ++j) {
        auto Astj = Acolptr[j];
        auto Alenj = Acolptr[j+1]-Astj;
        auto Cstj = Ccolptr[j];
        auto Clenj = Ccolptr[j+1]-Cstj;
        
        Xcolptr[j] = Xstj;
        for (LSCMindex ii = 0; ii < Clenj; ++ii) {
            Xrowidx[Xstj+ii] = Crowidx[Cstj+ii];  // memcpy()???
            Xvals[Xstj+ii] = weight*Cvals[Cstj+ii];
        }
        Xstj += Clenj;
        for (LSCMindex ii = 0; ii < Alenj; ++ii) {
            Xrowidx[Xstj+ii] = Arowidx[Astj+ii]+Crows;
            Xvals[Xstj+ii] = Avals[Astj+ii];  // memcpy()???
        }
        Xstj += Alenj;
    }
    Xcolptr[Acols] = Xstj;
    
    Map<LSCMSparseMatrix> X(Xrows,Acols,Xnnz,Xcolptr,Xrowidx,Xvals);
    return X;
}
