#pragma once

#include "matrix.hpp"
#include "spin.hpp"
#include "vector.hpp"

template <class T, int D> class Euclidean {
public:
  Vector<T, D> t;
  Matrix<T, D> r;

  Euclidean(T L) {
    for (unsigned i = 0; i < D; i++) {
      t(i) = 0;
      r(i, i) = 1;
      for (unsigned j = 1; j < D; j++) {
        r(i, (i + j) % D) = 0;
      }
    }
  }

  Euclidean(Vector<T, D> t0, Matrix<T, D> r0) {
    t = t0;
    r = r0;
  }

  template <class S> Spin<T, D, S> act(const Spin<T, D, S>& s) const {
    Spin<T, D, S> s_new;

    s_new.x = t + r * s.x;
    s_new.s = s.s;

    return s_new;
  }

  Euclidean act(const Euclidean& x) const {
    Vector<T, D> tnew = r * x.t + t;
    Matrix<T, D> rnew = r * x.r;

    Euclidean pnew(tnew, rnew);

    return pnew;
  }

  Euclidean inverse() const {
    Vector<T, D> tnew = -r.transpose() * t;
    Matrix<T, D> rnew = r.transpose();

    Euclidean pnew(tnew, rnew);

    return pnew;
  }
};

template <class T, int D> class TorusGroup {
private:
  T L;

public:
  Vector<T, D> t;
  Matrix<T, D> r;

  /** brief TorusGroup - default constructor, constructs the identity
   */
  TorusGroup(T L) : L(L) {
    for (unsigned i = 0; i < D; i++) {
      t(i) = 0;
      r(i, i) = 1;
      for (unsigned j = 1; j < D; j++) {
        r(i, (i + j) % D) = 0;
      }
    }
  }

  TorusGroup(T L, Vector<T, D> t0, Matrix<T, D> r0) : L(L) {
    t = t0;
    r = r0;
  }

  template <class S> Spin<T, D, S> act(const Spin<T, D, S>& s) const {
    Spin<T, D, S> s_new;

    s_new.x = t + r * s.x;
    s_new.s = s.s;

    for (unsigned i = 0; i < D; i++) {
      s_new.x(i) = fmod(L + s_new.x(i), L);
    }

    return s_new;
  }

  TorusGroup act(const TorusGroup& x) const {
    Vector<T, D> tnew = r * x.t + t;
    Matrix<T, D> rnew = r * x.r;

    for (unsigned i = 0; i < D; i++) {
      tnew(i) = fmod(L + tnew(i), L);
    }

    TorusGroup pnew(this->L, tnew, rnew);

    return pnew;
  }

  TorusGroup inverse() const {
    Vector<T, D> tnew = -r.transpose() * t;
    Matrix<T, D> rnew = r.transpose();

    TorusGroup pnew(this->L, tnew, rnew);

    return pnew;
  }
};