/* File: LinLSQSolver.hpp
 *
 * This class is intended to approximate the solution to nonlinear DAEs for a
 * given grid and approximation space. The main functions are solve and
 * nlsq.
 * 
 * Copyright (C) Michael Hanke 2019
 * Version: 2022-06-10
 */

/* 
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

*/

#include "LSQSolver.hpp"
#include "LinLSQSolver.hpp"
#include "NlDAE.hpp"
#include "LinearizedDAE.hpp"
#include "Xn.hpp"
#include "Yn.hpp"
#include "GridFkt.hpp"
#include <Eigen/Dense>
#include <functional>
#include <memory>
#include <iostream>
#include <cstdlib>

using namespace LSCM;
using namespace std;
using namespace Eigen;

GridFkt LSCM::LSQSolver::solve(function<VectorXd(double)> x0) const
{
    double err = numeric_limits<double>::infinity();
    GridFkt x(space,x0);
    LSCMint it = 0;
    
    cout << "Iteration history:" << endl;
    while (err > tol && it <= itmax) {
        shared_ptr<LinearizedDAE> lindae = 
            make_shared<LinearizedDAE>(dae,x);
        
        auto mat = make_shared<LinLSQMatrices>(lindae,space,colloc);
        mat->setAlpha(alpha);
        mat->setBCversion(BCversion);
        
        // Solver selection
        linsolver.replaceLSQMatrices(mat);
        GridFkt zn = linsolver.solve();
        
        x = x-zn;
        err = zn.h1norm();
        
        cout << "it = " << it << ",  err: " << err << ",  tol: " << tol << endl;
        it++;
    }
    return x;
}

GridFkt LSCM::LSQSolver::solveD(function<VectorXd(double)> x0) const
{
    double err = numeric_limits<double>::infinity();
    GridFkt x(space,x0);
    LSCMint it = 0;
    auto lindae = make_shared<LinearizedDAE>(dae,x);
    auto mat = make_shared<LinLSQMatrices>(lindae,space,colloc);
    mat->setAlpha(alpha);
    mat->setBCversion(BCversion);

    VectorXd f = mat->genrhsq();
    double normold = f.norm();

    cout << "Iteration history:" << endl;

    while (err > tol && it <= itmax) {
        
        linsolver.replaceLSQMatrices(mat);
        shared_ptr<LinLSQSolver> slv = nullptr;
        GridFkt zn = linsolver.solve(f);
        
        // Damping cycle
        double lam = it ? 1.0 : 0.01;
        while (true) {
            GridFkt xn = x-lam*zn;
            lindae = make_shared<LinearizedDAE>(dae,xn);
            LinLSQMatrices mat(lindae,space,colloc);
            mat.setAlpha(alpha);
            mat.setBCversion(BCversion);

            f = mat.genrhsq();
            // Effectively, this is the L2-norm in Yn!
            double normnew = f.norm();
            cout << normold << ",   " << normnew << endl;
            if ((it == 0) || (normnew < 0.99*normold)) { 
                x = xn;
                normold = normnew;
                break;
            }
            lam = lam*0.5;
            if (lam < lambdamin) {
                //cerr << "Warning solveD: Underflow damping parameter" << endl;
                break;
            }
        }
        err = zn.h1norm();
        
        cout << "it = " << it << ",  err: " << err << ",  tol: " << tol << endl;

        it++;
    }
    return x;
}

// The following implementation follows the code NLSQ_ERR.m that is based on
// P Deuflhard's book. This implementation is more Fortran-like.

