#include "log-fourier.hpp"
#include <getopt.h>
#include <iostream>

int main(int argc, char* argv[]) {
  /* Model parameters */
  unsigned p = 2;
  unsigned s = 2;
  Real λ = 0.5;
  Real τ₀ = 0;

  /* Log-Fourier parameters */
  unsigned log2n = 8;
  Real Δτ = 0.1;
  Real k = -0.01;
  Real logShift = 0;

  /* Iteration parameters */
  Real ε = 1e-15;
  Real γ₀ = 1;
  Real x = 1;
  Real β₀ = 0;
  Real βₘₐₓ = 20;
  Real Δβ = 0.01;
  bool loadData = false;
  unsigned stepsToRespond = 1e7;
  unsigned pad = 2;

  int opt;

  while ((opt = getopt(argc, argv, "p:s:2:T:t:b:d:g:k:D:e:0:lS:x:P:h:")) != -1) {
    switch (opt) {
    case 'p':
      p = atoi(optarg);
      break;
    case 's':
      s = atoi(optarg);
      break;
    case '2':
      log2n = atoi(optarg);
      break;
    case 't':
      τ₀ = atof(optarg);
      break;
    case 'b':
      βₘₐₓ = atof(optarg);
      break;
    case 'd':
      Δβ = atof(optarg);
      break;
    case 'g':
      γ₀ = atof(optarg);
      break;
    case 'k':
      k = atof(optarg);
      break;
    case 'h':
      logShift = atof(optarg);
      break;
    case 'D':
      Δτ = atof(optarg);
      break;
    case 'e':
      ε = atof(optarg);
      break;
    case '0':
      β₀ = atof(optarg);
      break;
    case 'x':
      x = atof(optarg);
      break;
    case 'P':
      pad = atoi(optarg);
      break;
    case 'l':
      loadData = true;
      break;
    case 'S':
      stepsToRespond = atoi(optarg);
      break;
    default:
      exit(1);
    }
  }

  unsigned N = pow(2, log2n);

  Real Γ₀ = 1;
  Real μ₀ = τ₀ > 0 ? (sqrt(1+4*Γ₀*τ₀)-1)/(2*τ₀) : Γ₀;

  LogarithmicFourierTransform fft(N, k, Δτ, pad, μ₀ * pow(10, logShift));

  std::cerr << "Starting, μ₀ = " << μ₀ << ", range " << fft.t(0) << " " << fft.t(N-1) << std::endl;

  Real μₜ₋₁ = μ₀;

  std::vector<Real> Cₜ₋₁(N);
  std::vector<Real> Rₜ₋₁(N);
  std::vector<Complex> Ĉₜ₋₁(N);
  std::vector<Complex> Ȓₜ₋₁(N);

  if (!loadData) {
    /* Start from the exact solution for β = 0 */
    for (unsigned n = 0; n < N; n++) {
      if (τ₀ > 0) {
        if (τ₀ == 2) {
          Cₜ₋₁[n] = Γ₀ * std::exp(-fft.t(n) / 2) * (1 + fft.t(n) / 2);
        } else {
          Cₜ₋₁[n] = Γ₀ * (std::exp(-μ₀ * fft.t(n)) - μ₀ * τ₀ * std::exp(-fft.t(n) / τ₀)) / (μ₀ - pow(μ₀, 3) * pow(τ₀, 2));
        }
      } else {
        Cₜ₋₁[n] = Γ₀ * std::exp(-μ₀ * fft.t(n)) / μ₀;
      }
      Rₜ₋₁[n] = std::exp(-μ₀ * fft.t(n));

      Ĉₜ₋₁[n] = 2 * Γ₀ / (pow(μ₀, 2) + pow(fft.ν(n), 2)) / (1 + pow(τ₀ * fft.ν(n), 2));
      Ȓₜ₋₁[n] = (Real)1.0 / (μ₀ + II * fft.ν(n));
    }
  } else {
    logFourierLoad(Cₜ₋₁, Rₜ₋₁, Ĉₜ₋₁, Ȓₜ₋₁, p, s, λ, τ₀, β₀, log2n, Δτ, logShift);
    μₜ₋₁ = estimateZ(fft, Cₜ₋₁, Ĉₜ₋₁, Rₜ₋₁, Ȓₜ₋₁, p, s, λ, τ₀, β₀);
  }

  std::vector<Real> Cₜ = Cₜ₋₁;
  std::vector<Real> Rₜ = Rₜ₋₁;
  std::vector<Complex> Ĉₜ = Ĉₜ₋₁;
  std::vector<Complex> Ȓₜ = Ȓₜ₋₁;
  Real μₜ = μₜ₋₁;

  Real β = β₀ + Δβ;
  while (β < βₘₐₓ) {
    Real γ = γ₀;
    Real ΔCmin = 1000;
    Real ΔCₜ = 100;
    unsigned stepsUp = 0;
    while (ΔCₜ > ε) {
      auto [RddfCt, dfCt] = RddfCtdfCt(fft, Cₜ, Rₜ, p, s, λ);

      std::vector<Complex> Ĉₜ₊₁(N);
      std::vector<Complex> Ȓₜ₊₁(N);

      Real C₀ = 0;
      Real μ₊ = 0;
      Real μ₋ = 0;

      while (std::abs(C₀ - 1) > ε) {
        for (unsigned n = 0; n < N; n++) {
          Ĉₜ₊₁[n] = ((2 * Γ₀ * std::conj(Ȓₜ[n]) / (1 + std::pow(τ₀ * fft.ν(n), 2)) + std::pow(β, 2) * (RddfCt[n] * Ĉₜ[n] + dfCt[n] * std::conj(Ȓₜ[n]))) / (μₜ + II * fft.ν(n))).real();
        }
        C₀ = C0(fft, Ĉₜ₊₁);
        if (C₀ > 1) {
          μ₋ = μₜ;
        } else {
          μ₊ = μₜ;
        }
        if (μ₋ > 0 && μ₊ > 0) {
          μₜ = (μ₊ + μ₋) / 2;
        } else {
          μₜ *= pow(tanh(C₀-1)+1, x);
        }
      }

      for (unsigned n = 0; n < N; n++) {
        Ȓₜ₊₁[n] = ((Real)1.0 + std::pow(β, 2) * RddfCt[n] * Ȓₜ[n]) / (μₜ + II * fft.ν(n));
      }
      std::vector<Real> Cₜ₊₁ = fft.inverse(Ĉₜ₊₁);
      std::vector<Real> Rₜ₊₁ = fft.inverse(Ȓₜ₊₁);

      if (!std::isnan(Cₜ₊₁[0])) {

      bool trigger0 = false;
      bool trigger1 = false;
      for (unsigned i = 0; i < N; i++) {
        if (Rₜ₊₁[i] < ε || trigger0) {
          Rₜ₊₁[i] = 0;
          trigger0 = true;
        }
      }


      Real Rmax = 0;
      for (unsigned i = 0; i < N; i++) {
        if (Rₜ₊₁[N-1-i] > Rmax) Rmax = Rₜ₊₁[N-1-i];
        Rₜ₊₁[N-1-i] = Rmax;
      }

      trigger0 = false;
      trigger1 = false;
      for (unsigned i = 0; i < N; i++) {
        if (Cₜ₊₁[i] < ε || trigger0) {
          Cₜ₊₁[i] = 0;
          trigger0 = true;
        }
        if (Cₜ₊₁[N-1-i] > 1 - ε || trigger1) {
          Cₜ₊₁[N-1-i] = 1;
          trigger1 = true;
        }
      }
      trigger0 = false;
      trigger1 = false;
      for (unsigned i = 0; i < N; i++) {
        if (Rₜ₊₁[N-1-i] > 1 - ε || trigger1) {
          Rₜ₊₁[N-1-i] = 1;
          trigger1 = true;
        }
      }

      Real Cmax = 0;
      for (unsigned i = 0; i < N; i++) {
        if (Cₜ₊₁[N-1-i] > Cmax) Cmax = Cₜ₊₁[N-1-i];
        Cₜ₊₁[N-1-i] = Cmax;
      }
      }

      ΔCₜ = 0;
      for (unsigned i = 0; i < N; i++) {
        ΔCₜ += std::norm(Cₜ[i] - Cₜ₊₁[i]);
        ΔCₜ += std::norm(Rₜ[i] - Rₜ₊₁[i]);
      }
      ΔCₜ = sqrt(ΔCₜ) / (2*N);

      if (ΔCₜ < 0.9 * ΔCmin) {
        ΔCmin = ΔCₜ;
        stepsUp = 0;
      } else {
        stepsUp++;
      }

      if (stepsUp > stepsToRespond) {
        γ = std::max(γ/2, (Real)1e-4);
        stepsUp = 0;
        ΔCmin = ΔCₜ;
      }

      for (unsigned i = 0; i < N; i++) {
        Cₜ[i] += γ * (Cₜ₊₁[i] - Cₜ[i]);
        Rₜ[i] += γ * (Rₜ₊₁[i] - Rₜ[i]);
        Ĉₜ[i] += γ * (Ĉₜ₊₁[i] - Ĉₜ[i]);
        Ȓₜ[i] += γ * (Ȓₜ₊₁[i] - Ȓₜ[i]);
      }

      std::cerr << "\x1b[2K" << "\r";
      std::cerr << β << " " << μₜ << " " << ΔCₜ << " " << γ << " " << C₀;
    }

    if (std::isnan(Cₜ[0])) {
      γ₀ /= 2;
      Cₜ = Cₜ₋₁;
      Rₜ = Rₜ₋₁;
      Ĉₜ = Ĉₜ₋₁;
      Ȓₜ = Ȓₜ₋₁;
      μₜ = μₜ₋₁;
    } else {
      Real E = energy(fft, Cₜ, Rₜ, p, s, λ, β);

      std::cerr << "\x1b[2K" << "\r";
      std::cerr << β << " " << μₜ << " " << Ĉₜ[0].real() << " " << E << " " << γ << std::endl;

      logFourierSave(Cₜ, Rₜ, Ĉₜ, Ȓₜ, p, s, λ, τ₀, β, log2n, Δτ, logShift);

      if (Ĉₜ[0].real() / Ĉₜ₋₁[0].real() > 1.5) {
        Δβ *= 0.1;
      }

      β = std::round(1e6 * (β + Δβ)) / 1e6;
      Cₜ₋₁ = Cₜ;
      Rₜ₋₁ = Rₜ;
      Ĉₜ₋₁ = Ĉₜ;
      Ȓₜ₋₁ = Ȓₜ;
      μₜ₋₁ = μₜ;
    }
  }

  return 0;
}