Program Listing for File LinearSolver.hpp

Return to documentation for file (include/networkit/numerics/LinearSolver.hpp)

/*
 * LinearSolver.hpp
 *
 *  Created on: 30.10.2014
 *      Author: Michael Wegner (michael.wegner@student.kit.edu)
 */

#ifndef NETWORKIT_NUMERICS_LINEAR_SOLVER_HPP_
#define NETWORKIT_NUMERICS_LINEAR_SOLVER_HPP_

#include <functional>
#include <limits>
#include <networkit/algebraic/Vector.hpp>
#include <networkit/graph/Graph.hpp>

namespace NetworKit {

struct SolverStatus {
    count numIters;  // number of iterations needed during solve phase
    double residual; // absolute final residual
    bool converged;  // flag of conversion status
};

template <class Matrix>
class LinearSolver {
protected:
    double tolerance;

public:
    LinearSolver(const double tolerance) : tolerance(tolerance) {}
    virtual ~LinearSolver() = default;

    virtual void setup(const Matrix &matrix) = 0;

    virtual void setup(const Graph &graph);

    virtual void setupConnected(const Matrix &matrix) = 0;

    virtual void setupConnected(const Graph &graph);

    virtual SolverStatus solve(const Vector &rhs, Vector &result,
                               count maxConvergenceTime = 5 * 60 * 1000,
                               count maxIterations = std::numeric_limits<count>::max()) const = 0;

    virtual std::vector<SolverStatus>
    parallelSolve(const std::vector<Vector> &rhs, std::vector<Vector> &results,
                  count maxConvergenceTime = 5 * 60 * 1000,
                  count maxIterations = std::numeric_limits<count>::max()) const;

    template <typename RHSLoader, typename ResultProcessor>
    void parallelSolve(const RHSLoader &rhsLoader, const ResultProcessor &resultProcessor,
                       std::pair<count, count> rhsSize, count maxConvergenceTime = 5 * 60 * 1000,
                       count maxIterations = std::numeric_limits<count>::max()) const {
        count n = rhsSize.first;
        count m = rhsSize.second;
        Vector rhs(m);
        Vector result(m);
        for (index i = 0; i < n; ++i) {
            solve(rhsLoader(i, rhs), result, maxConvergenceTime, maxIterations);
            resultProcessor(i, result);
        }
    }
};

template <class Matrix>
void LinearSolver<Matrix>::setup(const Graph &graph) {
    setup(Matrix::laplacianMatrix(graph));
}

template <class Matrix>
void LinearSolver<Matrix>::setupConnected(const Graph &graph) {
    setupConnected(Matrix::laplacianMatrix(graph));
}

template <class Matrix>
std::vector<SolverStatus>
LinearSolver<Matrix>::parallelSolve(const std::vector<Vector> &rhs, std::vector<Vector> &results,
                                    count maxConvergenceTime, count maxIterations) const {
    std::vector<SolverStatus> stats(rhs.size());
    for (index i = 0; i < rhs.size(); ++i) {
        stats[i] = solve(rhs[i], results[i], maxConvergenceTime, maxIterations);
    }
    return stats;
}

} /* namespace NetworKit */

#endif // NETWORKIT_NUMERICS_LINEAR_SOLVER_HPP_