GridFkt LSCM::LSQSolver::nlsq(function<VectorXd(double)> x0) const {
    
    double xthresh = tol;  // for highly nonlinear problems
    double esterr = 1.0;
    double lambdak = 0.01;  // Should be \lambda_0 - initial damping parameters
    const double delk = 0.0;  // Avoid to compute: expensive!
    double thetak, dxk_norm, dxkm1_norm, dxbark_norm, lambdakm1;
    GridFkt dxk(space), dxkm1(space), dxbark(space), dbark(space);
    
    // weighting
    // OBS: We are using the L2-norm here. This is not consistent since our norm is
    // H^1_D!
    GridFkt xk(space,x0);
    VectorXd xwght = xk.cl2norm();
    for (LSCMint i = 0; i < xwght.size(); ++i)
        if (xwght(i) < xthresh) xwght(i) = xthresh;
    
    shared_ptr<LinLSQMatrices> matkm1;
    
    LSCMint k = 0;
    shared_ptr<LinearizedDAE> lindae = make_shared<LinearizedDAE>(dae,xk);
    shared_ptr<LinLSQMatrices> mat = make_shared<LinLSQMatrices>(lindae,space,colloc);
    mat->setAlpha(alpha);
    mat->setBCversion(BCversion);

    VectorXd Fk = mat->genrhsq();
    
    while (k < itmax) {
        linsolver.replaceLSQMatrices(mat);
        dxk = linsolver.solve(-Fk); // includes factorization!
        dxk_norm = (dxk/xwght).h1norm();  // better to use a different norm??
        // cerr << "dxk:" << endl << dxk_norm << endl;
       
        if (k > 0) {
            thetak = dxk_norm/dxkm1_norm;
            // a-priori estimate for lambda
            VectorXd v = Fk+matkm1->feval(dxbark);
            dbark = linsolver.solve(-v); // Using already computed factorization!
            GridFkt dxh = dxbark-dxk+dbark;
            double omegabark = (dxh/xwght).h1norm()/(lambdakm1*dxkm1_norm*dxbark_norm);
            lambdak = min(1.0,(1-delk)/(omegabark*dxk_norm));
            
            if (lambdak < lambdamin) {
                cerr << "Error NLSQ:: Underflow of damping parameter" << endl;
                exit(1);
            }
            
            // The updating of delk is switched off in NLSQ_ERR since expensive!
            // It should follow here!
        }
        
        // Damping loop
        LSCMint i = 0;
        while (i < imax) {
            GridFkt xkp1 = xk+lambdak*dxk;
            
            auto lindaekp1 = make_shared<LinearizedDAE>(dae,xkp1);
            auto matkp1 = make_shared<LinLSQMatrices>(lindaekp1,space,colloc);
            matkp1->setAlpha(alpha);
            matkp1->setBCversion(BCversion);

            VectorXd Fkp1 = matkp1->genrhsq();
            GridFkt dxbarkp1 = linsolver.solve(-Fkp1);
            double dxbarkp1_norm = (dxbarkp1/xwght).h1norm();
            double thetabark = dxbarkp1_norm/dxk_norm;
            cout << "It: " << k << ", theta: " << thetabark << ", err: " << dxbarkp1_norm <<
            ", lambda: " << lambdak << endl;
            
            // Convergence test
            if ((k > 0) & (dxbarkp1_norm <= tol) ) {
                double esterr = thetak*dxk_norm/(1.0-thetak); // What to do with this??
                cout << "NLSQ: estimated error: " << esterr << endl;
                return xkp1;
            }
            
            // Monotonicity test
            if (thetabark > 1.0-delk*lambdak) {
                GridFkt dxh = dxbarkp1-(1-lambdak)*dxk;
                double hk = 2.0*(dxh/xwght).h1norm()/(lambdak*lambdak*dxk_norm);
                lambdak = min(0.5*lambdak,(1.0-delk)/hk);
                if (lambdak < lambdamin) {
                    cerr << "Error NLSQ:: Underflow of damping parameter" << endl;
                    return xk;
                    //exit(1);
                }
                i++;
            }
            else {
                // Prepare next loop
                xk = xkp1;
                lindae = lindaekp1;
                Fk = Fkp1;
                matkm1 = mat;
                dxkm1 = dxk;
                dxbark = dxbarkp1;
                lambdakm1 = lambdak;
                // Update x weights
                xwght = xkp1.cl2norm();
                for (LSCMint i = 0; i < xwght.size(); ++i)
                    if (xwght(i) < xthresh) xwght(i) = xthresh;
                
                dxkm1_norm = (dxkm1/xwght).h1norm();
                dxbark_norm = (dxbark/xwght).h1norm();
                break;
            }
        }
        if (i == imax) {
            cerr << "Warning NLSQ: Too many damping steps" << endl;
            break;
        }
        
        // Prepare next loop
        k++;
    }
    
    if (k == itmax) {
        cerr << "Warning NLSQ: Max iteration number reached" << endl;
    }
    
    return xk;
}
