#include "log-fourier.hpp"
#include "p-spin.hpp"
#include <complex>
#include <fstream>
#include <types.hpp>

LogarithmicFourierTransform::LogarithmicFourierTransform(unsigned N, Real k, Real Δτ, unsigned pad, Real shift) : N(N), pad(pad), k(k), Δτ(Δτ) {
  τₛ = -shift * N;
  ωₛ = -(1-shift) * N;
  sₛ = -0.5 * pad * N;
  a = reinterpret_cast<Complex*>(FFTW_ALLOC_COMPLEX(pad*N));
  â = reinterpret_cast<Complex*>(FFTW_ALLOC_COMPLEX(pad*N));
  FFTW_IMPORT_WISDOM("fftw.wisdom");
  a_to_â = FFTW_PLAN_DFT_1D(pad*N, reinterpret_cast<FFTW_COMPLEX*>(a), reinterpret_cast<FFTW_COMPLEX*>(â), FFTW_BACKWARD, 0);
  â_to_a = FFTW_PLAN_DFT_1D(pad*N, reinterpret_cast<FFTW_COMPLEX*>(â), reinterpret_cast<FFTW_COMPLEX*>(a), FFTW_BACKWARD, 0);
  FFTW_EXPORT_WISDOM("fftw.wisdom");
}

LogarithmicFourierTransform::~LogarithmicFourierTransform() {
  FFTW_DESTROY_PLAN(a_to_â);
  FFTW_DESTROY_PLAN(â_to_a);
  FFTW_FREE(a);
  FFTW_FREE(â);
  FFTW_CLEANUP();
}

Real LogarithmicFourierTransform::τ(unsigned n) const {
  return Δτ * (n + τₛ);
}

Real LogarithmicFourierTransform::ω(unsigned n) const {
  return Δτ * (n + ωₛ);
}

Real LogarithmicFourierTransform::s(unsigned n) const {
  return (n + sₛ) * 2*M_PI / (pad * N * Δτ);
}

Real LogarithmicFourierTransform::t(unsigned n) const {
  return exp(τ(n));
}

Real LogarithmicFourierTransform::ν(unsigned n) const {
  return exp(ω(n));
}

Complex Γ(Complex z) {
  gsl_sf_result logΓ;
  gsl_sf_result argΓ;

  gsl_sf_lngamma_complex_e((double)z.real(), (double)z.imag(), &logΓ, &argΓ);

  return exp((Real)logΓ.val + II * (Real)argΓ.val);
}

std::vector<Complex> LogarithmicFourierTransform::fourier(const std::vector<Real>& c, bool symmetric) {
  std::vector<Complex> ĉ(N);
  std::vector<Real> σs = {1};
  /* c is either even or zero for negative arguments */
  if (symmetric){
    σs.push_back(-1);
  }
  for (Real σ : σs) {
    for (unsigned n = 0; n < pad*N; n++) {
      if (n < N) {
        a[n] = c[n] * exp((1 - k) * τ(n));
      } else if (n >= (pad - 1) * N) {
        a[n] = c[pad*N-n-1] * exp((1 - k) * τ(pad*N-n-1));
      } else {
        a[n] = 0;
      }
    }
    FFTW_EXECUTE(a_to_â);
    for (unsigned n = 0; n < pad*N; n++) {
      â[(pad*N / 2 + n) % (pad*N)] *= std::exp(II*(0.5 * N + τₛ) * s(n) / Δτ) * std::pow(II * σ, II * s(n) - k) * Γ(k - II * s(n));
    }
    FFTW_EXECUTE(â_to_a);
    for (unsigned n = 0; n < N; n++) {
      ĉ[n] += std::exp(-k * ω(n)) * a[(pad - 1)*N+n] / (Real)(pad*N);
    }
  }

  for (unsigned n = 0; n < N; n++) {
    ĉ[n] -= ĉ[N - 1];
  }

  return ĉ;
}

