summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--dynamics.hpp6
-rw-r--r--p-spin.hpp17
-rw-r--r--stokes.hpp16
3 files changed, 14 insertions, 25 deletions
diff --git a/dynamics.hpp b/dynamics.hpp
index 22d590a..d421d13 100644
--- a/dynamics.hpp
+++ b/dynamics.hpp
@@ -21,8 +21,7 @@ std::tuple<double, Vector<Scalar>> gradientDescent(const Tensor<Scalar, p>& J, c
auto [W, dW] = WdW(J, z);
while (W > ε) {
- Vector<Scalar> zNewTmp = z - γ * dW.conjugate();
- Vector<Scalar> zNew = normalize(zNewTmp);
+ Vector<Scalar> zNew = normalize(z - γ * dW.conjugate());
auto [WNew, dWNew] = WdW(J, zNew);
@@ -102,8 +101,7 @@ std::tuple<double, Vector<Scalar>> metropolis(const Tensor<Scalar, p>& J, const
std::uniform_real_distribution<double> D(0, 1);
for (unsigned i = 0; i < N; i++) {
- Vector<Scalar> zNewTmp = z + γ * randomVector<Scalar>(z.size(), d, r);
- Vector<Scalar> zNew = normalize(zNewTmp);
+ Vector<Scalar> zNew = normalize(z + γ * randomVector<Scalar>(z.size(), d, r));
double ENew = energy(J, zNew);
diff --git a/p-spin.hpp b/p-spin.hpp
index 480b3ca..bd3cacc 100644
--- a/p-spin.hpp
+++ b/p-spin.hpp
@@ -5,6 +5,11 @@
#include "tensor.hpp"
#include "factorial.hpp"
+template <typename Derived>
+Vector<typename Derived::Scalar> normalize(const Eigen::MatrixBase<Derived>& z) {
+ return z * sqrt((double)z.size() / (typename Derived::Scalar)(z.transpose() * z));
+}
+
template <class Scalar, int p>
std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scalar, p>& J, const Vector<Scalar>& z) {
Matrix<Scalar> Jz = contractDown(J, z); // Contracts J into p - 2 copies of z.
@@ -21,14 +26,8 @@ std::tuple<Scalar, Vector<Scalar>, Matrix<Scalar>> hamGradHess(const Tensor<Scal
}
template <class Scalar>
-Vector<Scalar> normalize(const Vector<Scalar>& z) {
- return z * sqrt((double)z.size() / (Scalar)(z.transpose() * z));
-}
-
-template <class Scalar>
-Vector<Scalar> project(const Vector<Scalar>& z, const Vector<Scalar>& x) {
- Scalar xz = x.transpose() * z;
- return x - (xz / z.squaredNorm()) * z.conjugate();
+Vector<Scalar> zDot(const Vector<Scalar>& z, const Vector<Scalar>& dH) {
+ return -dH.conjugate() + (dH.dot(z) / z.squaredNorm()) * z.conjugate();
}
template <class Scalar, int p>
@@ -37,7 +36,7 @@ std::tuple<double, Vector<Scalar>> WdW(const Tensor<Scalar, p>& J, const Vector<
Matrix<Scalar> ddH;
std::tie(std::ignore, dH, ddH) = hamGradHess(J, z);
- Vector<Scalar> dzdt = project(z, dH.conjugate().eval());
+ Vector<Scalar> dzdt = zDot(z, dH);
double a = z.squaredNorm();
Scalar A = (Scalar)(z.transpose() * dzdt) / a;
diff --git a/stokes.hpp b/stokes.hpp
index 95bb9c4..117f4de 100644
--- a/stokes.hpp
+++ b/stokes.hpp
@@ -1,11 +1,6 @@
#include "p-spin.hpp"
template <class Scalar>
-Vector<Scalar> zDot(const Vector<Scalar>& z, const Vector<Scalar>& dH) {
- return -dH.conjugate() + (dH.dot(z) / z.squaredNorm()) * z.conjugate();
-}
-
-template <class Scalar>
double segmentCost(const Vector<Scalar>& z, const Vector<Scalar>& dz, const Vector<Scalar>& dH) {
Vector<Scalar> zD = zDot(z, dH);
return 1.0 - pow(real(zD.dot(dz)), 2) / zD.squaredNorm() / dz.squaredNorm();
@@ -68,8 +63,7 @@ class Rope {
Rope(unsigned N, const Vector<Scalar>& z1, const Vector<Scalar>& z2) : z(N + 2) {
for (unsigned i = 0; i < N + 2; i++) {
- z[i] = z1 + (z2 - z1) * ((double)i / (N + 1.0));
- z[i] = normalize(z[i]);
+ z[i] = normalize(z1 + (z2 - z1) * ((double)i / (N + 1.0)));
}
}
@@ -106,8 +100,7 @@ class Rope {
while (rNew.cost(J) >= this->cost(J)) {
for (unsigned i = 1; i < z.size() - 1; i++) {
- rNew.z[i] = z[i] - δ * Δz[i].conjugate();
- rNew.z[i] = normalize(rNew.z[i]);
+ rNew.z[i] = normalize(z[i] - δ * Δz[i].conjugate());
}
δ /= 2;
@@ -144,8 +137,7 @@ class Rope {
Vector<Scalar> δz = z[pos] - z[pos - 1];
- zNew[i] = z[pos] - (a - b) / δz.norm() * δz;
- zNew[i] = normalize(zNew[i]);
+ zNew[i] = normalize(z[pos] - (a - b) / δz.norm() * δz);
}
z = zNew;
@@ -192,7 +184,7 @@ class Rope {
}
for (unsigned i = 0; i < z.size() - 1; i++) {
- r.z[2 * i + 1] = normalize(((z[i] + z[i + 1]) / 2.0).eval());
+ r.z[2 * i + 1] = normalize(((z[i] + z[i + 1]) / 2.0));
}
return r;