summaryrefslogtreecommitdiff
path: root/p-spin.hpp
diff options
context:
space:
mode:
authorJaron Kent-Dobias <jaron@kent-dobias.com>2021-02-25 15:28:11 +0100
committerJaron Kent-Dobias <jaron@kent-dobias.com>2021-02-25 15:28:11 +0100
commit3276bdd1e9796fec71e169e6c41d77da72b3a4fb (patch)
tree32be646f64c83751572eb867f9354e74d146ef6b /p-spin.hpp
parentc16f7fc3fd8206e5f05e07353328538b2f5c8b6b (diff)
downloadcode-3276bdd1e9796fec71e169e6c41d77da72b3a4fb.tar.gz
code-3276bdd1e9796fec71e169e6c41d77da72b3a4fb.tar.bz2
code-3276bdd1e9796fec71e169e6c41d77da72b3a4fb.zip
Many changes.
Diffstat (limited to 'p-spin.hpp')
-rw-r--r--p-spin.hpp11
1 files changed, 6 insertions, 5 deletions
diff --git a/p-spin.hpp b/p-spin.hpp
index bd3cacc..f1dc07f 100644
--- a/p-spin.hpp
+++ b/p-spin.hpp
@@ -2,12 +2,13 @@
#include <eigen3/Eigen/Dense>
+#include "types.hpp"
#include "tensor.hpp"
#include "factorial.hpp"
template <typename Derived>
Vector<typename Derived::Scalar> normalize(const Eigen::MatrixBase<Derived>& z) {
- return z * sqrt((double)z.size() / (typename Derived::Scalar)(z.transpose() * z));
+ return z * sqrt((Real)z.size() / (typename Derived::Scalar)(z.transpose() * z));
}
template <class Scalar, int p>
@@ -16,7 +17,7 @@ std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scal
Vector<Scalar> Jzz = Jz * z;
Scalar Jzzz = Jzz.transpose() * z;
- double pBang = factorial(p);
+ Real pBang = factorial(p);
Matrix<Scalar> hessian = ((p - 1) * p / pBang) * Jz;
Vector<Scalar> gradient = (p / pBang) * Jzz;
@@ -31,18 +32,18 @@ Vector<Scalar> zDot(const Vector<Scalar>& z, const Vector<Scalar>& dH) {
}
template <class Scalar, int p>
-std::tuple<double, Vector<Scalar>> WdW(const Tensor<Scalar, p>& J, const Vector<Scalar>& z) {
+std::tuple<Real, Vector<Scalar>> WdW(const Tensor<Scalar, p>& J, const Vector<Scalar>& z) {
Vector<Scalar> dH;
Matrix<Scalar> ddH;
std::tie(std::ignore, dH, ddH) = hamGradHess(J, z);
Vector<Scalar> dzdt = zDot(z, dH);
- double a = z.squaredNorm();
+ Real a = z.squaredNorm();
Scalar A = (Scalar)(z.transpose() * dzdt) / a;
Scalar B = dH.dot(z) / a;
- double W = dzdt.squaredNorm();
+ Real W = dzdt.squaredNorm();
Vector<Scalar> dW = ddH * (dzdt - A * z.conjugate())
+ 2 * (conj(A) * B * z).real()
- conj(B) * dzdt - conj(A) * dH.conjugate();