/* File: LinLSQMatrices.hpp
 *
 * This class is intended to approximate the solution to linear DAEs for a
 * given grid and approximation space. Some auxiliary routines are made
 * public for testing purposes. However, the main functions are solve and
 * solveL.
 * 
 * Copyright (C) Michael Hanke 2020
 * Version: 2020-04-27
 */

/* 
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

*/

#include "LinLSQMatrices.hpp"
#include "DAE.hpp"
#include "Xn.hpp"
#include "Yn.hpp"
#include "Grid.hpp"
#include "GridFkt.hpp"
#include <Eigen/Dense>
#include <Eigen/SparseCore>
#include <Eigen/SPQRSupport>
#include <unsupported/Eigen/KroneckerProduct>
#include <memory>
#include <vector>
#include <iostream>
#include <cstdlib>

using namespace LSCM;
using namespace std;
using namespace Eigen;

LSCM::LinLSQMatrices::LinLSQMatrices(shared_ptr<DAE> dae_, shared_ptr<Xn> space_,
                 shared_ptr<Yn> colloc_) :
                 dae(dae_), space(space_), colloc(colloc_) {
    D = space->getDptr();
    grid = space->getgrid();
    n = grid->getn();
    m = space->getm();
    k = space->getk();
    l = dae->getl();
    N = space->getN();
    M = colloc->getM();
    nun = space->getdim();
    lsqnorm = colloc->getnorm();
    
    // Check number of collocation points vs degrees of freedom
    if (M <= N) {
        cerr << "LinLSQMatrices: Too few collocation points" << endl;
        exit(1);
    }
    // Check consistency of dae and space
    vector<bool> Ddae = dae->getD();
    if (m != Ddae.size()) {
        cerr << "LinLSQMatrices: space and DAE inconsistent" << endl;
        exit(1);
    }
    for (LSCMint i = 0; i < m; ++i)
        if (D[i] != Ddae[i]) {
            cerr << "LinLSQMatrices: space and DAE inconsistent" << endl;
            exit(1);
        }
}

MatrixXd LSCM::LinLSQMatrices::gensub(LSCMint j) const
{
    if ((j < 0) || (j >= n)) {
        cerr << "gensub: Index out of range" << endl;
        exit(1);
    }

    LSCMint kk = k;
    MatrixXd ABj(M*m,nun);

    double tj1 = (*grid)[j];
    double tj = (*grid)[j+1];
    double h = tj-tj1;
    shared_ptr<vector<double>> cp = colloc->gettau();
    LSCMint koff = 0;
    for (LSCMint l = 0; l < M; l++) {
        double t = tj1+h*(*cp)[l];
        MatrixXd Bl = (dae->B)(t);
        MatrixXd Al = (dae->A)(t);
        VectorXd psil = space->evalsub((*cp)[l],h);
        VectorXd dpsil = space->devalsub((*cp)[l],h);
        for (LSCMint k = 0; k < m; k++) {
            LSCMint joff = 0;
            for (LSCMint i = 0; i < m; i++) {
                if (D[i]) {
                    for (LSCMint j = 0; j <= N; j++) {
                        ABj(koff+k,joff+j) = Bl(k,i)*psil(joff+j)+Al(k,i)*dpsil(joff+j);
                    }
                    joff += N+1;
                }
                else {
                    for (LSCMint j = 0; j < N; j++) {
                        ABj(koff+k,joff+j) = Bl(k,i)*psil(joff+j);
                    }
                    joff += N;
                }
            }
        }
        koff += m;
    }
    
    switch (lsqnorm) {
        case Yn::LSQ_RN:
            return ABj *= sqrt(h/M);
        case Yn::LSQ_INT:
        {
            shared_ptr<vector<double>> zw = colloc->getinteg();
            koff = 0;
            for (LSCMint l = 0; l < M; l++) {
                double sw = sqrt((*zw)[l]*h);
                for (LSCMint i = 0; i < m; i++)
                    for (LSCMint j = 0; j < nun; j++) ABj(koff+i,j) *= sw;
                koff += m;
            }
            return ABj;
        }
        case Yn::LSQ_L2:
        {
            shared_ptr<MatrixXd> L = colloc->getmass();
            MatrixXd ABwj = (kroneckerProduct(*L,MatrixXd::Identity(m,m)))*ABj;
            return ABwj;
        }
        default:
            // Cannot happen!
            cerr << "LinLSQMatrices: Wrong functional" << endl;
            exit(1);
    }
}

