/**
 * @file Solver.tpp
 * @author Giulio Romualdi
 * @copyright Released under the terms of the BSD 3-Clause License
 * @date 2018
 */

#include <iostream>
#ifndef OSQP_EIGEN_OSQP_IS_V1
#include <auxil.h>
#include <scaling.h>
#endif

#include "Debug.hpp"

template <typename Derived>
bool OsqpEigen::Solver::updateHessianMatrix(
    const Eigen::SparseCompressedBase<Derived>& hessianMatrix)
{
    if (!m_isSolverInitialized)
    {
        debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] The solver has not been "
                         "initialized."
                      << std::endl;
        return false;
    }

    if (((c_int)hessianMatrix.rows() != getData()->n)
        || ((c_int)hessianMatrix.cols() != getData()->n))
    {
        debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] The hessian matrix has to be a "
                         "nxn matrix"
                      << std::endl;
        return false;
    }

    // evaluate the triplets from old and new hessian sparse matrices
    if (!OsqpEigen::SparseMatrixHelper::osqpSparseMatrixToTriplets(getData()->P,
                                                                   m_oldHessianTriplet))
    {
        debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] Unable to evaluate triplets "
                         "from the old hessian matrix."
                      << std::endl;
        return false;
    }
    if (!OsqpEigen::SparseMatrixHelper::eigenSparseMatrixToTriplets(hessianMatrix,
                                                                    m_newHessianTriplet))
    {
        debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] Unable to evaluate triplets "
                         "from the old hessian matrix."
                      << std::endl;
        return false;
    }

    selectUpperTriangularTriplets(m_newHessianTriplet, m_newUpperTriangularHessianTriplets);

    // try to update the hessian matrix without reinitialize the solver
    // according to the osqp library it can be done only if the sparsity pattern of the hessian
    // matrix does not change.

    if (evaluateNewValues(m_oldHessianTriplet,
                          m_newUpperTriangularHessianTriplets,
                          m_hessianNewIndices,
                          m_hessianNewValues))
    {
        if (m_hessianNewValues.size() > 0)
        {
#ifdef OSQP_EIGEN_OSQP_IS_V1
            if (osqp_update_data_mat(m_solver.get(),
                                     m_hessianNewValues.data(),
                                     m_hessianNewIndices.data(),
                                     m_hessianNewIndices.size(),
                                     nullptr,
                                     nullptr,
                                     0)
                != 0)
            {
#else
            if (osqp_update_P(m_workspace.get(),
                              m_hessianNewValues.data(),
                              m_hessianNewIndices.data(),
                              m_hessianNewIndices.size())
                != 0)
            {
#endif
                debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] Unable to update "
                                 "hessian matrix."
                              << std::endl;
                return false;
            }
        }
    } else
    {
        // the sparsity pattern has changed
        // the solver has to be setup again

        // get the primal and the dual variables

        if (!getPrimalVariable(m_primalVariables))
        {
            debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] Unable to get the primal "
                             "variable."
                          << std::endl;
            return false;
        }

        if (!getDualVariable(m_dualVariables))
        {
            debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] Unable to get the dual "
                             "variable."
                          << std::endl;
            return false;
        }

        // clear old hessian matrix
        m_data->clearHessianMatrix();

        // set new hessian matrix
        if (!m_data->setHessianMatrix(hessianMatrix))
        {
            debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] Unable to update the "
                             "hessian matrix in "
                          << "OptimizaroData object." << std::endl;
            return false;
        }

        // clear the old solver
        clearSolver();

        // initialize a new solver
        if (!initSolver())
        {
            debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] Unable to Initialize the "
                             "solver."
                          << std::endl;
            return false;
        }

        // set the old primal and dual variables
        if (!setPrimalVariable(m_primalVariables))
        {
            debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] Unable to set the primal "
                             "variable."
                          << std::endl;
            return false;
        }

        if (!setDualVariable(m_dualVariables))
        {
            debugStream() << "[OsqpEigen::Solver::updateHessianMatrix] Unable to set the dual "
                             "variable."
                          << std::endl;
            return false;
        }
    }
    return true;
}

