diff options
author | Jaron Kent-Dobias <jaron@kent-dobias.com> | 2021-02-25 15:28:11 +0100 |
---|---|---|
committer | Jaron Kent-Dobias <jaron@kent-dobias.com> | 2021-02-25 15:28:11 +0100 |
commit | 3276bdd1e9796fec71e169e6c41d77da72b3a4fb (patch) | |
tree | 32be646f64c83751572eb867f9354e74d146ef6b /p-spin.hpp | |
parent | c16f7fc3fd8206e5f05e07353328538b2f5c8b6b (diff) | |
download | code-3276bdd1e9796fec71e169e6c41d77da72b3a4fb.tar.gz code-3276bdd1e9796fec71e169e6c41d77da72b3a4fb.tar.bz2 code-3276bdd1e9796fec71e169e6c41d77da72b3a4fb.zip |
Many changes.
Diffstat (limited to 'p-spin.hpp')
-rw-r--r-- | p-spin.hpp | 11 |
1 files changed, 6 insertions, 5 deletions
@@ -2,12 +2,13 @@ #include <eigen3/Eigen/Dense> +#include "types.hpp" #include "tensor.hpp" #include "factorial.hpp" template <typename Derived> Vector<typename Derived::Scalar> normalize(const Eigen::MatrixBase<Derived>& z) { - return z * sqrt((double)z.size() / (typename Derived::Scalar)(z.transpose() * z)); + return z * sqrt((Real)z.size() / (typename Derived::Scalar)(z.transpose() * z)); } template <class Scalar, int p> @@ -16,7 +17,7 @@ std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scal Vector<Scalar> Jzz = Jz * z; Scalar Jzzz = Jzz.transpose() * z; - double pBang = factorial(p); + Real pBang = factorial(p); Matrix<Scalar> hessian = ((p - 1) * p / pBang) * Jz; Vector<Scalar> gradient = (p / pBang) * Jzz; @@ -31,18 +32,18 @@ Vector<Scalar> zDot(const Vector<Scalar>& z, const Vector<Scalar>& dH) { } template <class Scalar, int p> -std::tuple<double, Vector<Scalar>> WdW(const Tensor<Scalar, p>& J, const Vector<Scalar>& z) { +std::tuple<Real, Vector<Scalar>> WdW(const Tensor<Scalar, p>& J, const Vector<Scalar>& z) { Vector<Scalar> dH; Matrix<Scalar> ddH; std::tie(std::ignore, dH, ddH) = hamGradHess(J, z); Vector<Scalar> dzdt = zDot(z, dH); - double a = z.squaredNorm(); + Real a = z.squaredNorm(); Scalar A = (Scalar)(z.transpose() * dzdt) / a; Scalar B = dH.dot(z) / a; - double W = dzdt.squaredNorm(); + Real W = dzdt.squaredNorm(); Vector<Scalar> dW = ddH * (dzdt - A * z.conjugate()) + 2 * (conj(A) * B * z).real() - conj(B) * dzdt - conj(A) * dH.conjugate(); |