#include "fourier.hpp"
#include <getopt.h>
#include <iostream>
#include <fstream>

int main(int argc, char* argv[]) {
  unsigned p = 2;
  unsigned s = 2;
  Real λ = 0.5;
  Real τ₀ = 0;
  Real β₀ = 0;
  Real βₘₐₓ = 0.5;
  Real Δβ = 0.05;

  unsigned log2n = 8;
  Real τₘₐₓ = 20;

  unsigned maxIterations = 1000;
  Real ε = 1e-14;
  Real γ = 1;

  bool loadData = false;

  int opt;

  while ((opt = getopt(argc, argv, "p:s:2:T:t:0:y:d:I:g:l")) != -1) {
    switch (opt) {
    case 'p':
      p = atoi(optarg);
      break;
    case 's':
      s = atoi(optarg);
      break;
    case '2':
      log2n = atoi(optarg);
      break;
    case 'T':
      τₘₐₓ = atof(optarg);
      break;
    case 't':
      τ₀ = atof(optarg);
      break;
    case '0':
      β₀ = atof(optarg);
      break;
    case 'y':
      βₘₐₓ = atof(optarg);
      break;
    case 'd':
      Δβ = atof(optarg);
      break;
    case 'I':
      maxIterations = (unsigned)atof(optarg);
      break;
    case 'g':
      γ = atof(optarg);
      break;
    case 'l':
      loadData = true;
      break;
    default:
      exit(1);
    }
  }

  unsigned n = pow(2, log2n);

  Real Δτ = (1 + τ₀ / 2) * τₘₐₓ / M_PI / n;
  Real Δω = M_PI / ((1 + τ₀ / 2) * τₘₐₓ);

  Real Γ₀ = 1.0;
  Real μ = Γ₀;
  if (τ₀ > 0) {
    μ = (sqrt(1+4*Γ₀*τ₀)-1)/(2*τ₀);
  }

  std::vector<Real> C(2 * n);
  std::vector<Real> R(2 * n);

  FourierTransform fft(n, Δω, Δτ);
  std::vector<Complex> Ct;
  std::vector<Complex> Rt;

  Real β = β₀;

  if (!loadData) {
    // start from the exact solution for τ₀ = 0
    for (unsigned i = 0; i < n; i++) {
      Real τ = i * Δτ * M_PI;
      if (τ₀ > 0) {
        C[i] = Γ₀ * (exp(-μ * τ) - μ * τ₀ * exp(-τ / τ₀)) / (μ - pow(μ, 3) * pow(τ₀, 2));
      } else {
        C[i] = Γ₀ * exp(-μ * τ) / μ;
      }
      if (i > 0) {
        C[2 * n - i] = C[i];
      }
      R[i] = exp(-μ * τ);
    }
    Ct = fft.fourier(C);
    Rt = fft.fourier(R);
  } else {
    std::ifstream cfile(fourierFile("C", p, s, λ, τ₀, β, log2n, τₘₐₓ), std::ios::binary);
    cfile.read((char*)(C.data()), (C.size() / 2) * sizeof(Real));
    cfile.close();
    for (unsigned i = 1; i < n; i++) {
      C[2 * n - i] = C[i];
    }
    std::ifstream rfile(fourierFile("R", p, s, λ, τ₀, β, log2n, τₘₐₓ), std::ios::binary);
    rfile.read((char*)(R.data()), (R.size() / 2) * sizeof(Real));
    rfile.close();

    Ct = fft.fourier(C);
    Rt = fft.fourier(R);

    μ = estimateZ(fft, C, Ct, R, Rt, p, s, λ, τ₀, β);
  }

  std::vector<Real> Cb = C;
  std::vector<Real> Rb = R;
  std::vector<Complex> Ctb = Ct;
  std::vector<Complex> Rtb = Rt;

  Real fac = 1;
  while (β += Δβ, β <= βₘₐₓ) {
    Real Δμ = 1e-2;
    Real μ₁ = 0;
    Real μ₂ = 0;
    while (true) {
    Real ΔC = 1;
    Real ΔCprev = 1000;
    unsigned it = 0;
    while (sqrt(2 * ΔC / C.size()) > ε) {
      it++;
      auto [RddfCt, dfCt] = RddfCtdfCt(fft, C, R, p, s, λ);

      for (unsigned i = 0; i < Rt.size(); i++) {
        Real ω = i * Δω;
        Rt[i] = (1.0 + pow(β, 2) * RddfCt[i] * Rt[i]) / (μ + 1i * ω);
      }

      for (unsigned i = 0; i < Ct.size(); i++) {
        Real ω = i * Δω;
        Ct[i] = (2 * Γ₀ * std::conj(Rt[i]) / (1 + pow(τ₀ * ω, 2)) + pow(β, 2) * (RddfCt[i] * Ct[i] + dfCt[i] * std::conj(Rt[i]))) / (μ + 1i * ω);
      }

      std::vector<Real> Cnew = fft.inverse(Ct);
      std::vector<Real> Rnew = fft.inverse(Rt);
      for (unsigned i = n; i < 2 * n; i++) {
        Rnew[i] = 0;
      }

      ΔC = 0;
      for (unsigned i = 0; i < Cnew.size() / 2; i++) {
        ΔC += pow(Cnew[i] - C[i], 2);
        ΔC += pow(Rnew[i] - R[i], 2);
      }

      for (unsigned i = 0; i < Cnew.size(); i++) {
        C[i] += γ * (Cnew[i] - C[i]);
      }

      for (unsigned i = 0; i < Rnew.size() / 2; i++) {
        R[i] += γ * (Rnew[i] - R[i]);
      }

      /*
      if (it % maxIterations == 0) {
        if (ΔC < ΔCprev) {
          γ = std::min(1.1 * γ, 1.0);
        } else {
          γ /= 2;
        }

        ΔCprev = ΔC;
      }
      */

      std::cerr << μ << " " << p << " " << s << " " << τ₀ << " " << β << " " << sqrt(2 * ΔC / C.size()) << " " << γ << " " << C[0];
      std::cerr << "\r";

    }
      if (std::isnan(C[0])) {
        C = Cb;
        R = Rb;
        Ct = Ctb;
        Rt = Rtb;
        μ /= sqrt(sqrt(fac*std::tanh(Cb[0]-1)+1));
        fac /= 2;
        μ₁ = 0;
        μ₂ = 0;
      } else {
        Cb = C;
        Rb = R;
        Ctb = Ct;
        Rtb = Rt;
      if (pow(C[0] - 1, 2) < ε) {
        break;
      }
      if (μ₁ == 0 || μ₂ == 0) {
        if (C[0] > 1 && μ₁ == 0) {
          /* We found a lower bound */
          μ₁ = μ;
        }
        if (C[0] < 1 && μ₂ == 0) {
          /* We found an upper bound */
          μ₂ = μ;
        }
        μ *= sqrt(sqrt(fac*std::tanh(C[0]-1)+1));
      } else {
        /* Once the bounds are set, we can use bisection */
        if (C[0] > 1) {
          μ₁ = μ;
        } else {
          μ₂ = μ;
        }
        μ = (μ₁ + μ₂) / 2;
      }
      }
    }

    Real e = energy(C, R, p, s, λ, β, Δτ);

    std::cerr << "y " << β << " " << e << " " << μ << std::endl;

    std::ofstream outfile(fourierFile("C", p, s, λ, τ₀, β, log2n, τₘₐₓ), std::ios::out | std::ios::binary);
    outfile.write((const char*)(C.data()), (C.size() / 2) * sizeof(Real));
    outfile.close();

    std::ofstream outfileR(fourierFile("R", p, s, λ, τ₀, β, log2n, τₘₐₓ), std::ios::out | std::ios::binary);
    outfileR.write((const char*)(R.data()), (R.size() / 2) * sizeof(Real));
    outfileR.close();
  }

  return 0;
}