diff options
author | Jaron Kent-Dobias <jaron@kent-dobias.com> | 2021-02-17 16:14:33 +0100 |
---|---|---|
committer | Jaron Kent-Dobias <jaron@kent-dobias.com> | 2021-02-17 16:14:33 +0100 |
commit | 12f15f49cd8cc4ab9c809700e8cb88a0efe198d8 (patch) | |
tree | 3b7e973ce93ce942563a55363c3edd8c883da48d /p-spin.hpp | |
parent | 95df02c90a455c2e539e795acb1921b688e8bc66 (diff) | |
download | code-12f15f49cd8cc4ab9c809700e8cb88a0efe198d8.tar.gz code-12f15f49cd8cc4ab9c809700e8cb88a0efe198d8.tar.bz2 code-12f15f49cd8cc4ab9c809700e8cb88a0efe198d8.zip |
Rearranged some functions among files, and wrote the normalize function to take generic Eigen expressions.
Diffstat (limited to 'p-spin.hpp')
-rw-r--r-- | p-spin.hpp | 17 |
1 files changed, 8 insertions, 9 deletions
@@ -5,6 +5,11 @@ #include "tensor.hpp" #include "factorial.hpp" +template <typename Derived> +Vector<typename Derived::Scalar> normalize(const Eigen::MatrixBase<Derived>& z) { + return z * sqrt((double)z.size() / (typename Derived::Scalar)(z.transpose() * z)); +} + template <class Scalar, int p> std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scalar, p>& J, const Vector<Scalar>& z) { Matrix<Scalar> Jz = contractDown(J, z); // Contracts J into p - 2 copies of z. @@ -21,14 +26,8 @@ std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scal } template <class Scalar> -Vector<Scalar> normalize(const Vector<Scalar>& z) { - return z * sqrt((double)z.size() / (Scalar)(z.transpose() * z)); -} - -template <class Scalar> -Vector<Scalar> project(const Vector<Scalar>& z, const Vector<Scalar>& x) { - Scalar xz = x.transpose() * z; - return x - (xz / z.squaredNorm()) * z.conjugate(); +Vector<Scalar> zDot(const Vector<Scalar>& z, const Vector<Scalar>& dH) { + return -dH.conjugate() + (dH.dot(z) / z.squaredNorm()) * z.conjugate(); } template <class Scalar, int p> @@ -37,7 +36,7 @@ std::tuple<double, Vector<Scalar>> WdW(const Tensor<Scalar, p>& J, const Vector< Matrix<Scalar> ddH; std::tie(std::ignore, dH, ddH) = hamGradHess(J, z); - Vector<Scalar> dzdt = project(z, dH.conjugate().eval()); + Vector<Scalar> dzdt = zDot(z, dH); double a = z.squaredNorm(); Scalar A = (Scalar)(z.transpose() * dzdt) / a; |