diff options
Diffstat (limited to 'p-spin.hpp')
-rw-r--r-- | p-spin.hpp | 17 |
1 files changed, 8 insertions, 9 deletions
@@ -5,6 +5,11 @@ #include "tensor.hpp" #include "factorial.hpp" +template <typename Derived> +Vector<typename Derived::Scalar> normalize(const Eigen::MatrixBase<Derived>& z) { + return z * sqrt((double)z.size() / (typename Derived::Scalar)(z.transpose() * z)); +} + template <class Scalar, int p> std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scalar, p>& J, const Vector<Scalar>& z) { Matrix<Scalar> Jz = contractDown(J, z); // Contracts J into p - 2 copies of z. @@ -21,14 +26,8 @@ std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scal } template <class Scalar> -Vector<Scalar> normalize(const Vector<Scalar>& z) { - return z * sqrt((double)z.size() / (Scalar)(z.transpose() * z)); -} - -template <class Scalar> -Vector<Scalar> project(const Vector<Scalar>& z, const Vector<Scalar>& x) { - Scalar xz = x.transpose() * z; - return x - (xz / z.squaredNorm()) * z.conjugate(); +Vector<Scalar> zDot(const Vector<Scalar>& z, const Vector<Scalar>& dH) { + return -dH.conjugate() + (dH.dot(z) / z.squaredNorm()) * z.conjugate(); } template <class Scalar, int p> @@ -37,7 +36,7 @@ std::tuple<double, Vector<Scalar>> WdW(const Tensor<Scalar, p>& J, const Vector< Matrix<Scalar> ddH; std::tie(std::ignore, dH, ddH) = hamGradHess(J, z); - Vector<Scalar> dzdt = project(z, dH.conjugate().eval()); + Vector<Scalar> dzdt = zDot(z, dH); double a = z.squaredNorm(); Scalar A = (Scalar)(z.transpose() * dzdt) / a; |