#ifndef SOPT_PRIMAL_DUAL_H
#define SOPT_PRIMAL_DUAL_H

#include "sopt/config.h"
#include <functional>
#include <limits>
#include <tuple> // for std::tuple<>
#include <utility> // for std::move<>
#include "sopt/exception.h"
#include "sopt/linear_transform.h"
#include "sopt/logging.h"
#include "sopt/types.h"

#ifdef SOPT_MPI
#include "sopt/mpi/communicator.h"
#include "sopt/mpi/utilities.h"
#endif

namespace sopt::algorithm {

//! \brief Primal Dual Algorithm
//! \details \f$\min_{x, z} f(x) + h(z)\f$ subject to \f$Φx + z = y\f$. \f$y\f$ is a target vector.
template <typename SCALAR>
class PrimalDual {
 public:
  //! Scalar type
  using value_type = SCALAR;
  //! Scalar type
  using Scalar = value_type;
  //! Real type
  using Real = typename real_type<Scalar>::type;
  //! Type of then underlying vectors
  using t_Vector = Vector<Scalar>;
  //! Type of the Ψ and Ψ^H operations, as well as Φ and Φ^H
  using t_LinearTransform = LinearTransform<t_Vector>;
  //! Type of the convergence function
  using t_IsConverged = std::function<bool (const t_Vector &, const t_Vector &)>;
  //! Type of the constraint function
  using t_Constraint = std::function<void (t_Vector &, const t_Vector &)>;
  //! Type of random update function
  using t_Random_Updater = std::function<bool ()>;
  //! Type of the convergence function
  using t_Proximal = ProximalFunction<Scalar>;

  //! Values indicating how the algorithm ran
  struct Diagnostic {
    //! Number of iterations
    t_uint niters;
    //! Wether convergence was achieved
    bool good;
    //! the residual from the last iteration
    t_Vector residual;

    Diagnostic(t_uint niters = 0u, bool good = false)
        : niters(niters), good(good), residual(t_Vector::Zero(0)) {}
    Diagnostic(t_uint niters, bool good, t_Vector &&residual)
        : niters(niters), good(good), residual(std::move(residual)) {}
  };
  //! Holds result vector as well
  struct DiagnosticAndResult : public Diagnostic {
    //! Output x
    t_Vector x;
  };

