/* File: DeferredCorrectionSolver.cpp
 *
 * Realization of the direct solver of linear DAEs
 * for a
 * given grid and approximation space. Taken from Barlow 1992
 * 
 * Copyright (C) Michael Hanke 2020
 * Version: 2022-06-03
 */

/* 
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

*/

#include "DeferredCorrectionSolver.hpp"
#include "GridFkt.hpp"
#include <Eigen/Dense>
#include <Eigen/SparseCore>
#include <Eigen/SPQRSupport>
#include <iostream>
#include <ctime>

using namespace LSCM;
using namespace Eigen;
using namespace std;

void LSCM::DeferredCorrectionSolver::factorizeExec()
{
    clock_t tass = clock();
    X = genmatT();
    timings.tass = ((double)(clock()-tass))/CLOCKS_PER_SEC;
    
    mem.nnzA = X.nonZeros();
    mem.dimA = Xrows;
    mem.ndim = X.cols();
    
    clock_t tfact = clock();
    
    // Initialization QR
    qrX = new SPQR<LSCMSparseMatrix>;
    qrX->compute(X);
    
    timings.tfact = ((double)(clock()-tfact))/CLOCKS_PER_SEC;

    if (qrX->info() != Success) {
        // cerr << QR.lastErrorMessage() << endl;
        cerr << "DeferredCorrectionSolver: QR decomposition failed" << endl;
        exit(1);
    }

    // Check statistics
    cholmod_common *cc = qrX->cholmodCommon();
    mem.nWork = cc->SPQR_istat[0]+cc->SPQR_istat[1];

    DecompositionAvail = true;
}

GridFkt LSCM::DeferredCorrectionSolver::solveExec(const VectorXd& f)
{
    clock_t tslv = clock();
    
    double tau = weight;
    double tolmod2 = (tol*tau)*(tol*tau);
    
    // First solve
    VectorXd zk = VectorXd::Zero(Xrows);
    if (homConstraints)
        zk.tail(Arows) = f;
    else {
        zk.tail(Arows) = f.head(Arows);
        zk.head(numBC) = weight*f.tail(numBC);
    }
    // Initialization iteration
    VectorXd xk = qrX->solve(zk);
    if (qrX->info() != Success) {
        cerr << "DeferredCorrectionSolver: Solve failed" << endl;
        exit(1);
    }

    VectorXd wek = zk-X*xk;
    VectorXd lk = tau*wek.head(Crows);
    
    // Iteration
    bool converged = false;
    for (LSCMint k = 0; k < kmax; ++k) {
        zk = wek;
        zk.head(Crows) += (1.0/tau)*lk;
        VectorXd dxk = qrX->solve(zk);
        if (qrX->info() != Success) {
            cerr << "DeferredCorrectionSolver: Warning, Solve failed in correction" << endl;
            break;
        }

        wek -= X*dxk;
        xk += dxk;
        lk += tau*wek.head(Crows);
        // Test for convergence
        double t1 = wek.head(Crows).squaredNorm();
        VectorXd tvec = tau*wek;
        tvec.head(Crows) = lk;
        double t2 = (X.transpose()*tvec).squaredNorm();
        if (t1+t2 <= tolmod2) {
            converged = true;
            break;
        }
    }
    
    // Finalization
    if (!converged) {
        cerr << "DeferredCorrectionSolver: Warning, not converged!" << endl;
    }
    
    timings.tslv = ((double)(clock()-tslv))/CLOCKS_PER_SEC;

    return GridFkt(mat->getspace(),xk);
}

LSCMSparseMatrix LSCM::DeferredCorrectionSolver::genmatT()
{
    // This version should be more stable w r t Householder QR.
    // NOTE: This version uses less memory than genmatS() at the
    //       expense of
    //       explicit memory management
    // NOTE: A and C are assumed to be comressed as output of
    //       LinLSQMatrices!
    LSCMSparseMatrix A = mat->genA();
    LSCMSparseMatrix C = mat->genC();
    A.makeCompressed();  // For security reasons
    C.makeCompressed();
    Arows = A.rows();
    Crows = C.rows();
    Acols = A.cols();

    auto Acolptr = A.outerIndexPtr();
    auto Arowidx = A.innerIndexPtr();
    auto Avals = A.valuePtr();
    auto Ccolptr = C.outerIndexPtr();
    auto Crowidx = C.innerIndexPtr();
    auto Cvals = C.valuePtr();
    
    Xrows = Arows+Crows;
    LSCMint Xnnz = A.nonZeros()+C.nonZeros();
    Xvals = new double[Xnnz];
    Xcolptr = new LSCMindex[Acols+1];
    Xrowidx = new LSCMindex[Xnnz];
    
    LSCMindex Xstj = 0;
    for (LSCMindex j = 0; j < Acols; ++j) {
        auto Astj = Acolptr[j];
        auto Alenj = Acolptr[j+1]-Astj;
        auto Cstj = Ccolptr[j];
        auto Clenj = Ccolptr[j+1]-Cstj;
        
        Xcolptr[j] = Xstj;
        for (LSCMindex ii = 0; ii < Clenj; ++ii) {
            Xrowidx[Xstj+ii] = Crowidx[Cstj+ii];  // memcpy()???
            Xvals[Xstj+ii] = weight*Cvals[Cstj+ii];
        }
        Xstj += Clenj;
        for (LSCMindex ii = 0; ii < Alenj; ++ii) {
            Xrowidx[Xstj+ii] = Arowidx[Astj+ii]+Crows;
            Xvals[Xstj+ii] = Avals[Astj+ii];  // memcpy()???
        }
        Xstj += Alenj;
    }
    Xcolptr[Acols] = Xstj;

    Map<LSCMSparseMatrix> X(Xrows,Acols,Xnnz,Xcolptr,Xrowidx,Xvals);
    return X;
}
