summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaron Kent-Dobias <jaron@kent-dobias.com>2021-01-14 15:44:26 +0100
committerJaron Kent-Dobias <jaron@kent-dobias.com>2021-01-14 15:44:26 +0100
commitdedcfe0c9aa1ae79dc4d26c559f239571bb26394 (patch)
treefd81b21390a9e983c498ea8e3c46258744ad53cb
parente6b4d83097f4442edf2290236e1b092724d8fd74 (diff)
downloadcode-dedcfe0c9aa1ae79dc4d26c559f239571bb26394.tar.gz
code-dedcfe0c9aa1ae79dc4d26c559f239571bb26394.tar.bz2
code-dedcfe0c9aa1ae79dc4d26c559f239571bb26394.zip
Updated W and dW to use correct constrained gradient.
-rw-r--r--p-spin.hpp38
1 files changed, 13 insertions, 25 deletions
diff --git a/p-spin.hpp b/p-spin.hpp
index 3532556..91e0152 100644
--- a/p-spin.hpp
+++ b/p-spin.hpp
@@ -26,38 +26,26 @@ std::tuple<Scalar, Vector, Matrix> hamGradHess(const Tensor& J, const Vector& z)
return {hamiltonian, gradient, hessian};
}
-std::tuple<double, Vector> WdW(const Tensor& J, const Vector& z) {
- /*
- Vector gradient;
- Matrix hessian;
- std::tie(std::ignore, gradient, hessian) = hamGradHess(J, z);
-
- Scalar zGrad = gradient.transpose() * z;
- double N = z.size();
-
- Vector projGrad = gradient - (zGrad / N) * z;
- Vector projGradConj = projGrad.conjugate();
-
- Scalar zProjGrad = z.transpose() * projGradConj;
-
- double W = projGrad.norm();
- Vector dW = hessian * projGradConj - (zGrad * projGradConj + (z.transpose() * projGradConj) * (gradient + hessian * z)) / N;
- */
+Vector project(const Vector& z, const Vector& x) {
+ Scalar xz = x.transpose() * z;
+ return x - (xz / z.squaredNorm()) * z.conjugate();
+}
+std::tuple<double, Vector> WdW(const Tensor& J, const Vector& z) {
Vector dH;
Matrix ddH;
std::tie(std::ignore, dH, ddH) = hamGradHess(J, z);
- double N = z.size();
- Scalar dHz = (Scalar)(dH.transpose() * z) / N;
-
- Vector pdH = dH - dHz * z;
- Vector pdHc = pdH.conjugate();
+ double a = z.squaredNorm();
+ Vector dzdt = project(z, dH.conjugate());
- Scalar pdHcz = pdH.dot(z) / N;
+ Scalar A = (Scalar)(z.transpose() * dzdt) / a;
+ Scalar B = dH.dot(z) / a;
- double W = pdH.squaredNorm();
- Vector dW = ddH * (pdHc - pdHcz * z) - (dHz * pdHc + pdHcz * dH);
+ double W = dzdt.squaredNorm();
+ Vector dW = ddH * (dzdt - A * z.conjugate())
+ + 2 * (conj(A) * B * z).real()
+ - conj(B) * dzdt - conj(A) * dH.conjugate();
return {W, dW};
}