  //! Setups PrimalDual
  //! \param[in] f_proximal: proximal operator of the \f$f\f$ function.
  //! \param[in] g_proximal: proximal operator of the \f$g\f$ function
  template <typename DERIVED>
  PrimalDual(t_Proximal const &f_proximal, t_Proximal const &g_proximal,
             Eigen::MatrixBase<DERIVED> const &target)
      : itermax_(std::numeric_limits<t_uint>::max()),
        sigma_(1),
        tau_(0.5),
        regulariser_strength_(0.5),
        update_scale_(1),
        xi_(1),
        rho_(1),
        is_converged_(),
        constraint_([](t_Vector &out, t_Vector const &x) { out = x; }),
        Phi_(linear_transform_identity<Scalar>()),
        Psi_(linear_transform_identity<Scalar>()),
        f_proximal_(f_proximal),
        g_proximal_(g_proximal),
        random_measurement_updater_([]() { return true; }),
        random_wavelet_updater_([]() { return true; }),
#ifdef SOPT_MPI
        v_all_sum_all_comm_(mpi::Communicator()),
        u_all_sum_all_comm_(mpi::Communicator()),
#endif
        target_(target) {
  }
  virtual ~PrimalDual() {}

// Macro helps define properties that can be initialized as in
// auto sdmm  = PrimalDual<float>().prop0(value).prop1(value);
#define SOPT_MACRO(NAME, TYPE)                 \
  TYPE const &NAME() const { return NAME##_; } \
  PrimalDual<SCALAR> &NAME(TYPE const &(NAME)) { \
    NAME##_ = NAME;                            \
    return *this;                              \
  }                                            \
                                               \
 protected:                                    \
  TYPE NAME##_;                                \
                                               \
 public:

  //! Maximum number of iterations
  SOPT_MACRO(itermax, t_uint);
  //! Update parameter
  SOPT_MACRO(update_scale, Real);
  //! γ parameter
  SOPT_MACRO(regulariser_strength, Real);
  //! sigma parameter
  SOPT_MACRO(sigma, Real);
  //! xi parameter
  SOPT_MACRO(xi, Real);
  //! rho parameter
  SOPT_MACRO(rho, Real);
  //! tau parameter
  SOPT_MACRO(tau, Real);
  //! \brief A function verifying convergence
  //! \details It takes as input two arguments: the current solution x and the current residual.
  SOPT_MACRO(is_converged, t_IsConverged);
  //! \brief A function applying a simple constraint
  SOPT_MACRO(constraint, t_Constraint);
  //! Measurement operator
  SOPT_MACRO(Phi, t_LinearTransform);
  //! Wavelet operator
  SOPT_MACRO(Psi, t_LinearTransform);
  //! First proximal
  SOPT_MACRO(f_proximal, t_Proximal);
  //! Second proximal
  SOPT_MACRO(g_proximal, t_Proximal);
  //! lambda that determines if to update measurements
  SOPT_MACRO(random_measurement_updater, t_Random_Updater);
  //! lambda that determines if to update wavelets
  SOPT_MACRO(random_wavelet_updater, t_Random_Updater);
#ifdef SOPT_MPI
  //! v space communicator
  SOPT_MACRO(v_all_sum_all_comm, mpi::Communicator);
  //! u space communicator
  SOPT_MACRO(u_all_sum_all_comm, mpi::Communicator);
#endif
#undef SOPT_MACRO
  //! \brief Simplifies calling the proximal of f.
  void f_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const {
    f_proximal()(out, regulariser_strength, x);
  }
  //! \brief Simplifies calling the proximal of f.
  void g_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const {
    g_proximal()(out, regulariser_strength, x);
  }

  //! Convergence function that takes only the output as argument
  PrimalDual<Scalar> &is_converged(std::function<bool(t_Vector const &x)> const &func) {
    return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
  }

  //! Vector of target measurements
  t_Vector const &target() const { return target_; }
  //! Sets the vector of target measurements
  template <typename DERIVED>
  PrimalDual<Scalar> &target(Eigen::MatrixBase<DERIVED> const &target) {
    target_ = target;
    return *this;
  }

  //! Facilitates call to user-provided convergence function
  bool is_converged(t_Vector const &x, t_Vector const &residual) const {
    return static_cast<bool>(is_converged()) and is_converged()(x, residual);
  }

  //! \brief Calls Primal Dual
  //! \param[out] out: Output vector x
  Diagnostic operator()(t_Vector &out) const { return operator()(out, initial_guess()); }
  //! \brief Calls Primal Dual
  //! \param[out] out: Output vector x
  //! \param[in] guess: initial guess
  Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) const {
    return operator()(out, std::get<0>(guess), std::get<1>(guess));
  }
  //! \brief Calls Primal Dual
  //! \param[out] out: Output vector x
  //! \param[in] guess: initial guess
  Diagnostic operator()(t_Vector &out,
                        std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
    return operator()(out, std::get<0>(guess), std::get<1>(guess));
  }
  //! \brief Calls Primal Dual
  //! \param[in] guess: initial guess
  DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) const {
    return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
  }
  //! \brief Calls Primal Dual
  //! \param[in] guess: initial guess
  DiagnosticAndResult operator()(
      std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
    DiagnosticAndResult result;
    static_cast<Diagnostic &>(result) = operator()(result.x, guess);
    return result;
  }
  //! \brief Calls Primal Dual
  //! \param[in] guess: initial guess
  DiagnosticAndResult operator()() const {
    DiagnosticAndResult result;
    static_cast<Diagnostic &>(result) = operator()(result.x, initial_guess());
    return result;
  }
  //! Makes it simple to chain different calls to PD
  DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const {
    DiagnosticAndResult result = warmstart;
    static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
    return result;
  }
  //! Set Φ and Φ^† using arguments that sopt::linear_transform understands
  template <typename... ARGS>
  typename std::enable_if<sizeof...(ARGS) >= 1, PrimalDual &>::type Phi(ARGS &&... args) {
    Phi_ = linear_transform(std::forward<ARGS>(args)...);
    return *this;
  }
  //! Set Φ and Φ^† using arguments that sopt::linear_transform understands
  template <typename... ARGS>
  typename std::enable_if<sizeof...(ARGS) >= 1, PrimalDual &>::type Psi(ARGS &&... args) {
    Psi_ = linear_transform(std::forward<ARGS>(args)...);
    return *this;
  }

  //! \brief Computes initial guess for x and the residual using the targets
  //! \details with y the vector of measurements
  //! - x = Φ^T y / nu =  Φ^T y / (Φ_norm^2)
  //! - residuals = Φ x - y
  std::tuple<t_Vector, t_Vector> initial_guess() const {
    return PrimalDual<SCALAR>::initial_guess(target(), Phi());
  }

  //! \brief Computes initial guess for x and the residual using the targets
  //! \details with y the vector of measurements
  //! - x = Φ^T y / nu =  Φ^T y / (Φ_norm^2)
  //! - residuals = Φ x - y
  //!
  //! This function simplifies creating overloads for operator() in PD wrappers.
  static std::tuple<t_Vector, t_Vector> initial_guess(t_Vector const &target,
                                                      t_LinearTransform const &phi) {
    std::tuple<t_Vector, t_Vector> guess;
    std::get<0>(guess) = static_cast<t_Vector>(phi.adjoint() * target) / phi.sq_norm();
    std::get<1>(guess) = target;
    return guess;
  }