std::vector<Real> LogarithmicFourierTransform::inverse(const std::vector<Complex>& ĉ) {
  std::vector<Real> c(N);
  std::vector<Real> σs = {1, -1};
  for (Real σ : σs) {
    for (unsigned n = 0; n < pad * N; n++) {
      if (n < N) {
        a[n] = (ĉ[n].real() + II * σ * ĉ[n].imag()) * std::exp((1 - k) * ω(n));
      } else if (n >= (pad - 1) * N) {
        a[n] = (ĉ[pad*N-n-1].real() + II * σ * ĉ[pad*N-n-1].imag()) * std::exp((1 - k) * ω(pad*N-n-1));
      } else {
        a[n] = 0;
      }
    }
    FFTW_EXECUTE(a_to_â);
    for (unsigned n = 0; n < pad*N; n++) {
      â[(pad*N / 2 + n) % (pad*N)] *= std::exp(-II*(0.5 * N + τₛ) * s(n) / Δτ) * std::pow(-II * σ, II * s(n) - k) * Γ(k - II * s(n));
    }
    FFTW_EXECUTE(â_to_a);
    for (unsigned n = 0; n < N; n++) {
      c[n] += std::exp(-k * τ(n)) * a[(pad - 1)*N+n].real() / (Real)(pad*N) / (2 * M_PI);
    }
  }

  for (unsigned n = 0; n < N; n++) {
    c[n] -= c[N - 1];
  }

  return c;
}

std::string logFourierFile(std::string prefix, unsigned p, unsigned s, Real λ, Real τ₀, Real β, unsigned log2n, Real Δτ, Real shift) {
  return prefix + "_" + std::to_string(p) + "_" + std::to_string(s) + "_" + std::to_string(λ) + "_" + std::to_string(τ₀) + "_" + std::to_string(β) + "_" + std::to_string(log2n) + "_" + std::to_string(Δτ)  + "_" + std::to_string(shift) + ".dat";
}

void logFourierSave(const std::vector<Real>& C, const std::vector<Real>& R, const std::vector<Complex>& Ct, const std::vector<Complex>& Rt, unsigned p, unsigned s, Real λ, Real τ₀, Real β, unsigned log2n, Real Δτ, Real k) {
    unsigned N = std::pow(2, log2n);
    std::ofstream outfile(logFourierFile("C", p, s, λ, τ₀, β, log2n, Δτ, k), std::ios::out | std::ios::binary);
    outfile.write((const char*)(C.data()), N * sizeof(Real));
    outfile.close();

    std::ofstream outfileCt(logFourierFile("Ct", p, s, λ, τ₀, β, log2n, Δτ, k), std::ios::out | std::ios::binary);
    outfileCt.write((const char*)(Ct.data()), N * sizeof(Complex));
    outfileCt.close();

    std::ofstream outfileR(logFourierFile("R", p, s, λ, τ₀, β, log2n, Δτ, k), std::ios::out | std::ios::binary);
    outfileR.write((const char*)(R.data()), N * sizeof(Real));
    outfileR.close();

    std::ofstream outfileRt(logFourierFile("Rt", p, s, λ, τ₀, β, log2n, Δτ, k), std::ios::out | std::ios::binary);
    outfileRt.write((const char*)(Rt.data()), N * sizeof(Complex));
    outfileRt.close();
}

bool logFourierLoad(std::vector<Real>& C, std::vector<Real>& R, std::vector<Complex>& Ct, std::vector<Complex>& Rt, unsigned p, unsigned s, Real λ, Real τ₀, Real β, unsigned log2n, Real Δτ, Real k) {
  std::ifstream cfile(logFourierFile("C", p, s, λ, τ₀, β, log2n, Δτ, k), std::ios::binary);
  std::ifstream rfile(logFourierFile("R", p, s, λ, τ₀, β, log2n, Δτ, k), std::ios::binary);
  std::ifstream ctfile(logFourierFile("Ct", p, s, λ, τ₀, β, log2n, Δτ, k), std::ios::binary);
  std::ifstream rtfile(logFourierFile("Rt", p, s, λ, τ₀, β, log2n, Δτ, k), std::ios::binary);

  if ((!cfile.is_open() || !rfile.is_open()) || (!ctfile.is_open() || !rtfile.is_open())) {
    return false;
  }

  unsigned N = std::pow(2, log2n);

  cfile.read((char*)(C.data()), N * sizeof(Real));
  cfile.close();

  rfile.read((char*)(R.data()), N * sizeof(Real));
  rfile.close();

  ctfile.read((char*)(Ct.data()), N * sizeof(Complex));
  ctfile.close();

  rtfile.read((char*)(Rt.data()), N * sizeof(Complex));
  rtfile.close();

  return true;
}

