diff options
-rw-r--r-- | p-spin.hpp | 39 |
1 files changed, 13 insertions, 26 deletions
@@ -12,19 +12,26 @@ Vector<typename Derived::Scalar> normalize(const Eigen::MatrixBase<Derived>& z) return z * sqrt((Real)z.size() / (typename Derived::Scalar)(z.transpose() * z)); } -template <class Scalar, int p> -std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scalar, p>& J, const Vector<Scalar>& z) { +template <class Scalar, int p, int ... ps> +std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scalar, p>& J, const Tensor<Scalar, ps>& ... Js, const Vector<Scalar>& z) { Matrix<Scalar> Jz = contractDown(J, z); // Contracts J into p - 2 copies of z. Vector<Scalar> Jzz = Jz * z; Scalar Jzzz = Jzz.transpose() * z; Real pBang = factorial(p); - Matrix<Scalar> hessian = ((p - 1) * p / pBang) * Jz; - Vector<Scalar> gradient = (p / pBang) * Jzz; - Scalar hamiltonian = Jzzz / pBang; + Matrix<Scalar> ddH = ((p - 1) * p / pBang) * Jz; + Vector<Scalar> dH = (p / pBang) * Jzz; + Scalar H = Jzzz / pBang; + + if constexpr (sizeof...(Js) > 0) { + auto [Hs, dHs, ddHs] = hamGradHess(Js..., z); + H += Hs; + dH += dHs; + ddH += ddHs; + } - return {hamiltonian, gradient, hessian}; + return {H, dH, ddH}; } template <class Scalar, int p> @@ -71,26 +78,6 @@ Vector<Scalar> zDot(const Vector<Scalar>& z, const Vector<Scalar>& dH) { return -dH.conjugate() + (dH.dot(z) / z.squaredNorm()) * z.conjugate(); } -template <class Scalar, int p> -std::tuple<Real, Vector<Scalar>> WdW(const Tensor<Scalar, p>& J, const Vector<Scalar>& z) { - Vector<Scalar> dH; - Matrix<Scalar> ddH; - std::tie(std::ignore, dH, ddH) = hamGradHess(J, z); - - Vector<Scalar> dzdt = zDot(z, dH); - - Real a = z.squaredNorm(); - Scalar A = (Scalar)(z.transpose() * dzdt) / a; - Scalar B = dH.dot(z) / a; - - Real W = dzdt.squaredNorm(); - Vector<Scalar> dW = ddH * (dzdt - A * z.conjugate()) - + 2 * (conj(A) * B * z).real() - - conj(B) * dzdt - conj(A) * dH.conjugate(); - - return {W, dW}; -} - template <class Scalar> Matrix<Scalar> dzDot(const Vector<Scalar>& z, const Vector<Scalar>& dH) { Real z² = z.squaredNorm(); |