/* Test driver for the linear time stepper
 *
 * C Michael Hanke 2023
 * Version 2023-05-15
 */

#include "LinCaMo.hpp"
#include "LinLSQStepper.hpp"
#include "LinLSQMatrices.hpp"
#include "Xn.hpp"
#include "Yn.hpp"
#include "Chebyshev.hpp"
#include "GaussLegendre.hpp"
#include "DirectSolver.hpp"
#include "GnuPlot.hpp"
#include <Eigen/Dense>
#include <vector>
#include <string>
#include <iostream>
#include <cmath>
#include <cstdlib>

using namespace Eigen;
using namespace std;
using namespace LSCM;

#define rho 5.0

#include "LinCaMo_exsol.cpp"

#define A 0.0
#define B 5.0

#define L 10
#define n 3

#define N 6
#define M 7

int main() {
    // data
    Yn::LSQnorm norm = Yn::LSQ_L2;
    auto dae = make_shared<LinCaMo>(rho);
    auto basis = make_shared<Chebyshev>(N);
    auto grid1 = make_shared<Grid>(A,B,L*n);
    auto colloc = make_shared<Yn>(make_shared<GaussLegendre>(M),norm);
    vector<bool> D1 = dae->getD();
    auto space1 = make_shared<Xn>(D1,grid1,basis);
    auto linlsq1 = make_shared<LinLSQMatrices>(dae,space1,colloc);
    
    DirectSolver slv(linlsq1);
    LinLSQSolver& slvp = slv;
    
    GridFkt sol1 = slvp.solve();
    
    // Error
    GridFkt exsol1(space1,LinCaMo_ex);
    cout << "DAE solved using the global approach!" << endl << "These are the errors:" << endl;
    cout << "Inf: " << sol1.infdist(LinCaMo_ex) << endl;
    cout << "L2:  " << sol1.l2dist(LinCaMo_ex) << endl;
    cout << "H1:  " << sol1.h1dist(LinCaMo_ex,LinCaMo_dex) << endl;
    
    // And now do the same with the stepper
    auto grid2 = make_shared<Grid>(A,B,L);
    
    // GO
    auto slvn = make_shared<DirectSolver>();
    LinLSQStepper stepper(dae,basis,colloc,slvn);
    auto sol2 = stepper.solve(grid2,n,true);
    
    // Computation of errors
    double errinf = 0.0, errL2 = 0.0, errH1 = 0.0;
    for (auto &subsol : *sol2) {
        double erri = subsol.infdist(LinCaMo_ex);
        double err1 = subsol.h1dist(LinCaMo_ex,LinCaMo_dex);
        double err2 = subsol.l2dist(LinCaMo_ex);
        if (errinf < erri) errinf = erri;
        errL2 += err2*err2;
        errH1 += err1*err1;
    }
    cout << "DAE solved using the stepper!" << endl << "These are the errors:" << endl;
    cout << "Inf: " << errinf << endl;
    cout << "L2:  " << sqrt(errL2) << endl;
    cout << "H1:  " << sqrt(errH1) << endl;
    
    // Plot the solution
    GnuPlot gp1;
    vector<int> comp1 = {0,1,2};
    vector<string> label1 = {"x_1","x_2","x_3"};
    gp1.plot(sol2,comp1,label1);
    gp1.show();
    
    GnuPlot gp2;
    vector<int> comp2 = {3,4,5};
    vector<string> label2 = {"x_4","x_5","x_6"};
    gp2.plot(sol2,comp2,label2);
    gp2.show();

    GnuPlot gp3;
    vector<int> comp3 = {6};
    vector<string> label3 = {"x_7"};
    gp3.plot(sol2,comp3,label3);
    gp3.show();
    // DONE
    return 0;
}
    
    