MatrixXd LSCM::LinLSQMatrices::genGa() const
{
    LSCMint kk = k;
    MatrixXd Ga = dae->getGa();
    
    VectorXd psil = space->evalr(0);
    MatrixXd Gamat(l,N*m+kk);

    LSCMint koff = 0;
    for (LSCMint k = 0; k < l; k++) {
        LSCMint joff = 0;
        for (LSCMint i = 0; i < m; i++) {
            if (D[i]) {
                for (LSCMint j = 0; j <= N; j++) {
                    Gamat(koff+k,joff+j) = Ga(k,i)*psil(joff+j);
                }
                joff += N+1;
            }
            else {
                for (LSCMint j = 0; j < N; j++) {
                    Gamat(koff+k,joff+j) = Ga(k,i)*psil(joff+j);
                }
                joff += N;
            }
        }
    }

    return Gamat;
}

MatrixXd LSCM::LinLSQMatrices::genGb() const
{
    LSCMint kk = k;
    MatrixXd Gb = dae->getGb();

    VectorXd psil = space->evall(n);
    MatrixXd Gbmat(l,N*m+kk);

    LSCMint koff = 0;
    for (LSCMint k = 0; k < l; k++) {
        LSCMint joff = 0;
        for (LSCMint i = 0; i < m; i++) {
            if (D[i]) {
                for (LSCMint j = 0; j <= N; j++) {
                    Gbmat(koff+k,joff+j) = Gb(k,i)*psil(joff+j);
                }
                joff += N+1;
            }
            else {
                for (LSCMint j = 0; j < N; j++) {
                    Gbmat(koff+k,joff+j) = Gb(k,i)*psil(joff+j);
                }
                joff += N;
            }
        }
    }

    return Gbmat;
}

VectorXd LSCM::LinLSQMatrices::gensubrhs(LSCMint j) const
{
     if ((j < 0) || (j >= n)) {
        cerr << "gensubrhs: Index out of range" << endl;
        exit(1);
    }
    
    VectorXd Qj(M*m);

    double tj1 = (*grid)[j];
    double tj = (*grid)[j+1];
    double h = tj-tj1;
    shared_ptr<vector<double>> cp = colloc->gettau();
    LSCMint koff = 0;
    for (LSCMint l = 0; l < M; l++) {
        VectorXd Ql = (dae->q)(tj1+h*(*cp)[l]);
        Qj.segment(koff,m) = Ql;
        //for (LSCMint i = 0; i < m; i++) Qj(koff+i) = Ql(i);
        koff += m;
    }

    switch (lsqnorm) {
        case Yn::LSQ_RN:
            return Qj *= sqrt(h/M);
        case Yn::LSQ_INT:
        {
            shared_ptr<vector<double>> zw = colloc->getinteg();
            koff = 0;
            for (LSCMint l = 0; l < M; l++) {
                double sw = sqrt((*zw)[l]*h);
                Qj.segment(koff,m) *= sw;
                //for (LSCMint i = 0; i < m; i++) Qj(koff+i) *= sw;
                koff += m;
            }
            return Qj;
        }

        case Yn::LSQ_L2:
        {
            shared_ptr<MatrixXd> L = colloc->getmass();
            VectorXd Qwj = (kroneckerProduct(*L,MatrixXd::Identity(m,m)))*Qj;
            return Qwj;
        }
        default:
            // Cannot happen!
            cerr << "LinLSQMatrices: Wrong functional" << endl;
            exit(1);
    }
}

MatrixXd LSCM::LinLSQMatrices::gencont(const LSCMint j) const
{
    if ((j < 1) || (j >= n)) {
        cerr << "Index out of range! >>> Exit" << endl;
        exit(1);
    }

    LSCMint kk = k;
    
    MatrixXd ccj = MatrixXd::Zero(kk,2*nun);

    double hr = (*grid)[j+1]-(*grid)[j];
    double hl = (*grid)[j]-(*grid)[j-1];
    VectorXd psir = space->evalr(j);
    VectorXd psil = space->evall(j);

    LSCMint joff = 0;
    LSCMint koff = 0;
    for (LSCMint k = 0; k < m; k++) {
        if (D[k]) {
            for (LSCMint l = 0; l <= N; l++) {
                ccj(koff,joff+l) = psil(joff+l);
            }
            for (LSCMint l = 0; l <= N; l++) {
                ccj(koff,nun+joff+l) = -psir(joff+l);
            }
            joff += N+1;
            ++koff;
        }
        else joff += N;
    }
    return ccj;
}

