#include "fourier.hpp"
#include "p-spin.hpp"
#include <fstream>
#include <fftw3.h>
#include <getopt.h>
#include <iostream>

Real energy(const std::vector<Real>& C, std::vector<Real>& R, Real λ, unsigned p, unsigned s, Real Δτ) {
  Real I = 0;
  for (unsigned σ = 0; σ < C.size(); σ++) {
    I += Δτ * df(λ, p, s, C[σ]) * R[σ];
  }
  return I;
}

int main(int argc, char* argv[]) {
  unsigned p = 3;
  unsigned s = 4;
  Real λ = 0.5;
  Real τₘₐₓ = 1e3;
  Real τ₀ = 0;
  Real β₀ = 0;
  Real βₘₐₓ = 1;
  Real Δβ = 1e-2;
  unsigned iterations = 10;
  unsigned log2n = 8;
  Real ε = 1e-14;

  int opt;

  while ((opt = getopt(argc, argv, "T:2:t:0:b:d:I:")) != -1) {
    switch (opt) {
    case 'T':
      τₘₐₓ = atof(optarg);
      break;
    case '2':
      log2n = atof(optarg);
      break;
    case 't':
      τ₀ = atof(optarg);
      break;
    case '0':
      β₀ = atof(optarg);
      break;
    case 'b':
      βₘₐₓ = atof(optarg);
      break;
    case 'd':
      Δβ = atof(optarg);
      break;
    case 'I':
      iterations = (unsigned)atof(optarg);
      break;
    default:
      exit(1);
    }
  }

  unsigned N = pow(2, log2n);

  Real Δτ = (1 + τ₀ / 2) * τₘₐₓ / M_PI / N;
  Real Δω = M_PI / ((1 + τ₀ / 2) * τₘₐₓ);

  FourierTransform fft(N, Δω, Δτ, FFTW_ESTIMATE);

  Real Γ₀ = 1;
  Real μ = 1;
  if (τ₀ > 0) {
    μ = (sqrt(1+4*Γ₀*τ₀) - 1) / (2 * τ₀);
  }

  Real τ = 0;
  std::vector<Real> C(N);
  std::vector<Real> R(N);
  std::vector<Real> Γ(N);
  std::vector<Real> Γh(N+1);

  Γh[0] = Γ₀;

  for (unsigned i = 0; i < N; i++) {
    Real τ = i * Δτ;
    Real ω = (i + 1) * Δω * M_PI;
    if (τ₀ > 0) {
      C[i] = (Γ₀ / μ) * (exp(-μ * τ) - μ * τ₀ * exp(-τ / τ₀)) / (1 - pow(μ * τ₀, 2));
      Γ[i] = (Γ₀ / τ₀) * exp(-τ / τ₀);
    } else {
      C[i] = (Γ₀ / μ) * exp(-μ * τ);
    }
    Γh[i+1] = Γ₀ / (1 + pow(ω * τ₀, 2));
    R[i] = exp(-μ * τ);
  }

  for (Real β = β₀; β < βₘₐₓ; β += Δβ) {
    Real Rerr = 100;
    while (sqrt(Rerr / N) > ε) {
      /* First step: integrate R from C */
      std::vector<Real> R₊(N);
      R₊[0] = 1;
      for (unsigned i = 1; i < N; i++) {
        Real I = 0;
        for (unsigned j = 0; j <= i; j++) {
          I += R[i - j] * ddf(λ, p, s, C[i - j]) * R[j] * Δτ;
        }
        Real dR = -μ * R₊[i - 1] + pow(β, 2) * I;
        R₊[i] = R₊[i - 1] + dR * Δτ;
      }

      Rerr = 0;
      for (unsigned i = 0; i < N; i++) {
        Rerr += pow(R[i] - R₊[i], 2);
      }

      R = R₊;

      /* Second step: integrate C from R */
      std::vector<Real> dC = fft.convolve(Γh, R);
      Real Cₜ₊₁ = 0;
      for (unsigned i = 0; i < N; i++) {
        Real Cₜ = Cₜ₊₁ + dC[N - i - 1] * Δτ;
        C[N - i - 1] = Cₜ;
        Cₜ₊₁ = Cₜ;
      }

      /* Third step: adjust μ */
      μ *= C[0];

      std::cerr << β << " " << sqrt(Rerr / N) << std::endl;
    }

    std::ofstream outfile(fourierFile("Ci", p, s, λ, τ₀, β, log2n, τₘₐₓ), std::ios::out | std::ios::binary);
    outfile.write((const char*)(C.data()), N * sizeof(Real));
    outfile.close();

    std::ofstream outfileR(fourierFile("Ri", p, s, λ, τ₀, β, log2n, τₘₐₓ), std::ios::out | std::ios::binary);
    outfileR.write((const char*)(R.data()), N * sizeof(Real));
    outfileR.close();
  }

  return 0;
}