backported kd tree improvements from monav project: faster with base case 8

This commit is contained in:
Dennis Luxen 2010-08-12 11:39:06 +00:00
parent b87d6f3c66
commit 171815c9b7

View File

@ -34,18 +34,20 @@ KD Tree coded by Christian Vetter, Monav Project
namespace KDTree { namespace KDTree {
#define KDTREE_BASESIZE (8)
template< unsigned k, typename T > template< unsigned k, typename T >
class BoundingBox { class BoundingBox {
public: public:
BoundingBox() { BoundingBox() {
for ( unsigned dim = 0; dim < k; ++dim ) { for ( unsigned dim = 0; dim < k; ++dim ) {
min[dim] = std::numeric_limits< T >::min(); min[dim] = std::numeric_limits< T >::min();
max[dim] = std::numeric_limits< T >::max(); max[dim] = std::numeric_limits< T >::max();
} }
} }
T min[k]; T min[k];
T max[k]; T max[k];
}; };
struct NoData {}; struct NoData {};
@ -53,158 +55,174 @@ struct NoData {};
template< unsigned k, typename T > template< unsigned k, typename T >
class EuclidianMetric { class EuclidianMetric {
public: public:
double operator() ( const T left[k], const T right[k] ) { double operator() ( const T left[k], const T right[k] ) {
double result = 0; double result = 0;
for ( unsigned i = 0; i < k; ++i ) { for ( unsigned i = 0; i < k; ++i ) {
double temp = (double)left[i] - (double)right[i]; double temp = (double)left[i] - (double)right[i];
result += temp * temp; result += temp * temp;
} }
return result; return result;
} }
double operator() ( const BoundingBox< k, T > &box, const T point[k] ) { double operator() ( const BoundingBox< k, T > &box, const T point[k] ) {
T nearest[k]; T nearest[k];
for ( unsigned dim = 0; dim < k; ++dim ) { for ( unsigned dim = 0; dim < k; ++dim ) {
if ( point[dim] < box.min[dim] ) if ( point[dim] < box.min[dim] )
nearest[dim] = box.min[dim]; nearest[dim] = box.min[dim];
else if ( point[dim] > box.max[dim] ) else if ( point[dim] > box.max[dim] )
nearest[dim] = box.max[dim]; nearest[dim] = box.max[dim];
else else
nearest[dim] = point[dim]; nearest[dim] = point[dim];
} }
return operator() ( point, nearest ); return operator() ( point, nearest );
} }
}; };
template < unsigned k, typename T, typename Data = NoData, typename Metric = EuclidianMetric< k, T > > template < unsigned k, typename T, typename Data = NoData, typename Metric = EuclidianMetric< k, T > >
class StaticKDTree { class StaticKDTree {
public: public:
struct InputPoint { struct InputPoint {
T coordinates[k]; T coordinates[k];
Data data; Data data;
}; bool operator==( const InputPoint& right )
{
for ( int i = 0; i < k; i++ ) {
if ( coordinates[i] != right.coordinates[i] )
return false;
}
return true;
}
};
StaticKDTree( std::vector< InputPoint > * points ){ StaticKDTree( std::vector< InputPoint > * points ){
assert( k > 0 ); assert( k > 0 );
assert ( points->size() > 0 ); assert ( points->size() > 0 );
size = points->size(); size = points->size();
kdtree = new InputPoint[size]; kdtree = new InputPoint[size];
for ( Iterator i = 0; i != size; ++i ) { for ( Iterator i = 0; i != size; ++i ) {
kdtree[i] = points->at(i); kdtree[i] = points->at(i);
for ( unsigned dim = 0; dim < k; ++dim ) { for ( unsigned dim = 0; dim < k; ++dim ) {
if ( kdtree[i].coordinates[dim] < boundingBox.min[dim] ) if ( kdtree[i].coordinates[dim] < boundingBox.min[dim] )
boundingBox.min[dim] = kdtree[i].coordinates[dim]; boundingBox.min[dim] = kdtree[i].coordinates[dim];
if ( kdtree[i].coordinates[dim] > boundingBox.max[dim] ) if ( kdtree[i].coordinates[dim] > boundingBox.max[dim] )
boundingBox.max[dim] = kdtree[i].coordinates[dim]; boundingBox.max[dim] = kdtree[i].coordinates[dim];
} }
} }
std::stack< Tree > s; std::stack< Tree > s;
s.push ( Tree ( 0, size, 0 ) ); s.push ( Tree ( 0, size, 0 ) );
while ( !s.empty() ) { while ( !s.empty() ) {
Tree tree = s.top(); Tree tree = s.top();
s.pop(); s.pop();
if ( tree.left == tree.right ) if ( tree.right - tree.left < KDTREE_BASESIZE )
continue; continue;
Iterator middle = tree.left + ( tree.right - tree.left ) / 2; Iterator middle = tree.left + ( tree.right - tree.left ) / 2;
#ifdef _GLIBCXX_PARALLEL #ifdef _GLIBCXX_PARALLEL
__gnu_parallel::nth_element( kdtree + tree.left, kdtree + middle, kdtree + tree.right, Less( tree.dimension ) ); __gnu_parallel::nth_element( kdtree + tree.left, kdtree + middle, kdtree + tree.right, Less( tree.dimension ) );
#else #else
std::nth_element( kdtree + tree.left, kdtree + middle, kdtree + tree.right, Less( tree.dimension ) ); std::nth_element( kdtree + tree.left, kdtree + middle, kdtree + tree.right, Less( tree.dimension ) );
#endif #endif
s.push( Tree( tree.left, middle, ( tree.dimension + 1 ) % k ) ); s.push( Tree( tree.left, middle, ( tree.dimension + 1 ) % k ) );
s.push( Tree( middle + 1, tree.right, ( tree.dimension + 1 ) % k ) ); s.push( Tree( middle + 1, tree.right, ( tree.dimension + 1 ) % k ) );
} }
} }
~StaticKDTree(){ ~StaticKDTree(){
delete[] kdtree; delete[] kdtree;
} }
bool NearestNeighbor( InputPoint* result, const InputPoint& point, double radius = std::numeric_limits< T >::max() ) { bool NearestNeighbor( InputPoint* result, const InputPoint& point, double radius = std::numeric_limits< T >::max() ) {
Metric distance; Metric distance;
bool found = false; bool found = false;
double nearestDistance = radius; double nearestDistance = radius;
std::stack< NNTree > s; std::stack< NNTree > s;
s.push ( NNTree ( 0, size, 0, boundingBox ) ); s.push ( NNTree ( 0, size, 0, boundingBox ) );
while ( !s.empty() ) { while ( !s.empty() ) {
NNTree tree = s.top(); NNTree tree = s.top();
s.pop(); s.pop();
if ( distance( tree.box, point.coordinates ) >= nearestDistance ) if ( distance( tree.box, point.coordinates ) >= nearestDistance )
continue; continue;
if ( tree.left == tree.right ) if ( tree.right - tree.left < KDTREE_BASESIZE ) {
continue; for ( unsigned i = tree.left; i < tree.right; i++ ) {
double newDistance = distance( kdtree[i].coordinates, point.coordinates );
if ( newDistance < nearestDistance ) {
nearestDistance = newDistance;
*result = kdtree[i];
found = true;
}
}
continue;
}
Iterator middle = tree.left + ( tree.right - tree.left ) / 2; Iterator middle = tree.left + ( tree.right - tree.left ) / 2;
double newDistance = distance( kdtree[middle].coordinates, point.coordinates ); double newDistance = distance( kdtree[middle].coordinates, point.coordinates );
if ( newDistance < nearestDistance ) { if ( newDistance < nearestDistance ) {
nearestDistance = newDistance; nearestDistance = newDistance;
*result = kdtree[middle]; *result = kdtree[middle];
found = true; found = true;
} }
Less comperator( tree.dimension ); Less comperator( tree.dimension );
if ( !comperator( point, kdtree[middle] ) ) { if ( !comperator( point, kdtree[middle] ) ) {
NNTree first( middle + 1, tree.right, ( tree.dimension + 1 ) % k, tree.box ); NNTree first( middle + 1, tree.right, ( tree.dimension + 1 ) % k, tree.box );
NNTree second( tree.left, middle, ( tree.dimension + 1 ) % k, tree.box ); NNTree second( tree.left, middle, ( tree.dimension + 1 ) % k, tree.box );
first.box.min[tree.dimension] = kdtree[middle].coordinates[tree.dimension]; first.box.min[tree.dimension] = kdtree[middle].coordinates[tree.dimension];
second.box.max[tree.dimension] = kdtree[middle].coordinates[tree.dimension]; second.box.max[tree.dimension] = kdtree[middle].coordinates[tree.dimension];
s.push( second ); s.push( second );
s.push( first ); s.push( first );
} }
else { else {
NNTree first( middle + 1, tree.right, ( tree.dimension + 1 ) % k, tree.box ); NNTree first( middle + 1, tree.right, ( tree.dimension + 1 ) % k, tree.box );
NNTree second( tree.left, middle, ( tree.dimension + 1 ) % k, tree.box ); NNTree second( tree.left, middle, ( tree.dimension + 1 ) % k, tree.box );
first.box.min[tree.dimension] = kdtree[middle].coordinates[tree.dimension]; first.box.min[tree.dimension] = kdtree[middle].coordinates[tree.dimension];
second.box.max[tree.dimension] = kdtree[middle].coordinates[tree.dimension]; second.box.max[tree.dimension] = kdtree[middle].coordinates[tree.dimension];
s.push( first ); s.push( first );
s.push( second ); s.push( second );
} }
} }
return found;
return found; }
}
private: private:
typedef unsigned Iterator; typedef unsigned Iterator;
struct Tree { struct Tree {
Iterator left; Iterator left;
Iterator right; Iterator right;
unsigned dimension; unsigned dimension;
Tree() {} Tree() {}
Tree( Iterator l, Iterator r, unsigned d ): left( l ), right( r ), dimension( d ) {} Tree( Iterator l, Iterator r, unsigned d ): left( l ), right( r ), dimension( d ) {}
}; };
struct NNTree { struct NNTree {
Iterator left; Iterator left;
Iterator right; Iterator right;
unsigned dimension; unsigned dimension;
BoundingBox< k, T > box; BoundingBox< k, T > box;
NNTree() {} NNTree() {}
NNTree( Iterator l, Iterator r, unsigned d, const BoundingBox< k, T >& b ): left( l ), right( r ), dimension( d ), box ( b ) {} NNTree( Iterator l, Iterator r, unsigned d, const BoundingBox< k, T >& b ): left( l ), right( r ), dimension( d ), box ( b ) {}
}; };
class Less { class Less {
public: public:
Less( unsigned d ) { Less( unsigned d ) {
dimension = d; dimension = d;
assert( dimension < k ); assert( dimension < k );
} }
bool operator() ( const InputPoint& left, const InputPoint& right ) { bool operator() ( const InputPoint& left, const InputPoint& right ) {
assert( dimension < k ); assert( dimension < k );
return left.coordinates[dimension] < right.coordinates[dimension]; return left.coordinates[dimension] < right.coordinates[dimension];
} }
private: private:
unsigned dimension; unsigned dimension;
}; };
BoundingBox< k, T > boundingBox; BoundingBox< k, T > boundingBox;
InputPoint* kdtree; InputPoint* kdtree;
Iterator size; Iterator size;
}; };
} }