VectorXd LSCM::LinLSQMatrices::genrhsq() const
{
    LSCMint dimrhs = M*m*n+l;

    VectorXd rhs(dimrhs);

    LSCMint loff = 0;
    for (LSCMint j = 0; j < n; ++j) {
        auto tmpi = gensubrhs(j);
        LSCMint idim = tmpi.size();
        rhs.segment(loff,idim) = tmpi;
        //for (LSCMint l = 0; l < idim; l++) rhs(loff+l) = tmpi(l);
        loff += idim;
    }

    // Adding BCs
    rhs.tail(l) = alpha*(dae->getr());

    return rhs;
}

VectorXd LSCM::LinLSQMatrices::fevalsub(LSCMint j, const GridFkt x) const
{
     if ((j < 0) || (j >= n)) {
        cerr << "fevalsub: Index out of range" << endl;
        exit(1);
    }
    
    VectorXd Qj(M*m);

    double tj1 = (*grid)[j];
    double tj = (*grid)[j+1];
    double h = tj-tj1;
    shared_ptr<vector<double>> cp = colloc->gettau();
    LSCMint koff = 0;
    for (LSCMint l = 0; l < M; l++) {
        double tji = tj1+h*(*cp)[l];
        VectorXd Ql = dae->A(tji)*x.deval(tji)+dae->B(tji)*x.eval(tji);
        Qj.segment(koff,m) = Ql;
        koff += m;
    }
    
    switch (lsqnorm) {
        case Yn::LSQ_RN:
            return Qj *= sqrt(h/M);
        case Yn::LSQ_INT:
        {
            shared_ptr<vector<double>> zw = colloc->getinteg();
            koff = 0;
            for (LSCMint l = 0; l < M; l++) {
                double sw = sqrt((*zw)[l]*h);
                Qj.segment(koff,m) *= sw;
                koff += m;
            }
            return Qj;
        }
        case Yn::LSQ_L2:
        {
            shared_ptr<MatrixXd> L = colloc->getmass();
            VectorXd Qwj = (kroneckerProduct(*L,MatrixXd::Identity(m,m)))*Qj;
            return Qwj;
        }
        default:
            // Cannot happen!
            cerr << "LinLSQMatrices: Wrong functional" << endl;
            exit(1);
    }
}

VectorXd LSCM::LinLSQMatrices::feval(const GridFkt x) const
{
    LSCMint dimrhs = M*m*n+l;

    VectorXd rhs(dimrhs);

    LSCMint loff = 0;
    for (LSCMint j = 0; j < n; ++j) {
        auto tmpi = fevalsub(j,x);
        LSCMint idim = tmpi.size();
        rhs.segment(loff,idim) = tmpi;
        loff += idim;
    }

    // Adding BCs
    if (l > 0) {
        double a = x.geta();
        double b = x.getb();
        rhs.tail(l) = alpha*(dae->getGa()*x.eval(a)+dae->getGb()*x.eval(b));
    }

    return rhs;
}

