#include "fourier.hpp"
#include <getopt.h>
#include <iostream>
#include <fstream>

int main(int argc, char* argv[]) {
  unsigned p = 2;
  unsigned s = 2;
  Real λ = 0.5;
  Real τ₀ = 0;
  Real y₀ = 0;
  Real yₘₐₓ = 0.5;
  Real Δy = 0.05;

  unsigned log2n = 8;
  Real τₘₐₓ = 20;

  unsigned maxIterations = 1000;
  Real ε = 1e-14;
  Real γ = 1;

  bool loadData = false;

  int opt;

  while ((opt = getopt(argc, argv, "p:s:2:T:t:0:y:d:I:g:l")) != -1) {
    switch (opt) {
    case 'p':
      p = atoi(optarg);
      break;
    case 's':
      s = atoi(optarg);
      break;
    case '2':
      log2n = atoi(optarg);
      break;
    case 'T':
      τₘₐₓ = atof(optarg);
      break;
    case 't':
      τ₀ = atof(optarg);
      break;
    case '0':
      y₀ = atof(optarg);
      break;
    case 'y':
      yₘₐₓ = atof(optarg);
      break;
    case 'd':
      Δy = atof(optarg);
      break;
    case 'I':
      maxIterations = (unsigned)atof(optarg);
      break;
    case 'g':
      γ = atof(optarg);
      break;
    case 'l':
      loadData = true;
      break;
    default:
      exit(1);
    }
  }

  unsigned n = pow(2, log2n);

  Real Δτ = (1 + τ₀ / 2) * τₘₐₓ / M_PI / n;
  Real Δω = M_PI / ((1 + τ₀ / 2) * τₘₐₓ);

  Real z = 0.5;
  Real Γ₀ = 1 + τ₀ / 2;

  std::vector<Real> C(2 * n);
  std::vector<Real> R(2 * n);

  FourierTransform fft(n, Δω, Δτ);
  std::vector<Complex> Ct;
  std::vector<Complex> Rt;

  Real y = y₀;

  if (!loadData) {
    // start from the exact solution for τ₀ = 0
    for (unsigned i = 0; i < n; i++) {
      Real τ = i * Δτ * M_PI;
      if (τ₀ > 0) {
        C[i] = Γ₀ / 2 * (exp(-z * τ) - z * τ₀ * exp(-τ / τ₀)) / (z - pow(z, 3) * pow(τ₀, 2));
      } else {
        C[i] = Γ₀ / 2 * exp(-z * τ) / z;
      }
      if (i > 0) {
        C[2 * n - i] = C[i];
      }
      R[i] = exp(-z * τ);
    }
    Ct = fft.fourier(C);
    Rt = fft.fourier(R);
  } else {
    std::ifstream cfile(fourierFile("C", p, s, λ, τ₀, y, log2n, τₘₐₓ), std::ios::binary);
    cfile.read((char*)(C.data()), (C.size() / 2) * sizeof(Real));
    cfile.close();
    for (unsigned i = 1; i < n; i++) {
      C[2 * n - i] = C[i];
    }
    std::ifstream rfile(fourierFile("R", p, s, λ, τ₀, y, log2n, τₘₐₓ), std::ios::binary);
    rfile.read((char*)(R.data()), (R.size() / 2) * sizeof(Real));
    rfile.close();

    Ct = fft.fourier(C);
    Rt = fft.fourier(R);

    z = estimateZ(fft, C, Ct, R, Rt, p, s, λ, τ₀, y);
  }

  while (y += Δy, y <= yₘₐₓ) {
    Real ΔC = 1;
    Real ΔCprev = 1000;
    unsigned it = 0;
    while (sqrt(2 * ΔC / C.size()) > ε) {
      it++;
      auto [RddfCt, dfCt] = RddfCtdfCt(fft, C, R, p, s, λ);

      for (unsigned i = 0; i < Rt.size(); i++) {
        Real ω = i * Δω;
        Rt[i] = (1.0 + pow(y, 2) * RddfCt[i] * Rt[i]) / (z + 1i * ω);
      }

      for (unsigned i = 0; i < Ct.size(); i++) {
        Real ω = i * Δω;
        Ct[i] = (Γ₀ * std::conj(Rt[i]) / (1 + pow(τ₀ * ω, 2)) + pow(y, 2) * (RddfCt[i] * Ct[i] + dfCt[i] * std::conj(Rt[i]))) / (z + 1i * ω);
      }

      std::vector<Real> Cnew = fft.inverse(Ct);
      std::vector<Real> Rnew = fft.inverse(Rt);
      for (unsigned i = n; i < 2 * n; i++) {
        Rnew[i] = 0;
      }

      ΔC = 0;
      for (unsigned i = 0; i < Cnew.size() / 2; i++) {
        ΔC += pow(Cnew[i] - C[i], 2);
        ΔC += pow(Rnew[i] - R[i], 2);
      }

      for (unsigned i = 0; i < Cnew.size(); i++) {
        C[i] += γ * (Cnew[i] - C[i]);
      }

      for (unsigned i = 0; i < Rnew.size() / 2; i++) {
        R[i] += γ * (Rnew[i] - R[i]);
      }

      z *= Cnew[0];

      if (it % maxIterations == 0) {
        if (ΔC < ΔCprev) {
          γ = std::min(1.1 * γ, 1.0);
        } else {
          γ /= 2;
        }

        ΔCprev = ΔC;
      }

      std::cerr << it << " " << p << " " << s << " " << τ₀ << " " << y << " " << sqrt(2 * ΔC / C.size()) << " " << γ << std::endl;

    }

    Real e = energy(C, R, p, s, λ, y, Δτ);

    std::cerr << "y " << y << " " << e << " " << z << std::endl;

    std::ofstream outfile(fourierFile("C", p, s, λ, τ₀, y, log2n, τₘₐₓ), std::ios::out | std::ios::binary);
    outfile.write((const char*)(C.data()), (C.size() / 2) * sizeof(Real));
    outfile.close();

    std::ofstream outfileR(fourierFile("R", p, s, λ, τ₀, y, log2n, τₘₐₓ), std::ios::out | std::ios::binary);
    outfileR.write((const char*)(R.data()), (R.size() / 2) * sizeof(Real));
    outfileR.close();
  }

  return 0;
}