From a0d936b498735569c6570b020b47d0430883406b Mon Sep 17 00:00:00 2001 From: Jaron Kent-Dobias Date: Thu, 28 Mar 2024 18:12:18 +0100 Subject: Changes. --- least_squares.cpp | 90 ++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 72 insertions(+), 18 deletions(-) diff --git a/least_squares.cpp b/least_squares.cpp index 8e7d085..093892d 100644 --- a/least_squares.cpp +++ b/least_squares.cpp @@ -21,7 +21,7 @@ private: public: template Model(Real σ, unsigned N, unsigned M, Generator& r) : A(M, N), b(M) { - std::normal_distribution aDistribution(0, 1); + std::normal_distribution aDistribution(0, 1 / sqrt(N)); for (unsigned i = 0; i < M; i++) { for (unsigned j =0; j < N; j++) { @@ -36,19 +36,19 @@ public: } } - const unsigned N() { + unsigned N() const { return A.cols(); } - const unsigned M() { + unsigned M() const { return A.rows(); } - const Vector V(const Vector& x) { + Vector V(const Vector& x) const { return A * x + b; } - const Matrix dV(const Vector& x) { + Matrix dV(const Vector& x) const { return A; } @@ -56,45 +56,94 @@ public: // return Matrix::Zero(; // } - const Real H(const Vector& x) { + Real H(const Vector& x) const { return V(x).squaredNorm(); } - const Vector dH(const Vector& x) { + Vector dH(const Vector& x) const { return dV(x).transpose() * V(x); } - const Matrix ddH(const Vector& x) { + Matrix ddH(const Vector& x) const { return dV(x).transpose() * dV(x); } - const Vector ∇H(const Vector& x){ + Vector ∇H(const Vector& x) const { return dH(x) - dH(x).dot(x) * x / x.squaredNorm(); } - const Matrix HessH(const Vector& x) { + Matrix HessH(const Vector& x) const { Matrix hess = ddH(x) - x.dot(dH(x)) * Matrix::Identity(N(), N()); return hess - (hess * x) * x.transpose() / x.squaredNorm(); } + + Vector HessSpectrum(const Vector& x) const { + Eigen::EigenSolver> eigenS(HessH(x)); + return eigenS.eigenvalues().real(); + } }; +template +Vector normalize(const Eigen::MatrixBase& z) { + return z * sqrt((double)z.size() / (typename Derived::Scalar)(z.transpose() * z)); +} + +template +Vector findMinimum(const Model& M, const Vector& x0, Real ε) { + Vector x = x0; + Real λ = 100; + + Real H = M.H(x); + Vector dH = M.dH(x); + Matrix ddH = M.ddH(x); + + Vector g = dH - x.dot(dH) * x / x.squaredNorm(); + Matrix m = ddH - (dH * x.transpose() + x.dot(dH) * Matrix::Identity(M.N(), M.N()) + (ddH * x) * x.transpose()) / x.squaredNorm() + 2.0 * x * x.transpose(); + + while (g.norm() / x.size() > ε && λ < 1e8) { + Vector dz = (m + λ * (Matrix)abs(m.diagonal().array()).matrix().asDiagonal()).partialPivLu().solve(g); + dz -= x.dot(dz) * x / x.squaredNorm(); + Vector zNew = normalize(x - dz); + + Real HNew = M.H(zNew); + Vector dHNew = M.dH(zNew); + Matrix ddHNew = M.ddH(zNew); + + if (HNew * 1.0001 <= H) { + x = zNew; + H = HNew; + dH = dHNew; + ddH = ddHNew; + + g = dH - x.dot(dH) * x / (Real)x.size(); + m = ddH - (dH * x.transpose() + x.dot(dH) * Matrix::Identity(x.size(), x.size()) + (ddH * x) * x.transpose()) / (Real)x.size() + 2.0 * x * x.transpose(); + + λ /= 2; + } else { + λ *= 1.5; + } + } + + return x; +} + using Rng = randutils::random_generator; using Real = double; int main(int argc, char* argv[]) { unsigned N = 10; - unsigned M = 10; + Real α = 1; Real σ = 1; int opt; - while ((opt = getopt(argc, argv, "N:M:s:")) != -1) { + while ((opt = getopt(argc, argv, "N:a:s:")) != -1) { switch (opt) { case 'N': N = (unsigned)atof(optarg); break; - case 'M': - M = (unsigned)atof(optarg); + case 'a': + α = atof(optarg); break; case 's': σ = atof(optarg); @@ -104,16 +153,21 @@ int main(int argc, char* argv[]) { } } + unsigned M = (unsigned)(α * N); + Rng r; Model leastSquares(σ, N, M, r.engine()); Vector x = Vector::Zero(N); - x(0) = N; + x(0) = sqrt(N); + + std::cout << leastSquares.H(x) / N << std::endl; + + Vector xMin = findMinimum(leastSquares, x, 1e-12); - std::cout << leastSquares.H(x) << std::endl; - std::cout << leastSquares.∇H(x) << std::endl; - std::cout << leastSquares.HessH(x) << std::endl; + std::cout << leastSquares.H(xMin) / N << std::endl; + std::cout << leastSquares.HessSpectrum(xMin)(1) / N << std::endl; return 0; } -- cgit v1.2.3-70-g09d2