summaryrefslogtreecommitdiff
path: root/p-spin.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'p-spin.hpp')
-rw-r--r--p-spin.hpp17
1 files changed, 8 insertions, 9 deletions
diff --git a/p-spin.hpp b/p-spin.hpp
index 480b3ca..bd3cacc 100644
--- a/p-spin.hpp
+++ b/p-spin.hpp
@@ -5,6 +5,11 @@
#include "tensor.hpp"
#include "factorial.hpp"
+template <typename Derived>
+Vector<typename Derived::Scalar> normalize(const Eigen::MatrixBase<Derived>& z) {
+ return z * sqrt((double)z.size() / (typename Derived::Scalar)(z.transpose() * z));
+}
+
template <class Scalar, int p>
std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scalar, p>& J, const Vector<Scalar>& z) {
Matrix<Scalar> Jz = contractDown(J, z); // Contracts J into p - 2 copies of z.
@@ -21,14 +26,8 @@ std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scal
}
template <class Scalar>
-Vector<Scalar> normalize(const Vector<Scalar>& z) {
- return z * sqrt((double)z.size() / (Scalar)(z.transpose() * z));
-}
-
-template <class Scalar>
-Vector<Scalar> project(const Vector<Scalar>& z, const Vector<Scalar>& x) {
- Scalar xz = x.transpose() * z;
- return x - (xz / z.squaredNorm()) * z.conjugate();
+Vector<Scalar> zDot(const Vector<Scalar>& z, const Vector<Scalar>& dH) {
+ return -dH.conjugate() + (dH.dot(z) / z.squaredNorm()) * z.conjugate();
}
template <class Scalar, int p>
@@ -37,7 +36,7 @@ std::tuple<double, Vector<Scalar>> WdW(const Tensor<Scalar, p>& J, const Vector<
Matrix<Scalar> ddH;
std::tie(std::ignore, dH, ddH) = hamGradHess(J, z);
- Vector<Scalar> dzdt = project(z, dH.conjugate().eval());
+ Vector<Scalar> dzdt = zDot(z, dH);
double a = z.squaredNorm();
Scalar A = (Scalar)(z.transpose() * dzdt) / a;