template <typename Derived>
bool OsqpEigen::Solver::updateLinearConstraintsMatrix(
    const Eigen::SparseCompressedBase<Derived>& linearConstraintsMatrix)
{
    if (!m_isSolverInitialized)
    {
        debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] The solver has not "
                         "been initialized."
                      << std::endl;
        return false;
    }

    if (((c_int)linearConstraintsMatrix.rows() != getData()->m)
        || ((c_int)linearConstraintsMatrix.cols() != getData()->n))
    {
        debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] The constraints "
                         "matrix has to be a mxn matrix"
                      << std::endl;
        return false;
    }

    // evaluate the triplets from old and new hessian sparse matrices

    if (!OsqpEigen::SparseMatrixHelper::osqpSparseMatrixToTriplets(getData()->A,
                                                                   m_oldLinearConstraintsTriplet))
    {
        debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] Unable to evaluate "
                         "triplets from the old hessian matrix."
                      << std::endl;
        return false;
    }
    if (!OsqpEigen::SparseMatrixHelper::eigenSparseMatrixToTriplets(linearConstraintsMatrix,
                                                                    m_newLinearConstraintsTriplet))
    {
        debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] Unable to evaluate "
                         "triplets from the old hessian matrix."
                      << std::endl;
        return false;
    }

    // try to update the linear constraints matrix without reinitialize the solver
    // according to the osqp library it can be done only if the sparsity pattern of the
    // matrix does not change.

    if (evaluateNewValues(m_oldLinearConstraintsTriplet,
                          m_newLinearConstraintsTriplet,
                          m_constraintsNewIndices,
                          m_constraintsNewValues))
    {
        if (m_constraintsNewValues.size() > 0)
        {
#ifdef OSQP_EIGEN_OSQP_IS_V1
            if (osqp_update_data_mat(m_solver.get(),
                                     nullptr,
                                     nullptr,
                                     0,
                                     m_constraintsNewValues.data(),
                                     m_constraintsNewIndices.data(),
                                     m_constraintsNewIndices.size())
                != 0)
            {
#else
            if (osqp_update_A(m_workspace.get(),
                              m_constraintsNewValues.data(),
                              m_constraintsNewIndices.data(),
                              m_constraintsNewIndices.size())
                != 0)
            {
#endif
                debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] Unable to "
                                 "update linear constraints matrix."
                              << std::endl;
                return false;
            }
        }
    } else
    {
        // the sparsity pattern has changed
        // the solver has to be setup again

        // get the primal and the dual variables

        if (!getPrimalVariable(m_primalVariables))
        {
            debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] Unable to get the "
                             "primal variable."
                          << std::endl;
            return false;
        }

        if (!getDualVariable(m_dualVariables))
        {
            debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] Unable to get the "
                             "dual variable."
                          << std::endl;
            return false;
        }

        // clear old linear constraints matrix
        m_data->clearLinearConstraintsMatrix();

        // set new linear constraints matrix
        if (!m_data->setLinearConstraintsMatrix(linearConstraintsMatrix))
        {
            debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] Unable to update "
                             "the hessian matrix in "
                          << "Data object." << std::endl;
            return false;
        }

        // clear the old solver
        clearSolver();

        if (!initSolver())
        {
            debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] Unable to "
                             "Initialize the solver."
                          << std::endl;
            return false;
        }

        // set the old primal and dual variables
        if (!setPrimalVariable(m_primalVariables))
        {
            debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] Unable to set the "
                             "primal variable."
                          << std::endl;
            return false;
        }

        if (!setDualVariable(m_dualVariables))
        {
            debugStream() << "[OsqpEigen::Solver::updateLinearConstraintsMatrix] Unable to set the "
                             "dual variable."
                          << std::endl;
            return false;
        }
    }
    return true;
}