 protected:
  void iteration_step(t_Vector &out, t_Vector &out_hold, t_Vector &u, t_Vector &u_hold, t_Vector &v,
                      t_Vector &v_hold, t_Vector &residual, t_Vector &q, t_Vector &r,
                      bool &random_measurement_update, bool &random_wavelet_update,
                      t_Vector &u_update, t_Vector &v_update) const;

  //! Checks input makes sense
  void sanity_check(t_Vector const &x_guess, t_Vector const &res_guess) const {
    if ((Phi().adjoint() * target()).size() != x_guess.size())
      SOPT_THROW("target, adjoint measurement operator and input vector have inconsistent sizes");
    if (target().size() != res_guess.size())
      SOPT_THROW("target and residual vector have inconsistent sizes");
    if ((Phi() * x_guess).size() != target().size())
      SOPT_THROW("target, measurement operator and input vector have inconsistent sizes");
    if (not static_cast<bool>(is_converged()))
      SOPT_WARN("No convergence function was provided: algorithm will run for {} steps", itermax());
  }

  //! \brief Calls Primal Dual
  //! \param[out] out: Output vector x
  //! \param[in] guess: initial guess
  //! \param[in] residuals: initial residuals
  Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const;

  //! Vector of measurements
  t_Vector target_;
};

template <typename SCALAR>
void PrimalDual<SCALAR>::iteration_step(t_Vector &out, t_Vector &out_hold, t_Vector &u,
                                        t_Vector &u_hold, t_Vector &v, t_Vector &v_hold,
                                        t_Vector &residual, t_Vector &q, t_Vector &r,
                                        bool &random_measurement_update,
                                        bool &random_wavelet_update, t_Vector &u_update,
                                        t_Vector &v_update) const {
  // dual calculations for measurements
  if (random_measurement_update) {
    g_proximal(v_hold, rho(), v + residual);
    v_hold = v + residual - v_hold;
    v = v + update_scale() * (v_hold - v);
    v_update = static_cast<t_Vector>(Phi().adjoint() * v);
  }
  // dual calculations for wavelet
  if (random_wavelet_update) {
    q = static_cast<t_Vector>(Psi().adjoint() * out_hold) * sigma();
    f_proximal(u_hold, regulariser_strength(), (u + q));
    u_hold = u + q - u_hold;
    u = u + update_scale() * (u_hold - u);
    u_update = static_cast<t_Vector>(Psi() * u);
  }
  // primal calculations
  r = out;
#ifdef SOPT_MPI
  if (v_all_sum_all_comm().size() > 0 and u_all_sum_all_comm().size() > 0)
    constraint()(
        out_hold,
        r - tau() * (u_all_sum_all_comm().all_sum_all(static_cast<const t_Vector>(u_update)) +
                     v_all_sum_all_comm().all_sum_all(static_cast<const t_Vector>(v_update))));
  else
#endif
    constraint()(out_hold, r - tau() * (u_update + v_update));
  out = r + update_scale() * (out_hold - r);
  out_hold = 2 * out_hold - r;
  random_measurement_update = random_measurement_updater_();
  random_wavelet_update = random_wavelet_updater_();
  // update residual
  if (random_measurement_update)
    residual = static_cast<t_Vector>(Phi() * out_hold) * xi() - target();
}

template <typename SCALAR>
typename PrimalDual<SCALAR>::Diagnostic PrimalDual<SCALAR>::operator()(
    t_Vector &out, t_Vector const &x_guess, t_Vector const &res_guess) const {
  SOPT_HIGH_LOG("Performing Primal Dual");
  sanity_check(x_guess, res_guess);
  bool random_measurement_update = random_measurement_updater_();
  bool random_wavelet_update = random_wavelet_updater_();
  t_Vector residual = res_guess;
  out = x_guess;
  t_Vector out_hold = x_guess;
  t_Vector r = x_guess;
  t_Vector v = residual;
  t_Vector v_hold = residual;
  t_Vector v_update = x_guess;
  t_Vector u = Psi().adjoint() * out;
  t_Vector u_hold = u;
  t_Vector u_update = out;
  t_Vector q = u;

  t_uint niters(0);
  bool converged = false;
  for (; (not converged) && (niters < itermax()); ++niters) {
    SOPT_LOW_LOG("    - [Primal Dual] Iteration {}/{}", niters, itermax());
    iteration_step(out, out_hold, u, u_hold, v, v_hold, residual, q, r, random_measurement_update,
                   random_wavelet_update, u_update, v_update);
    SOPT_LOW_LOG("      - [Primal Dual] Sum of residuals: {}",
                 static_cast<t_Vector>(residual).array().abs().sum());
    converged = is_converged(out, residual);
  }

  if (converged) {
    SOPT_MEDIUM_LOG("    - [Primal Dual] converged in {} of {} iterations", niters, itermax());
  } else if (static_cast<bool>(is_converged())) {
    // not meaningful if not convergence function
    SOPT_ERROR("    - [Primal Dual] did not converge within {} iterations", itermax());
  }
  return {niters, converged, std::move(residual)};
}
} // namespace sopt::algorithm
#endif
