Esiry.com
Focus on Machine Learning.

# Principle and Implementation of K-D TREE Algorithm

The k-d tree is a k-dimensional tree, which is commonly used for spatial division and neighbor search. It is a special case of a binary space division tree. Generally, for a data set whose dimension is k and the number of data points is N, k-d tree is applicable to the case of N≫2k.

1) Principle of k-d tree algorithm
The kd tree is a binary tree in which each node is a k-dimensional numerical point, and each node on the node represents a hyperplane which is perpendicular to the coordinate axis of the current division dimension and divides the space into two parts in the dimension. One part is in its left subtree and the other part is in its right subtree. That is, if the division dimension of the current node is d, the coordinate values ​​of all points on the left subtree in the d dimension are smaller than the current value, and the coordinate values ​​of all points on the right subtree in the d dimension are greater than or equal to the current value, and the definition is Any child nodes are established.

1.1) Tree construction
A balanced k-d tree whose distance from all leaf nodes to the root node is approximately equal. However, a balanced k-d tree is not optimal for application scenarios such as nearest neighbor search and spatial search.
The construction process of the conventional kd tree is as follows: the loop sequentially takes each dimension of the data point as a segmentation dimension, takes the median value of the data point as a segmentation hyperplane, and hangs the data point on the left side of the median value. Left subtree, hangs the data point to the right of the median in its right subtree. Recursively process its subtree until all data points are mounted.
a) Segmentation dimension selection optimization
Before the construction begins, compare the distribution of data points in each dimension. The larger the variance of the data points in a certain dimension, the more dispersed the distribution, and the smaller the variance, the more concentrated the distribution. Cutting from the dimension of large variance can achieve good segmentation effect and balance.
b) median selection optimization
First, before the algorithm starts, the original data points are sorted in all dimensions once, stored, and then in the subsequent median selection, there is no need to sort their subsets every time, which improves performance.
Second, a fixed number of points are randomly selected from the original data points, and then sorted, and the median is taken from each of the sample points each time as a segmentation hyperplane. This approach has proven to be very good performance and well balanced in practice.
This paper uses the conventional construction method, with a set of two-dimensional plane points (x, y) (2, 3), (5, 4), (9, 6), (4, 7), (8, 1), ( 7, 2) As an example, the following figure is used to illustrate the construction process of the kd tree.
a) When constructing the root node, the segmentation dimension at this time is x, and the above point set is sorted as (2, 3), (4, 7), (5, 4), (7, 2) in the x dimension from small to large. , (8,1), (9,6); where the value is (7,2). (Note: The median of 2,4,5,7,8,9 in mathematics is (5 + 7)/2=6, but since the median of the algorithm needs to be within the set of points, the value calculation in this paper Use len(points)//2=3, points[3]=(7,2))
b) (2,3), (4,7), (5,4) hang on the left subtree of the (7,2) node, (8,1), (9,6) hang on (7,2) The right subtree of the node.
c) When constructing the left subtree of the (7, 2) node, the point set (2, 3), (4, 7), (5, 4) at this time has a split dimension of y and a median of (5, 4) As a split plane, (2,3) hangs in its left subtree, and (4,7) hangs in its right subtree.
d) When constructing the right subtree of the (7, 2) node, the point set (8, 1), (9, 6) is also the sigma dimension of y, and the median value is (9, 6) as the segmentation plane. (8,1) hangs in its left subtree. At this point, the k-d tree is built.

The above construction process can be seen in the following figure. Building a k-d tree is a process of gradually dividing a two-dimensional plane.

We can also combine the following figure (the figure is taken from Wikipedia) to look at the construction and spatial division process of k-d tree from the three-dimensional space.

First, the vertical plane with the red border divides the whole space into two parts, which are respectively divided into upper and lower parts by a horizontal plane with a green border. Finally, the four subspaces are respectively divided into two parts by a vertical plane with a blue border, and become eight subspaces, and the eight subspaces are leaf nodes.

The following is the build code for the k-d tree:

def kd_tree(points, depth):
if 0 == len(points):
return None
cutting_dim = depth % len(points[0])
medium_index = len(points) // 2
points.sort(key=itemgetter(cutting_dim))
node = Node(points[medium_index])
node.left = kd_tree(points[:medium_index], depth + 1)
node.right = kd_tree(points[medium_index + 1:], depth + 1)
return node

1.2) Find the d-dimensional minimum coordinate point
a) If the current node’s segmentation dimension is d
Since the right subtree nodes are greater than or equal to the coordinate value of the current node in the d dimension, its right subtree can be ignored, and only the left subtree is searched. If there is no left subtree, the current node is the smallest coordinate value node.
b) If the current node’s segmentation dimension is not d
Recursive search is required in its left and right subtrees, respectively.
The following is to find the d-dimensional minimum coordinate point code:

