summaryrefslogtreecommitdiff
path: root/octree.hpp
blob: d8f0aa3b8bc71b0a54cbcef7ce4fe9e159bef8cb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

#pragma once

#include <array>
#include <set>
#include <unordered_map>

#include "spin.hpp"
#include "vector.hpp"

namespace std {
template <typename T, size_t N> struct hash<array<T, N>> {
  typedef array<T, N> argument_type;
  typedef size_t result_type;

  result_type operator()(const argument_type& a) const {
    hash<T> hasher;
    result_type h = 0;
    for (result_type i = 0; i < N; ++i) {
      h = h * 31 + hasher(a[i]);
    }
    return h;
  }
};
} // namespace std

template <class T, int D, class S> class Octree {
private:
  unsigned N;
  T L;
  std::unordered_map<std::array<signed, D>, std::set<Spin<T, D, S>*>> data;

public:
  Octree(T L_tmp, unsigned N_tmp) {
    L = L_tmp;
    N = N_tmp;
  }

  std::array<signed, D> ind(Vector<T, D> x) const {
    std::array<signed, D> ind;

    for (unsigned i = 0; i < D; i++) {
      ind[i] = std::floor(N * x(i) / L);
    }

    return ind;
  }

  void insert(Spin<T, D, S>* s) { data[ind(s->x)].insert(s); };

  void remove(Spin<T, D, S>* s) {
    data[ind(s->x)].erase(s);
    if (data[ind(s->x)].empty()) {
      data.erase(ind(s->x));
    }
  };

  std::set<Spin<T, D, S>*> at(const Vector<T, D>& x) const {
    auto it = data.find(ind(x));
    if (it == data.end()) {
      return {};
    } else {
      return it->second;
    }
  }

  std::set<Spin<T, D, S>*> neighbors(const Vector<T, D>& x) const {
    std::array<signed, D> i0 = ind(x);
    std::set<Spin<T, D, S>*> ns;

    nearest_neighbors_of(i0, D + 1, ns);

    return ns;
  };

  void nearest_neighbors_of(std::array<signed, D> i0, unsigned depth,
                            std::set<Spin<T, D, S>*>& ns) const {
    if (depth == 0) {
      auto it = data.find(i0);
      if (it != data.end()) {
        ns.insert(it->second.begin(), it->second.end());
      }
    } else {
      for (signed j : {-1, 0, 1}) {
        std::array<signed, D> i1 = i0;
        if (N < 2) {
          i1[depth - 1] += j;
        } else {
          i1[depth - 1] = (N + i1[depth - 1] + j) % N;
        }
        nearest_neighbors_of(i1, depth - 1, ns);
      }
    }
  };
};