#include <getopt.h>
#include <iomanip>
#include <fstream>
#include "pcg-cpp/include/pcg_random.hpp"
#include "randutils/randutils.hpp"

#include "eigen/Eigen/Dense"
#include "eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h"

using Rng = randutils::random_generator<pcg32>;

using Real = double;
using Vector = Eigen::Matrix<Real, Eigen::Dynamic, 1>;
using Matrix = Eigen::Matrix<Real, Eigen::Dynamic, Eigen::Dynamic>;

Vector normalizeVector(const Vector& x) {
  return x * sqrt(x.size() / x.squaredNorm());
}

Vector randomVector(unsigned N, Rng& r, Real σ = 1) {
  Vector v(N);
  for (Real& vᵢ : v) {
    vᵢ = r.variate<Real, std::normal_distribution>(0, σ);
  }
  return v;
}

Vector projectionOfOn(const Vector& v, const Vector& u) {
  return (v.dot(u) / u.squaredNorm()) * u;
}

Real wignerCDF(Real λ) {
  return 0.5 + (λ * sqrt(4 - pow(λ, 2)) / 4 + atan(λ / sqrt(4 - pow(λ, 2)))) / M_PI;
}

Real wignerInverse(Real p, Real ε = 1e-14) {
  Real a = -2;
  Real b = 2;

  while (b - a > ε) {
    Real c = (a + b) / 2;
    if ((wignerCDF(a) - p) * (wignerCDF(c) - p) > 0) {
      a = c;
    } else {
      b = c;
    }
  }

  return (a + b) / 2;
}

class QuadraticModel {
public:
  Vector J;
  unsigned N;

  QuadraticModel(unsigned N, Rng& r, bool diag = false) : J(N), N(N) {
    if (diag) {
      Matrix Jtmp(N, N);

      for (unsigned j = 0; j < N; j++) {
        for (unsigned i = j; i < N; i++) {
          Jtmp(i, j) = r.variate<Real, std::normal_distribution>(0, 1 / sqrt(N));
          Jtmp(j, i) = Jtmp(i, j);
        }
      }

      std::cerr << "Beginning diagonalization" << std::endl;
      Eigen::SelfAdjointEigenSolver<Matrix> es;
      es.compute(Jtmp);
      J = es.eigenvalues();
      std::cerr << "Finished diagonalization" << std::endl;
    } else {
      for (Real& Jᵢ : J) {
        Jᵢ = wignerInverse(r.uniform(0.0, 1.0));
      }
    }
  }

  Real H(const Vector& x) const {
    return 0.5 * (J.cwiseProduct(x)).dot(x);
  }

  Vector ∇H(const Vector& x) const {
    Vector ∂H = J.cwiseProduct(x);
    return ∂H - projectionOfOn(∂H, x);
  }
};

Vector gradientDescent(const QuadraticModel& M, const Vector& x₀, Real E, Real ε = 1e-14) {
  Vector xₜ = x₀;
  Real Hₜ = M.H(x₀);
  Real α = 1.0 / M.N;
  Real m;
  Vector ∇H;

  while (
    ∇H = (Hₜ / M.N - E) * M.∇H(xₜ) / M.N, m = ∇H.squaredNorm(),
    m > ε
  ) {
    Vector xₜ₊₁;
    Real Hₜ₊₁;

    while (
      xₜ₊₁ = normalizeVector(xₜ - α * ∇H), Hₜ₊₁ = M.H(xₜ₊₁),
      pow(Hₜ₊₁ / M.N - E, 2) > pow(Hₜ / M.N - E, 2) - α * m
    ) {
      α /= 2;
    }

    xₜ = xₜ₊₁;
    Hₜ = Hₜ₊₁;
    α *= 1.25;
  }

  return xₜ;
}

Vector randomStep(const QuadraticModel& M, const Vector& x₀, Real E, Rng& r, Real Δt = 1e-4) {
  Vector η = randomVector(M.N, r, sqrt(2 * Δt));
  η -= projectionOfOn(η, x₀);
  η -= projectionOfOn(η, M.∇H(x₀));
  return gradientDescent(M, normalizeVector(x₀ + η), E);
}

int main(int argc, char* argv[]) {
  unsigned N = 10;
  Real T = 10;
  Real E = 0;
  Real Δt = 1e-4;
  Real Δw = 1e-2;
  bool diag = true;

  int opt;

  while ((opt = getopt(argc, argv, "N:E:T:t:w:d")) != -1) {
    switch (opt) {
    case 'N':
      N = (unsigned)atof(optarg);
      break;
    case 'E':
      E = atof(optarg);
      break;
    case 'T':
      T = atof(optarg);
      break;
    case 't':
      Δt = atof(optarg);
      break;
    case 'w':
      Δw = atof(optarg);
      break;
    case 'd':
      diag = false;
      break;
    default:
      exit(1);
    }
  }

  Rng r;
  QuadraticModel model(N, r, diag);

  Vector x₀ = normalizeVector(randomVector(N, r));
  x₀ = gradientDescent(model, x₀, E);

  std::cout << std::setprecision(15);

  Vector x = x₀;

  Real timeSinceWrite = Δw;

  auto tag = std::chrono::high_resolution_clock::now();
  std::ofstream outfile(std::to_string(N) + "_" + std::to_string(E) + "_" + std::to_string(-std::log10(Δt)) + "_" + std::to_string(Δw) + "_" +  std::to_string(tag.time_since_epoch().count())+ ".dat", std::ios::out | std::ios::binary);

  for (Real Jᵢ : model.J) {
    float Jif = Jᵢ;
    outfile.write((const char*)(&Jif), sizeof(float));
  }

  for (Real t = 0; t < T; t += Δt) {
    x = randomStep(model, x, E, r, Δt);

    if (timeSinceWrite >= Δw) {
      for (Real xᵢ : x) {
        float xif = xᵢ;
        outfile.write((const char*)(&xif), sizeof(float));
      }
      timeSinceWrite = 0;
    }

    timeSinceWrite += Δt;
  }

  outfile.close();

  return 0;
}