From d448af5010b664025c816dc2c6e383ac7bea3491 Mon Sep 17 00:00:00 2001 From: Jaron Kent-Dobias Date: Tue, 9 Nov 2021 00:44:01 +0100 Subject: Generalized energy function to multiple tensor arguments in anticipation of the mixed p-spin. --- p-spin.hpp | 39 +++++++++++++-------------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/p-spin.hpp b/p-spin.hpp index 15d2525..6111b75 100644 --- a/p-spin.hpp +++ b/p-spin.hpp @@ -12,19 +12,26 @@ Vector normalize(const Eigen::MatrixBase& z) return z * sqrt((Real)z.size() / (typename Derived::Scalar)(z.transpose() * z)); } -template -std::tuple, Matrix> hamGradHess(const Tensor& J, const Vector& z) { +template +std::tuple, Matrix> hamGradHess(const Tensor& J, const Tensor& ... Js, const Vector& z) { Matrix Jz = contractDown(J, z); // Contracts J into p - 2 copies of z. Vector Jzz = Jz * z; Scalar Jzzz = Jzz.transpose() * z; Real pBang = factorial(p); - Matrix hessian = ((p - 1) * p / pBang) * Jz; - Vector gradient = (p / pBang) * Jzz; - Scalar hamiltonian = Jzzz / pBang; + Matrix ddH = ((p - 1) * p / pBang) * Jz; + Vector dH = (p / pBang) * Jzz; + Scalar H = Jzzz / pBang; + + if constexpr (sizeof...(Js) > 0) { + auto [Hs, dHs, ddHs] = hamGradHess(Js..., z); + H += Hs; + dH += dHs; + ddH += ddHs; + } - return {hamiltonian, gradient, hessian}; + return {H, dH, ddH}; } template @@ -71,26 +78,6 @@ Vector zDot(const Vector& z, const Vector& dH) { return -dH.conjugate() + (dH.dot(z) / z.squaredNorm()) * z.conjugate(); } -template -std::tuple> WdW(const Tensor& J, const Vector& z) { - Vector dH; - Matrix ddH; - std::tie(std::ignore, dH, ddH) = hamGradHess(J, z); - - Vector dzdt = zDot(z, dH); - - Real a = z.squaredNorm(); - Scalar A = (Scalar)(z.transpose() * dzdt) / a; - Scalar B = dH.dot(z) / a; - - Real W = dzdt.squaredNorm(); - Vector dW = ddH * (dzdt - A * z.conjugate()) - + 2 * (conj(A) * B * z).real() - - conj(B) * dzdt - conj(A) * dH.conjugate(); - - return {W, dW}; -} - template Matrix dzDot(const Vector& z, const Vector& dH) { Real z² = z.squaredNorm(); -- cgit v1.2.3-70-g09d2