/* File: LinLSQSolver.hpp
 *
 * This class is intended to provide the interface for solvers of linear DAEs
 * for a
 * given grid and approximation space.
 * 
 * Copyright (C) Michael Hanke 2020
 * Version: 2022-06-09
 */

/* 
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

*/

#ifndef LINLSQSOLVER_HPP
#define LINLSQSOLVER_HPP

#include "LSCMConfig.hpp"
#include "LinLSQMatrices.hpp"
#include "GridFkt.hpp"
#include "Eigen/Dense"
#include <memory>

namespace LSCM {

/**
 * Abstract class defining the interface to the linear solvers
 * 
 * For testing purposes, also informations about the timings as well as estimates
 * of the (permanent) memory consumption of different solution
 * steps are collected.
 */
class LinLSQSolver {
private:
    Eigen::ColPivHouseholderQR<Eigen::MatrixXd> QR; // For the generic implementation, n = 1
    bool oneD;
    
protected:
    /** object for generating matrices */
    std::shared_ptr<LinLSQMatrices> mat = nullptr;
    /** has factorize() been called? */
    bool DecompositionAvail = false;

public:
    /**
     * This type contains the timings for the linear solver. Unit is seconds
     */
    typedef struct {
        double tass = 0.0;    /**< time for matrix assembly */
        double tfact = 0.0;   /**< time for factorization */
        double tslv = 0.0;    /**< time for solution */
    } LinTimings;
    
    /**
     *  This type contains estimations for the memory usage (doubles).
     */
    typedef struct {
        LSCMint nnzA = 0;   /**< nonzeros of A */
        LSCMint nnzC = 0;   /**< nonzeros of C */
        LSCMint nWork = 0;  /**< main temporaries */
        LSCMint dimA = 0;   /**< rows of A */
        LSCMint ndim = 0;   /**< number of unknowns */
        LSCMint dimC = 0;   /**< rows of C */
    } MemUsage;
    
protected:
    /** timings for the solver */
    LinTimings timings;
    
    /** memory usage of the solver */
    MemUsage mem;
    
protected:
    bool errorest = false;
    double error = std::numeric_limits<double>::quiet_NaN();
    bool homConstraints;
    LSCMint numBC;
    
    virtual void resetLocalVars() = 0;
    virtual void factorizeExec() = 0;
    virtual GridFkt solveExec(const Eigen::VectorXd& f) = 0;
        
    
public:
    /**
     * Default constructor
     */
    LinLSQSolver() {}
    
    /**
     * Constructor
     * 
     * @param[in] genmat pointer to matrix generating object
     */
    LinLSQSolver(std::shared_ptr<LinLSQMatrices> genmat) : mat(genmat)
    {}
    
    /**
     * Replace matrix generating object
     * 
     * @param[in] genmat pointer to matrix generating object
     */
    void replaceLSQMatrices(std::shared_ptr<LinLSQMatrices> genmat) {
        if (mat != nullptr) {
            resetLocalVars();
        }
        mat = genmat;
        DecompositionAvail = false;
        timings.tass = 0.0;
        timings.tfact = 0.0;
        timings.tslv = 0.0;
        mem.nnzA = 0;
        mem.nnzC = 0;
        mem.nWork = 0;
        mem.dimA = 0;
        mem.ndim = 0;
        mem.dimC = 0;
    }

    /**
     * Factorization step of the linear solver. To be defined by the solver classes.
     */
    void factorize();
    
    /**
     * Solver interface. To be defined by the solver classes. If the
     * factorize() method was not called before, it will be done here.
     * 
     * \param[in] f right-hand side vector
     */
    GridFkt solve(const Eigen::VectorXd& f);
    
    /**
     * Solver interface. To be defined by the solver classes. If the
     * factorize() methods was not called before, it will be done here.
     */
    GridFkt solve() {
        Eigen::VectorXd q = mat->genrhsq();
        return solve(q);
    }
    
    // TODO: Solve directly without saving the factors
    // This may lead to a saving in memory consumption.
    //virtual GridFkt SCsolve() = 0;
    
    virtual ~LinLSQSolver() = default;
    
protected:
    /**
     * This is a generic implementation for only one subinterval
     * with no equality constraints.
     */
    void factorize1i();
    
    /**
     * This is a generic implementation for only one subinterval
     * with no equality constraints. If the
     * factorize() methods was not called before, it will be done here.
     */
    GridFkt solve1i() {
        return solve1i(mat->genrhsq());
    }
    
    /**
     * This is a generic implementation for the of only one subinterval.
     * In that case, there are no equality constraints. If the
     * factorize() methods was not called before, it will be done here.
     */
    GridFkt solve1i(const Eigen::VectorXd& f);
    
public:
    /**
     * Returns timings for the last solution process
     */
    LinTimings gettimings() const { return timings; }
    
    /**
     * Returns memory usage for the last solution process
     */
    MemUsage getmemusage() const { return mem; }
    
    /**
     * Sets a flag for updating the error estimation. Will only be evaluated by a few
     * solvers.
     * 
     * \param[in] newval new value for errorest
     */
    void setErrorest(const bool newval) { errorest = newval; }
    
    /**
     * The error estimation for the last solve. If none is available, NaN is
     * returned.
     * 
     * \returns error estimation
     */
    double esterr() const {return error; }
};

}

#endif