template <typename T, int n, int m>
bool OsqpEigen::Solver::setWarmStart(const Eigen::Matrix<T, n, 1>& primalVariable,
                                     const Eigen::Matrix<T, m, 1>& dualVariable)
{
    if (!m_isSolverInitialized)
    {
        debugStream() << "[OsqpEigen::Solver::setWarmStart] The solver is not initialized"
                      << std::endl;
        return false;
    }

    if (primalVariable.rows() != getData()->n)
    {
        debugStream() << "[OsqpEigen::Solver::setWarmStart] The size of the primal variable vector "
                         "has to be equal to "
                      << " the number of variables." << std::endl;
        return false;
    }

    if (dualVariable.rows() != getData()->m)
    {
        debugStream() << "[OsqpEigen::Solver::setWarmStart] The size of the dual variable vector "
                         "has to be equal to "
                      << " the number of constraints." << std::endl;
        return false;
    }

    m_primalVariables = primalVariable.template cast<c_float>();
    m_dualVariables = dualVariable.template cast<c_float>();

#ifdef OSQP_EIGEN_OSQP_IS_V1
    return (osqp_warm_start(m_solver.get(), m_primalVariables.data(), m_dualVariables.data()) == 0);
#else
    return (osqp_warm_start(m_workspace.get(), m_primalVariables.data(), m_dualVariables.data())
            == 0);
#endif
}

template <typename T, int n>
bool OsqpEigen::Solver::setPrimalVariable(const Eigen::Matrix<T, n, 1>& primalVariable)
{
    if (!m_isSolverInitialized)
    {
        debugStream() << "[OsqpEigen::Solver::setPrimalVariable] The solver is not initialized"
                      << std::endl;
        return false;
    }

    if (primalVariable.rows() != getData()->n)
    {
        debugStream() << "[OsqpEigen::Solver::setPrimalVariable] The size of the primal variable "
                         "vector has to be equal to "
                      << " the number of variables." << std::endl;
        return false;
    }

    m_primalVariables = primalVariable.template cast<c_float>();

#ifdef OSQP_EIGEN_OSQP_IS_V1
    return (osqp_warm_start(m_solver.get(), m_primalVariables.data(), nullptr) == 0);
#else
    return (osqp_warm_start_x(m_workspace.get(), m_primalVariables.data()) == 0);
#endif
}

template <typename T, int m>
bool OsqpEigen::Solver::setDualVariable(const Eigen::Matrix<T, m, 1>& dualVariable)
{
    if (dualVariable.rows() != getData()->m)
    {
        debugStream() << "[OsqpEigen::Solver::setDualVariable] The size of the dual variable "
                         "vector has to be equal to "
                      << " the number of constraints." << std::endl;
        return false;
    }

    m_dualVariables = dualVariable.template cast<c_float>();

#ifdef OSQP_EIGEN_OSQP_IS_V1
    return (osqp_warm_start(m_solver.get(), nullptr, m_dualVariables.data()) == 0);
#else
    return (osqp_warm_start_y(m_workspace.get(), m_dualVariables.data()) == 0);
#endif
}

template <typename T, int n>
bool OsqpEigen::Solver::getPrimalVariable(Eigen::Matrix<T, n, 1>& primalVariable)
{
    if (!m_isSolverInitialized)
    {
        debugStream() << "[OsqpEigen::Solver::getPrimalVariable] The solver is not initialized"
                      << std::endl;
        return false;
    }

    if (n == Eigen::Dynamic)
    {
        primalVariable.resize(getData()->n, 1);
    } else
    {
        if (n != getData()->n)
        {
            debugStream() << "[OsqpEigen::Solver::getPrimalVariable] The size of the vector has to "
                             "be equal to the number of variables. (You can use an eigen dynamic "
                             "vector)"
                          << std::endl;
            return false;
        }
    }

#ifdef OSQP_EIGEN_OSQP_IS_V1
    primalVariable = Eigen::Map<Eigen::Matrix<c_float, n, 1>>(m_solver->solution->x, getData()->n)
                         .template cast<T>();
#else
    primalVariable
        = Eigen::Map<Eigen::Matrix<c_float, n, 1>>(m_workspace->x, getData()->n).template cast<T>();
#endif

    return true;
}

