123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626 |
- #ifndef OPENCV_FLANN_KDTREE_INDEX_H_
- #define OPENCV_FLANN_KDTREE_INDEX_H_
- #include <algorithm>
- #include <map>
- #include <cassert>
- #include <cstring>
- #include "general.h"
- #include "nn_index.h"
- #include "dynamic_bitset.h"
- #include "matrix.h"
- #include "result_set.h"
- #include "heap.h"
- #include "allocator.h"
- #include "random.h"
- #include "saving.h"
- namespace cvflann
- {
- struct KDTreeIndexParams : public IndexParams
- {
- KDTreeIndexParams(int trees = 4)
- {
- (*this)["algorithm"] = FLANN_INDEX_KDTREE;
- (*this)["trees"] = trees;
- }
- };
- template <typename Distance>
- class KDTreeIndex : public NNIndex<Distance>
- {
- public:
- typedef typename Distance::ElementType ElementType;
- typedef typename Distance::ResultType DistanceType;
-
- KDTreeIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeIndexParams(),
- Distance d = Distance() ) :
- dataset_(inputData), index_params_(params), distance_(d)
- {
- size_ = dataset_.rows;
- veclen_ = dataset_.cols;
- trees_ = get_param(index_params_,"trees",4);
- tree_roots_ = new NodePtr[trees_];
-
- vind_.resize(size_);
- for (size_t i = 0; i < size_; ++i) {
- vind_[i] = int(i);
- }
- mean_ = new DistanceType[veclen_];
- var_ = new DistanceType[veclen_];
- }
- KDTreeIndex(const KDTreeIndex&);
- KDTreeIndex& operator=(const KDTreeIndex&);
-
- ~KDTreeIndex()
- {
- if (tree_roots_!=NULL) {
- delete[] tree_roots_;
- }
- delete[] mean_;
- delete[] var_;
- }
-
- void buildIndex()
- {
-
- for (int i = 0; i < trees_; i++) {
-
- #ifndef OPENCV_FLANN_USE_STD_RAND
- cv::randShuffle(vind_);
- #else
- std::random_shuffle(vind_.begin(), vind_.end());
- #endif
- tree_roots_[i] = divideTree(&vind_[0], int(size_) );
- }
- }
- flann_algorithm_t getType() const
- {
- return FLANN_INDEX_KDTREE;
- }
- void saveIndex(FILE* stream)
- {
- save_value(stream, trees_);
- for (int i=0; i<trees_; ++i) {
- save_tree(stream, tree_roots_[i]);
- }
- }
- void loadIndex(FILE* stream)
- {
- load_value(stream, trees_);
- if (tree_roots_!=NULL) {
- delete[] tree_roots_;
- }
- tree_roots_ = new NodePtr[trees_];
- for (int i=0; i<trees_; ++i) {
- load_tree(stream,tree_roots_[i]);
- }
- index_params_["algorithm"] = getType();
- index_params_["trees"] = tree_roots_;
- }
-
- size_t size() const
- {
- return size_;
- }
-
- size_t veclen() const
- {
- return veclen_;
- }
-
- int usedMemory() const
- {
- return int(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*sizeof(int));
- }
-
- void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
- {
- int maxChecks = get_param(searchParams,"checks", 32);
- float epsError = 1+get_param(searchParams,"eps",0.0f);
- if (maxChecks==FLANN_CHECKS_UNLIMITED) {
- getExactNeighbors(result, vec, epsError);
- }
- else {
- getNeighbors(result, vec, maxChecks, epsError);
- }
- }
- IndexParams getParameters() const
- {
- return index_params_;
- }
- private:
-
- struct Node
- {
-
- int divfeat;
-
- DistanceType divval;
-
- Node* child1, * child2;
- };
- typedef Node* NodePtr;
- typedef BranchStruct<NodePtr, DistanceType> BranchSt;
- typedef BranchSt* Branch;
- void save_tree(FILE* stream, NodePtr tree)
- {
- save_value(stream, *tree);
- if (tree->child1!=NULL) {
- save_tree(stream, tree->child1);
- }
- if (tree->child2!=NULL) {
- save_tree(stream, tree->child2);
- }
- }
- void load_tree(FILE* stream, NodePtr& tree)
- {
- tree = pool_.allocate<Node>();
- load_value(stream, *tree);
- if (tree->child1!=NULL) {
- load_tree(stream, tree->child1);
- }
- if (tree->child2!=NULL) {
- load_tree(stream, tree->child2);
- }
- }
-
- NodePtr divideTree(int* ind, int count)
- {
- NodePtr node = pool_.allocate<Node>();
-
- if ( count == 1) {
- node->child1 = node->child2 = NULL;
- node->divfeat = *ind;
- }
- else {
- int idx;
- int cutfeat;
- DistanceType cutval;
- meanSplit(ind, count, idx, cutfeat, cutval);
- node->divfeat = cutfeat;
- node->divval = cutval;
- node->child1 = divideTree(ind, idx);
- node->child2 = divideTree(ind+idx, count-idx);
- }
- return node;
- }
-
- void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
- {
- memset(mean_,0,veclen_*sizeof(DistanceType));
- memset(var_,0,veclen_*sizeof(DistanceType));
-
- int cnt = std::min((int)SAMPLE_MEAN+1, count);
- for (int j = 0; j < cnt; ++j) {
- ElementType* v = dataset_[ind[j]];
- for (size_t k=0; k<veclen_; ++k) {
- mean_[k] += v[k];
- }
- }
- for (size_t k=0; k<veclen_; ++k) {
- mean_[k] /= cnt;
- }
-
- for (int j = 0; j < cnt; ++j) {
- ElementType* v = dataset_[ind[j]];
- for (size_t k=0; k<veclen_; ++k) {
- DistanceType dist = v[k] - mean_[k];
- var_[k] += dist * dist;
- }
- }
-
- cutfeat = selectDivision(var_);
- cutval = mean_[cutfeat];
- int lim1, lim2;
- planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
- if (lim1>count/2) index = lim1;
- else if (lim2<count/2) index = lim2;
- else index = count/2;
-
- if ((lim1==count)||(lim2==0)) index = count/2;
- }
-
- int selectDivision(DistanceType* v)
- {
- int num = 0;
- size_t topind[RAND_DIM];
-
- for (size_t i = 0; i < veclen_; ++i) {
- if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
-
- if (num < RAND_DIM) {
- topind[num++] = i;
- }
- else {
- topind[num-1] = i;
- }
-
- int j = num - 1;
- while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
- std::swap(topind[j], topind[j-1]);
- --j;
- }
- }
- }
-
- int rnd = rand_int(num);
- return (int)topind[rnd];
- }
-
- void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
- {
-
- int left = 0;
- int right = count-1;
- for (;; ) {
- while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
- while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
- if (left>right) break;
- std::swap(ind[left], ind[right]); ++left; --right;
- }
- lim1 = left;
- right = count-1;
- for (;; ) {
- while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
- while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
- if (left>right) break;
- std::swap(ind[left], ind[right]); ++left; --right;
- }
- lim2 = left;
- }
-
- void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError)
- {
-
- if (trees_ > 1) {
- fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
- }
- if (trees_>0) {
- searchLevelExact(result, vec, tree_roots_[0], 0.0, epsError);
- }
- assert(result.full());
- }
-
- void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError)
- {
- int i;
- BranchSt branch;
- int checkCount = 0;
- Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
- DynamicBitset checked(size_);
-
- for (i = 0; i < trees_; ++i) {
- searchLevel(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
- }
-
- while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
- searchLevel(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
- }
- delete heap;
- assert(result.full());
- }
-
- void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck,
- float epsError, Heap<BranchSt>* heap, DynamicBitset& checked)
- {
- if (result_set.worstDist()<mindist) {
-
- return;
- }
-
- if ((node->child1 == NULL)&&(node->child2 == NULL)) {
-
- int index = node->divfeat;
- if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return;
- checked.set(index);
- checkCount++;
- DistanceType dist = distance_(dataset_[index], vec, veclen_);
- result_set.addPoint(dist,index);
- return;
- }
-
- ElementType val = vec[node->divfeat];
- DistanceType diff = val - node->divval;
- NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
- NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
-
- DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
-
- if ((new_distsq*epsError < result_set.worstDist())|| !result_set.full()) {
- heap->insert( BranchSt(otherChild, new_distsq) );
- }
-
- searchLevel(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
- }
-
- void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError)
- {
-
- if ((node->child1 == NULL)&&(node->child2 == NULL)) {
- int index = node->divfeat;
- DistanceType dist = distance_(dataset_[index], vec, veclen_);
- result_set.addPoint(dist,index);
- return;
- }
-
- ElementType val = vec[node->divfeat];
- DistanceType diff = val - node->divval;
- NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
- NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
-
- DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
-
- searchLevelExact(result_set, vec, bestChild, mindist, epsError);
- if (new_distsq*epsError<=result_set.worstDist()) {
- searchLevelExact(result_set, vec, otherChild, new_distsq, epsError);
- }
- }
- private:
- enum
- {
-
- SAMPLE_MEAN = 100,
-
- RAND_DIM=5
- };
-
- int trees_;
-
- std::vector<int> vind_;
-
- const Matrix<ElementType> dataset_;
- IndexParams index_params_;
- size_t size_;
- size_t veclen_;
- DistanceType* mean_;
- DistanceType* var_;
-
- NodePtr* tree_roots_;
-
- PooledAllocator pool_;
- Distance distance_;
- };
- }
- #endif
|