def findmin(n, depth, cutting_dim, min):
if min is None:
min = n.location
if n is None:
return min
current_cutting_dim = depth % len(min)
if n.location[cutting_dim] < min[cutting_dim]: min = n.location if cutting_dim == current_cutting_dim: return findmin(n.left, depth + 1, cutting_dim, min) else: leftmin = findmin(n.left, depth + 1, cutting_dim, min) rightmin = findmin(n.right, depth + 1, cutting_dim, min) if leftmin[cutting_dim] > rightmin[cutting_dim]:
return rightmin
else:
return leftmin


Starting from the root node, if the coordinate value of the segmentation dimension of the node to be inserted in the current node is smaller than the coordinate value of the current node in the dimension, it is inserted in the left subtree; if it is greater than or equal to the coordinate value of the current node in the dimension, Its right subtree is inserted. Recursively traverse until the leaf node.
The following is the new node code:

def insert(n, point, depth):
if n is None:
return Node(point)
cutting_dim = depth % len(point)
if point[cutting_dim] < n.location[cutting_dim]:
if n.left is None:
n.left = Node(point)
else:
insert(n.left, point, depth + 1)
else:
if n.right is None:
n.right = Node(point)
else:
insert(n.right, point, depth + 1)


Adding a node multiple times may cause a tree imbalance. When the imbalance exceeds a certain threshold, rebalancing is required.
1.4) Delete node
The easiest way is to make all the child nodes of the node to be deleted into a new collection and then rebuild it. Mount the constructed subtree to the deleted node. The performance of this method is not good, and the optimized algorithm is considered below.
Assuming that the segmentation dimension of the node T to be deleted is x, the following is considered according to different types of nodes to be deleted.
a) no subtree
It is a leaf node itself and is deleted directly.
b) There is a right subtree
Look for the node p with the smallest x-division dimension in T.right, and then replace the deleted node T; recursively delete the node p.
c) no right subtree has left subtree
Find the node p with the smallest x-division dimension in T.left, ie p=findmin(T.left, cutting-dim=x), then replace the deleted node T with node p; use the original T.left as p.right; Recursive processing deletes node p.
(The reason why the findmax(T.left, cutting-dim=x) node is not used to replace the deleted node is because the left subtree node of the original deleted node has the same maximum x dimension, thus destroying the left child. The coordinates of the tree in the x-division dimension need to be smaller than the definition of its root node)
The following is the removal of the node code:

def delete(n, point, depth):
cutting_dim = depth % len(point)
if n.location == point:
if n.right is not None:
n.location = findmin(n.right, depth + 1, cutting_dim, None)
delete(n.right, n.location, depth + 1)
elif n.left is not None:
n.location = findmin(n.left, depth + 1)
delete(n.left, n.location, depth + 1)
n.right = n.left
n.left = None
else:
n = None
else:
if point[cutting_dim] < n.location[cutting_dim]:
delete(n.left, point, depth + 1)
else:
delete(n.right, point, depth + 1)


2) Nearest neighbor search
Given a point p, the process of querying the data set with its closest point is the nearest neighbor search.
For example, when searching for the nearest neighbor of (3, 5) on the constructed k-d tree above, this paper analyzes the nearest neighbor search process of two-dimensional space by combining the following two figures.
a) First, starting from the root node (7, 2), the current nearest neighbor is set to (7, 2), and the k-d tree is depth-first traversed. Taking (3,5) as the center, the distance to (7,2) is the radius of the circle (the multi-dimensional space is hyperspherical), it can be seen that the area on the right side of (8,1) does not intersect with the circle, so ( The right subtree of 8,1) is ignored.
b) Then go to (7, 2) the left subtree root node (5, 4), after comparing the distance with the original nearest neighbor, update the current nearest neighbor to (5, 4). Taking (3,5) as the center of the circle, the distance to (5,4) is the radius, and the area on the right side of (7,2) is found to be disjoint with the circle, ignoring all the nodes on the side, so (7,2 The entire right subtree is marked as ignored.
c) After traversing the left and right leaf nodes of (5, 4), it is found that the current optimal distance is equal, and the nearest neighbor is not updated. So the nearest neighbor of (3,5) is (5,4).

The following is the nearest neighbor search code:
3) Complexity analysis

 Operation Average complexity Worst complexity Add node O(logn) O(n) Delete node O(logn) O(n) Nearest neighbor search O(logn) O(n)

4) scikit-learn use
Scikit-learn is a practical machine learning class library with KDTree implementation. The following example is a visual display. Only a k-d tree of a two-dimensional space is constructed, and then a k-nearest neighbor search and a range search of a specified radius are performed. The retrieval of multidimensional space, the calling method is not much different from this example.

#!/usr/bin/python
# -*- coding: UTF-8 -*-
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Circle
from sklearn.neighbors import KDTree
np.random.seed(0)
points = np.random.random((100, 2))
tree = KDTree(points)
point = points[0]
# kNN
dists, indices = tree.query([point], k=3)
print(dists, indices)
print(indices)
fig = plt.figure()