std::tuple<std::vector<Complex>, std::vector<Complex>> RddfCtdfCt(LogarithmicFourierTransform& fft, const std::vector<Real>& C, const std::vector<Real>& R, unsigned p, unsigned s, Real λ) {
  std::vector<Real> dfC(C.size());
  std::vector<Real> RddfC(C.size());
  for (unsigned n = 0; n < C.size(); n++) {
    RddfC[n] = R[n] * ddf(λ, p, s, C[n]);
    dfC[n] = df(λ, p, s, C[n]);
  }
  std::vector<Complex> RddfCt = fft.fourier(RddfC, false);
  std::vector<Complex> dfCt = fft.fourier(dfC, true);

  return {RddfCt, dfCt};
}

Real estimateZ(LogarithmicFourierTransform& fft, const std::vector<Real>& C, const std::vector<Complex>& Ct, const std::vector<Real>& R, const std::vector<Complex>& Rt, unsigned p, unsigned s, Real λ, Real τ₀, Real β) {
  auto [RddfCt, dfCt] = RddfCtdfCt(fft, C, R, p, s, λ);
  Real Γ₀ = 1.0;

  return ((2 * Γ₀ * std::conj(Rt[0]) + std::pow(β, 2) * (RddfCt[0] * Ct[0] + dfCt[0] * std::conj(Rt[0]))) / Ct[0]).real();
}

Real energy(const LogarithmicFourierTransform& fft, std::vector<Real>&  C, const std::vector<Real>& R, unsigned p, unsigned s, Real λ, Real β) {
  unsigned n₀ = 0;
  /*
  for (unsigned n = 0; n < C.size(); n++) {
    if (C[n] > 1 || R[n] > 1) n₀ = n % 2 == 0 ? n / 2 : (n + 1) / 2;
  }
  */
  Real E = fft.t(2*n₀) * df(λ, p, s, 1);
  for (unsigned n = n₀; n < C.size()/2-1; n++) {
    Real R₂ₙ   = R[2*n];
    Real R₂ₙ₊₁ = R[2*n+1];
    Real R₂ₙ₊₂ = R[2*n+2];
    Real C₂ₙ   = C[2*n];
    Real C₂ₙ₊₁ = C[2*n+1];
    Real C₂ₙ₊₂ = C[2*n+2];

    //if (C₂ₙ₊₂ < 0 || R₂ₙ₊₂ < 0) break;

    Real h₂ₙ   = fft.t(2*n+1) - fft.t(2*n);
    Real h₂ₙ₊₁ = fft.t(2*n+2) - fft.t(2*n+1);
    Real f₂ₙ   = R₂ₙ   * df(λ, p, s, C₂ₙ);
    Real f₂ₙ₊₁ = R₂ₙ₊₁ * df(λ, p, s, C₂ₙ₊₁);
    Real f₂ₙ₊₂ = R₂ₙ₊₂ * df(λ, p, s, C₂ₙ₊₂);

    E += (h₂ₙ + h₂ₙ₊₁) / 6 * (
          (2 - h₂ₙ₊₁ / h₂ₙ) * f₂ₙ
          + std::pow(h₂ₙ + h₂ₙ₊₁, 2) / (h₂ₙ * h₂ₙ₊₁) * f₂ₙ₊₁
          + (2 - h₂ₙ / h₂ₙ₊₁) * f₂ₙ₊₂
        );
  }
  return  β * E;
}