LSCMSparseMatrix LSCM::LinLSQMatrices::genA() const
{
    bool oneD = (n == 1) && (BCversion || (!BCversion && l == 0));
    if (oneD) {
        cerr << "genA: Grid has only one subinterval. Use genA1i() instead" << endl;
        exit(1);
    }
    
    // Dimensions AB
    LSCMint lloc = 0;
    if (BCversion) lloc = l;
    LSCMint rdim = M*m*n+lloc;  // BC
    LSCMint nnz = m*M*nun*n;
    LSCMint alldofs = nun*n;

    vector<LSCMTriplet> DOK;
    DOK.reserve(nnz);  // OK?

    LSCMint iglob = 0;
    LSCMint jglob = 0;

    for (LSCMint ll = 0; ll < n; ++ll) {
        MatrixXd AB = gensub(ll);
        LSCMint ABi = AB.rows();
        for (LSCMint j = 0; j < nun; ++j)
            for (LSCMint i = 0; i < ABi; ++i)
                DOK.push_back(LSCMTriplet(iglob+i,jglob+j,AB(i,j)));
        iglob += ABi;
        jglob += nun;
    }

    // Adding BCs
    if (BCversion) {
        MatrixXd Gamat = alpha*genGa();
        MatrixXd Gbmat = alpha*genGb();

        // first I.C loop
        for(LSCMint j = 0; j < nun; ++j)
            for(LSCMint i = 0; i < lloc; ++i) {
            DOK.push_back(LSCMTriplet(iglob+i,j,Gamat(i,j)));
        }
        // Second I.C loop
        // This procedure will work even for n = 1 since elements with equal
        // coordinates are summed up when forming the sparse matrix!
        jglob -= nun;
        for(LSCMint j = 0; j < nun; j++)
            for(LSCMint i = 0; i < lloc; i++) {
                DOK.push_back(LSCMTriplet(iglob+i,jglob+j,Gbmat(i,j)));
        }
    }

    LSCMSparseMatrix mat(rdim,alldofs);
    mat.setFromTriplets(DOK.begin(),DOK.end());
    struct keepFkt {
        bool operator() (const Index& row, const Index& col, const double& value) const {
            return value != 0.0;
        }
    } kF;
    mat.prune(kF);          // Very expensive!!!
    //mat.makeCompressed();   // Not necessary, part of prune()

    return mat;
}

MatrixXd LSCM::LinLSQMatrices::genA1i() const {
    bool oneD = (n == 1) && (BCversion || (!BCversion && l == 0));
    if (!oneD) {
        cerr << "genA: Grid has more than one subinterval. Use genA() instead" << endl;
        exit(1);
    }

    LSCMint Arow = M*m;
    MatrixXd X(Arow+l,nun);
    X.block(0,0,Arow,nun) = gensub(0);
    X.block(Arow,0,l,nun) = alpha*(genGa()+genGb());
    return X;
}

LSCMSparseMatrix LSCM::LinLSQMatrices::genC() const
{
    // Dimensions CM
    LSCMint lloc = 0;
    if (!BCversion) lloc = l;
    LSCMint rdim = (n-1)*k+lloc;
    LSCMint nnz = 2*k*nun*(n-1);
    LSCMint alldofs = n*nun;

    vector<LSCMTriplet> DOK;
    DOK.reserve(nnz);
    
    // Adding BCs: These are the first equations in the equality
    // constraints!
    if (!BCversion) {
        MatrixXd Gamat = alpha*genGa();
        MatrixXd Gbmat = alpha*genGb();

        // first I.C loop
        for(LSCMint j = 0; j < nun; ++j)
            for(LSCMint i = 0; i < lloc; ++i) {
            DOK.push_back(LSCMTriplet(i,j,Gamat(i,j)));
        }
        // Second I.C loop
        // This procedure will work even for n = 1 since elements with equal
        // coordinates are summed up when forming the sparse matrix!
        LSCMint jglob = alldofs-nun;
        for(LSCMint j = 0; j < nun; j++)
            for(LSCMint i = 0; i < lloc; i++) {
                DOK.push_back(LSCMTriplet(i,jglob+j,Gbmat(i,j)));
        }
    }
    
    LSCMint iglob = lloc;
    LSCMint jglob = 0;
    for (LSCMint l = 1; l < n; ++l) {
        MatrixXd CM = gencont(l);

        for (LSCMint j = 0; j < 2*nun; ++j)
            for (LSCMint i = 0; i < k; ++i)
                DOK.push_back(LSCMTriplet(iglob+i,jglob+j,CM(i,j)));
        iglob += k;
        jglob += nun;
    }
    
    LSCMSparseMatrix mat(rdim,alldofs);
    mat.setFromTriplets(DOK.begin(),DOK.end());
    struct keepFkt {
        bool operator() (const Index& row, const Index& col, const double& value) const {
            return value != 0.0;
        }
    } kF;
    mat.prune(kF);          // Very expensive!!!
    // mat.makeCompressed();

    return mat;
}
