summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--dynamics.hpp50
1 files changed, 28 insertions, 22 deletions
diff --git a/dynamics.hpp b/dynamics.hpp
index 561714e..0597f35 100644
--- a/dynamics.hpp
+++ b/dynamics.hpp
@@ -4,6 +4,8 @@
#include <eigen3/Eigen/LU>
#include <random>
+#include <iostream>
+
#include "p-spin.hpp"
#include "stereographic.hpp"
@@ -45,35 +47,39 @@ std::tuple<Real, Vector<Scalar>> gradientDescent(const Tensor<Scalar, p>& J, con
template <class Real, class Scalar, int p>
Vector<Scalar> findSaddle(const Tensor<Scalar, p>& J, const Vector<Scalar>& z0, Real ε, Real δW = 2, Real γ0 = 1, Real δγ = 2) {
Vector<Scalar> z = z0;
- Vector<Scalar> ζ = euclideanToStereographic(z);
-
- Real W;
- std::tie(W, std::ignore) = WdW(J, z);
Vector<Scalar> dH;
Matrix<Scalar> ddH;
- std::tie(std::ignore, dH, ddH) = stereographicHamGradHess(J, ζ, z);
+ std::tie(std::ignore, dH, ddH) = hamGradHess(J, z);
- while (W > ε) {
- // ddH is complex symmetric, which is (almost always) invertible, so a
- // partial pivot LU decomposition can be used.
- Vector<Scalar> dζ = ddH.partialPivLu().solve(dH);
- Vector<Scalar> ζNew = ζ - dζ;
- Vector<Scalar> zNew = stereographicToEuclidean(ζNew);
+ Scalar zz = z.transpose() * z;
+ Vector<Scalar> ż = zDot(z, dH) + z * (zz - (Real)z.size());
+ Matrix<Scalar> dż = dzDot(z, dH) + Matrix<Scalar>::Identity(z.size(), z.size()) * (zz - (Real)z.size()) + 2.0 * z * z.transpose();
+ Matrix<Scalar> dżc = dzDotConjugate(z, dH, ddH);
- Real WNew;
- std::tie(WNew, std::ignore) = WdW(J, zNew);
+ Vector<Scalar> b(2 * z.size());
+ Matrix<Scalar> M(2 * z.size(), 2 * z.size());
- if (WNew < W) { // If Newton's step lowered the objective, accept it!
- ζ = ζNew;
- z = zNew;
- W = WNew;
- } else { // Otherwise, do gradient descent until W is a factor δW smaller.
- std::tie(W, z) = gradientDescent(J, z, W / δW, γ0, δγ);
- ζ = euclideanToStereographic(z);
- }
+ b << ż.conjugate(), ż;
+ M << dż.conjugate(), dżc, dżc.conjugate(), dż;
+
+ while (ż.norm() > ε) {
+ Vector<Scalar> dz = M.partialPivLu().solve(b).tail(z.size());
+ dz -= z.conjugate().dot(dz) / z.squaredNorm() * z.conjugate();
+ z = normalize(z - dz);
+
+ std::cout << "error : " << z.transpose() * z << " "<< ż.norm() << " " << dz.norm() << std::endl;
+ getchar();
+
+ std::tie(std::ignore, dH, ddH) = hamGradHess(J, z);
+
+ zz = z.transpose() * z;
+ ż = zDot(z, dH) + z * (zz - (Real)z.size());
+ dż = dzDot(z, dH) + Matrix<Scalar>::Identity(z.size(), z.size()) * (zz - (Real)z.size()) + 2.0 * z * z.transpose();
+ dżc = dzDotConjugate(z, dH, ddH);
- std::tie(std::ignore, dH, ddH) = stereographicHamGradHess(J, ζ, z);
+ b << ż.conjugate(), ż;
+ M << dż.conjugate(), dżc, dżc.conjugate(), dż;
}
return z;