template <typename T, int m>
bool OsqpEigen::Solver::getDualVariable(Eigen::Matrix<T, m, 1>& dualVariable)
{
    if (!m_isSolverInitialized)
    {
        debugStream() << "[OsqpEigen::Solver::getDualVariable] The solver is not initialized"
                      << std::endl;
        return false;
    }

    if (m == Eigen::Dynamic)
    {
        dualVariable.resize(getData()->m, 1);
    } else
    {
        if (m != getData()->m)
        {
            debugStream() << "[OsqpEigen::Solver::getDualVariable] The size of the vector has to "
                             "be equal to the number of constraints. (You can use an eigen dynamic "
                             "vector)"
                          << std::endl;
            return false;
        }
    }

#ifdef OSQP_EIGEN_OSQP_IS_V1
    dualVariable = Eigen::Map<Eigen::Matrix<c_float, m, 1>>(m_solver->solution->y, getData()->m)
                       .template cast<T>();
#else
    dualVariable
        = Eigen::Map<Eigen::Matrix<c_float, m, 1>>(m_workspace->y, getData()->m).template cast<T>();
#endif

    return true;
}

template <typename T>
bool OsqpEigen::Solver::evaluateNewValues(const std::vector<Eigen::Triplet<T>>& oldMatrixTriplet,
                                          const std::vector<Eigen::Triplet<T>>& newMatrixTriplet,
                                          std::vector<c_int>& newIndices,
                                          std::vector<c_float>& newValues) const
{
    // When updating the matrices for osqp, we need to provide the indices to modify of the value
    // vector. The following can work since, when extracting triplets from osqp sparse matrices, the
    // order of the triplets follows the same order of the value vector.
    //  check if the sparsity pattern is changed
    size_t valuesAdded = 0;
    if (newMatrixTriplet.size() == oldMatrixTriplet.size())
    {
        for (int i = 0; i < newMatrixTriplet.size(); i++)
        {
            // check if the sparsity pattern is changed
            if ((newMatrixTriplet[i].row() != oldMatrixTriplet[i].row())
                || (newMatrixTriplet[i].col() != oldMatrixTriplet[i].col()))
                return false;

            // check if an old value is changed
            if (newMatrixTriplet[i].value() != oldMatrixTriplet[i].value())
            {
                if (valuesAdded >= newValues.size())
                {
                    newValues.push_back((c_float)newMatrixTriplet[i].value());
                    newIndices.push_back((c_int)i);
                    valuesAdded++;
                } else
                {
                    newValues[valuesAdded] = static_cast<c_float>(newMatrixTriplet[i].value());
                    newIndices[valuesAdded] = static_cast<c_int>(i);
                    valuesAdded++;
                }
            }
        }
        newValues.erase(newValues.begin() + valuesAdded, newValues.end());
        newIndices.erase(newIndices.begin() + valuesAdded, newIndices.end());
        return true;
    }
    return false;
}

template <typename T>
void OsqpEigen::Solver::selectUpperTriangularTriplets(
    const std::vector<Eigen::Triplet<T>>& fullMatrixTriplets,
    std::vector<Eigen::Triplet<T>>& upperTriangularMatrixTriplets) const
{

    int upperTriangularTriplets = 0;
    for (int i = 0; i < fullMatrixTriplets.size(); ++i)
    {
        if (fullMatrixTriplets[i].row() <= fullMatrixTriplets[i].col())
        {
            if (upperTriangularTriplets < upperTriangularMatrixTriplets.size())
            {
                upperTriangularMatrixTriplets[upperTriangularTriplets] = fullMatrixTriplets[i];
            } else
            {
                upperTriangularMatrixTriplets.push_back(fullMatrixTriplets[i]);
            }
            upperTriangularTriplets++;
        }
    }

    upperTriangularMatrixTriplets.erase(upperTriangularMatrixTriplets.begin()
                                            + upperTriangularTriplets,
                                        upperTriangularMatrixTriplets.end());
}
