diff options
Diffstat (limited to 'p-spin.hpp')
-rw-r--r-- | p-spin.hpp | 19 |
1 files changed, 18 insertions, 1 deletions
@@ -27,6 +27,7 @@ std::tuple<Scalar, Vector, Matrix> hamGradHess(const Tensor& J, const Vector& z) } std::tuple<double, Vector> WdW(const Tensor& J, const Vector& z) { + /* Vector gradient; Matrix hessian; std::tie(std::ignore, gradient, hessian) = hamGradHess(J, z); @@ -40,7 +41,23 @@ std::tuple<double, Vector> WdW(const Tensor& J, const Vector& z) { Scalar zProjGrad = z.transpose() * projGradConj; double W = projGrad.norm(); - Vector dW = hessian * (projGradConj - (zProjGrad / N) * z) - (zGrad * projGradConj + zProjGrad * gradient) / N; + Vector dW = hessian * projGradConj - (zGrad * projGradConj + (z.transpose() * projGradConj) * (gradient + hessian * z)) / N; + */ + + Vector dH; + Matrix ddH; + std::tie(std::ignore, dH, ddH) = hamGradHess(J, z); + + double N = z.size(); + Scalar dHz = (Scalar)(dH.transpose() * z) / N; + + Vector pdH = dH - dHz * z; + Vector pdHc = pdH.conjugate(); + + Scalar pdHcz = pdH.dot(z) / N; + + double W = pdH.squaredNorm(); + Vector dW = ddH * (pdHc - pdHcz * z) - (dHz * pdHc + pdHcz * dH); return {W, dW}; } |