/* File LinLSQStepper.cpp
 * 
 * Implementation of a timestepper for linear DAEs
 * 
 * C Michael Hanke 2021
 * Version: 2021-10-18
 */

#include "LinLSQStepper.hpp"
#include "AccurateIC.hpp"
#include "LinLSQSolver.hpp"
#include "LinLSQMatrices.hpp"
#include "Radau.hpp"
#include "Chebyshev2.hpp"
#include "GridFkt.hpp"
#include "Grid.hpp"
#include "Basis.hpp"
#include "Xn.hpp"
#include "Yn.hpp"
#include "Eigen/Dense"
#include <memory>
#include <deque>
#include <ctime>
#include <cstdlib>

using namespace LSCM;
using namespace Eigen;
using namespace std;

namespace {
    // This class implements the interface to a class with modified initial
    // conditions. It is assumed that sanity checks have been carried out in
    // solve()!
    class DAEmod : public DAE {
    private:
        shared_ptr<DAE> daeobj;
        LSCMint N_;
        shared_ptr<Chebyshev2> nodes;
        VectorXd cpnew;
        MatrixXd Gupdt;
        LSCMint m, l;
        bool valid = false;
    public:
        DAEmod(shared_ptr<DAE> dae, LSCMint N) : daeobj(dae), N_(N) {
            Gupdt = dae->getGa();
            cpnew = dae->getr();
            m = Gupdt.cols();
            l = Gupdt.rows();
            LSCMint M = N+1;
            if (M%2 == 0) ++M;
            nodes = make_shared<Chebyshev2>(M);
        }
        
        void resetInival(double t, const VectorXd ininew, double H,
                         AccurateIC::CILocation loc) {
            AccurateIC aic(N_,nodes);
            Gupdt = aic.getGa(daeobj,t,H,loc);
            cpnew = Gupdt*ininew;
            if (Gupdt.rows() != l) {
                cerr << "LinLSQStepper: Change of number of dynamical degrees of freedom!" << endl;
                exit(1);
            }
        }
        
        Eigen::MatrixXd A(double t) { return daeobj->A(t); }
        
        Eigen::MatrixXd B(double t) { return daeobj->B(t); }
        
        Eigen::VectorXd q(double t) { return daeobj->q(t); }
        
        std::vector<bool> getD() { return daeobj->getD(); }
        
        Eigen::MatrixXd getGa() { return Gupdt; }
        
        Eigen::MatrixXd getGb() { return Eigen::MatrixXd::Zero(l,m); }
        
        Eigen::VectorXd getr() { return cpnew; }
        
        LSCMint getl() { return l; }
        
        LSCMint getindex() { return daeobj->getindex(); }
    };

}

shared_ptr<deque<GridFkt>> LinLSQStepper::solve(shared_ptr<Grid> grid,
                  LSCMint micro,
                  bool checkIC,
                  bool BCversion) {
    clock_t trun = clock();
   
    LSCMint n = grid->getn();
    double t0 = (*grid)[0];
    auto D = dae_->getD();
    shared_ptr<DAEmod> DAEupd = make_shared<DAEmod>(dae_,basis_->getdim());

    // This is a cheap check.
    auto Gb = dae_->getGb();
    if ((Gb.array().abs() > 0).any()) {
        cerr << "LinLSQStepper: Not an initial value problem!" << endl;
        exit(1);
    }
    // This one is more expensive!
    if (checkIC) {
        shared_ptr<Radau> nodes = make_shared<Radau>(colloc_->getM());
        AccurateIC aic(basis_->getdim(),nodes);
        double err = aic.opening(dae_->getGa(),dae_,t0,((*grid)[1]-t0)/micro,
                                 AccurateIC::CI_LEFT);
        cerr << "Opening of provided initial condition: " << err << endl;
        if (err > sqrt(numeric_limits<double>::epsilon())) {
            cerr << "Warning: Provided initial value may be accurate!" << endl;
        }
    }
    
    // This construct will not work since GridFkt does not have a default constructor!
    //shared_ptr<vector<GridFkt>> res = make_shared<vector<GridFkt>>(n);
    // Workaround: Start with an empty container (or, use C-like programming)
    shared_ptr<deque<GridFkt>> res = make_shared<deque<GridFkt>>();
    
    for (LSCMint i = 1; i <= n; ++i) {
        double t1 = (*grid)[i];
        shared_ptr<Grid> gloc = make_shared<Grid>(t0,t1,micro);
        shared_ptr<Xn> space = make_shared<Xn>(D,gloc,basis_);
        shared_ptr<LinLSQMatrices> mat =
                make_shared<LinLSQMatrices>(DAEupd,space,colloc_);
        mat->setBCversion(BCversion);
        slv_->replaceLSQMatrices(mat);
        (*res).push_back(slv_->solve());
        
        auto tmgs = slv_->gettimings();
        timings.tsub += tmgs.tass+tmgs.tfact+tmgs.tslv;

        DAEupd->resetInival(t1,(res->back())(t1),(t1-t0)/micro,
                            AccurateIC::CI_CENTER);
        t0 = t1;
    }

    // DONE
    timings.tiv = ((double)(clock()-trun))/CLOCKS_PER_SEC-timings.tsub;
    return res;
}
