#include "log-fourier.hpp"
#include "p-spin.hpp"
#include <getopt.h>
#include <iostream>

int main(int argc, char* argv[]) {
  /* Model parameters */
  unsigned p = 2;
  unsigned s = 2;
  Real λ = 0.5;
  Real τ₀ = 0;

  /* Log-Fourier parameters */
  unsigned log2n = 8;
  Real Δτ = 0.1;
  Real k = 0.1;

  /* Iteration parameters */
  Real ε = 1e-13;
  Real γ = 1;
  Real βₘₐₓ = 0.7;
  Real Δβ = 0.01;

  int opt;

  while ((opt = getopt(argc, argv, "p:s:2:T:t:b:d:g:k:D:")) != -1) {
    switch (opt) {
    case 'p':
      p = atoi(optarg);
      break;
    case 's':
      s = atoi(optarg);
      break;
    case '2':
      log2n = atoi(optarg);
      break;
    case 't':
      τ₀ = atof(optarg);
      break;
    case 'b':
      βₘₐₓ = atof(optarg);
      break;
    case 'd':
      Δβ = atof(optarg);
      break;
    case 'g':
      γ = atof(optarg);
      break;
    case 'k':
      k = atof(optarg);
      break;
    case 'D':
      Δτ = atof(optarg);
      break;
    default:
      exit(1);
    }
  }

  unsigned N = pow(2, log2n);

  LogarithmicFourierTransform fft(N, k, Δτ, 4);

  Real Γ₀ = 1.0 + τ₀;
  Real μ = 1.0;
/*  Real μ = Γ₀;
  if (τ₀ > 0) {
    μ = (sqrt(1+4*Γ₀*τ₀)-1)/(2*τ₀);
  }
*/

  std::vector<Real> Cₜ₋₁(N);
  std::vector<Real> Rₜ₋₁(N);
  std::vector<Complex> Ĉₜ₋₁(N);
  std::vector<Complex> Ȓₜ₋₁(N);

  /* Start from the exact solution for β = 0 */
  for (unsigned n = 0; n < N; n++) {
    if (τ₀ != 1) {
      Cₜ₋₁[n] = Γ₀ * (exp(-μ * fft.t(n)) - μ * τ₀ * exp(-fft.t(n) / τ₀)) / (μ - pow(μ, 3) * pow(τ₀, 2));
    } else {
      Cₜ₋₁[n] = Γ₀ * exp(-fft.t(n)) * (1 + fft.t(n));
    }
    Rₜ₋₁[n] = exp(-μ * fft.t(n));

    Ĉₜ₋₁[n] = 2 * Γ₀ / (pow(μ, 2) + pow(fft.ν(n), 2)) / (1 + pow(τ₀ * fft.ν(n), 2));
    Ȓₜ₋₁[n] = 1.0 / (μ + 1i * fft.ν(n));
  }

  std::vector<Real> Cₜ = Cₜ₋₁;
  std::vector<Real> Rₜ = Rₜ₋₁;
  std::vector<Complex> Ĉₜ = Ĉₜ₋₁;
  std::vector<Complex> Ȓₜ = Ȓₜ₋₁;

  Real fac = 1;
  Real β = 0;
  while (β < βₘₐₓ) {
    Real ΔC = 100;
    Real ΔC₀ = 100;
    unsigned it = 0;
    while (ΔC > ε) {
      std::vector<Real> dfC(N);
      std::vector<Real> RddfC(N);
      for (unsigned n = 0; n < N; n++) {
        RddfC[n] = Rₜ[n] * ddf(λ, p, s, Cₜ[n]);
        dfC[n] = df(λ, p, s, Cₜ[n]);
      }
      std::vector<Complex> RddfCt = fft.fourier(RddfC, false);
      std::vector<Complex> dfCt = fft.fourier(dfC, true);

      std::vector<Complex> Ȓₜ₊₁(N);
      std::vector<Complex> Ĉₜ₊₁(N);

      for (unsigned n = 0; n < N; n++) {
        Ȓₜ₊₁[n] = (1.0 + pow(β, 2) * RddfCt[n] * Ȓₜ[n]) / (μ + 1i * fft.ν(n));
      }

      std::vector<Real> Rₜ₊₁ = fft.inverse(Ȓₜ₊₁);

      for (unsigned n = 0; n < N; n++) {
        RddfC[n] = Rₜ₊₁[n] * ddf(λ, p, s, Cₜ[n]);
      }
      RddfCt = fft.fourier(RddfC, false);
      for (unsigned n = 0; n < N; n++) {
        Ĉₜ₊₁[n] = (2 * Γ₀ * std::conj(Ȓₜ₊₁[n]) / (1 + pow(τ₀ * fft.ν(n), 2)) + pow(β, 2) * (RddfCt[n] * Ĉₜ[n] + dfCt[n] * std::conj(Ȓₜ₊₁[n]))) / (μ + 1i * fft.ν(n));
      }
      std::vector<Real> Cₜ₊₁ = fft.inverse(Ĉₜ₊₁);

      μ *= pow(tanh(Cₜ₊₁[0]-1)+1, 0.05);

      ΔC = 0;
      for (unsigned i = 0; i < N; i++) {
        ΔC += std::norm(Ĉₜ[i] - Ĉₜ₊₁[i]);
        ΔC += std::norm(Ȓₜ[i] - Ȓₜ₊₁[i]);
      }
      ΔC = sqrt(ΔC) / (2*N);

      for (unsigned i = 0; i < N; i++) {
        Cₜ[i] += γ * (Cₜ₊₁[i] - Cₜ[i]);
        Rₜ[i] += γ * (Rₜ₊₁[i] - Rₜ[i]);
        Ĉₜ[i] += γ * (Ĉₜ₊₁[i] - Ĉₜ[i]);
        Ȓₜ[i] += γ * (Ȓₜ₊₁[i] - Ȓₜ[i]);
      }

      /*
      if (ΔC < ΔC₀) {
        ΔC₀ = ΔC;
        it = 0;
        γ = std::min(1.001 * γ, 1.0);
      } else {
        it++;
      }

      if (it > 100) {
        γ = std::max(0.5 * γ, 1e-3);
        it = 0;
        ΔC₀ = ΔC;
      }
      */

      std::cerr << "\x1b[2K" << "\r";
      std::cerr << β << " " << μ << " " << ΔC << " " << γ << " " << Cₜ[0];
    }

    /* Integrate the energy using Simpson's rule */
    Real E = 0;
    for (unsigned n = 0; n < N/2-1; n++) {
      Real h₂ₙ   = fft.t(2*n+1) - fft.t(2*n);
      Real h₂ₙ₊₁ = fft.t(2*n+2) - fft.t(2*n+1);
      Real f₂ₙ   = Rₜ[2*n]   * df(λ, p, s, Cₜ[2*n]);
      Real f₂ₙ₊₁ = Rₜ[2*n+1] * df(λ, p, s, Cₜ[2*n+1]);
      Real f₂ₙ₊₂ = Rₜ[2*n+2] * df(λ, p, s, Cₜ[2*n+2]);
      E += (h₂ₙ + h₂ₙ₊₁) / 6 * (
            (2 - h₂ₙ₊₁ / h₂ₙ) * f₂ₙ
            + pow(h₂ₙ + h₂ₙ₊₁, 2) / (h₂ₙ * h₂ₙ₊₁) * f₂ₙ₊₁
            + (2 - h₂ₙ / h₂ₙ₊₁) * f₂ₙ₊₂
          );
    }
    E *= β;

      std::cerr << "\x1b[2K" << "\r";
    std::cerr << β << " " << μ << " " << Ĉₜ[0].real() << " " << E << " " << γ << std::endl;
    β += Δβ;
    Cₜ₋₁ = Cₜ;
    Rₜ₋₁ = Rₜ;
    Ĉₜ₋₁ = Ĉₜ;
    Ȓₜ₋₁ = Ȓₜ;
  }

  for (unsigned i = 0; i < N; i++) {
    std::cout << fft.t(i) << " " << Cₜ[i] << std::endl;
  }

  return 0;
}