本文信息

中文名:《并行化优化KD树算法:使用C#实现高效的最近邻搜索》

英文名:”Parallelized Optimization of KD-Tree Algorithm: Implementing Efficient Nearest Neighbor Search in C#”

摘要

本文介绍了如何使用并行计算技术优化 KD 树算法,并使用 C# 编程语言实现了高效的最近邻搜索。首先,我们简要介绍了 KD 树的原理和构建过程,然后详细讨论了如何利用并行计算库在多个 CPU 核心上并行构建 KD 树,从而加速搜索过程。通过实验验证,我们证明了并行化优化能够显著提高 KD 树的构建速度和搜索效率,为大规模数据集下的最近邻搜索问题提供了一种高效的解决方案。

Summary

This article presents a parallelized optimization approach for KD-tree algorithm and demonstrates efficient nearest neighbor search implementation in C#. We first introduce the principles and construction process of KD trees, and then discuss in detail how to leverage parallel computing techniques to build KD trees concurrently on multiple CPU cores, thus accelerating the search process. Through experimental validation, we prove that parallelized optimization significantly improves the construction speed and search efficiency of KD trees, providing an efficient solution for nearest neighbor search problems on large-scale datasets.

版本信息

本文涉及到的 C# 代码使用 .Net 8.0 以及 C# 12 版本编写。

前言

思考以下场景:有 1000 个 A 类型的地点(包含地址和GPS坐标),以及 50000 个 B 类型的地点,需要找出距离每个 A 类型地点最近的 B 类型地点。

我第一时间想到的是之前做推荐系统用过的 KNN 算法,不过 KNN 的实现计算量太大了,当数据量多的时候,需要耗费很多时间,因此针对这个场景,我采用了 KD 树算法。

关于 KD 树

KD 树(K-Dimensional Tree)是一种用于多维空间中的数据结构,它是二叉树的一种变种,用于高效地组织和搜索多维数据。同时也是是 KNN 的一个高效算法,KD 树的主要优点是可以在高维空间中进行快速的最近邻搜索和范围搜索。

关于 KD 树的详细原理不在本文的讨论范围内,本文只做简单介绍。本文讨论的场景是关于 GPS 坐标间距离的计算,因此选择 KD 树的维度是二维。

1. 数据结构

  • 节点: KD 树中的每个节点代表一个数据点。节点包含一个轴(Axis)和一个分割值(Split Value),以及对应于左子树和右子树的指针。

  • 根节点: KD 树的根节点代表整个数据集的范围。

  • 叶子节点: KD 树的叶子节点代表一个单独的数据点。

2. 构建过程

KD 树的构建过程基于递归的分割策略。通常,我们会选择一个轴(Axis)和一个分割值(Split Value),将数据集分割成两个子集。然后,递归地对每个子集进行相同的分割操作,直到每个子集中的数据点数量小于某个阈值,或者达到了指定的深度。

构建过程中,通常采用以下策略来选择轴和分割值:

  • 轴的选择: 轴的选择通常是根据数据点在每个维度上的方差或者范围来确定。可以选择方差最大的维度作为轴,或者按照轮换的方式选择每个维度作为轴。

  • 分割值的选择: 分割值通常是选取当前子集中数据点在选定轴上的中位数。

3. 搜索过程

KD 树的搜索过程也是基于递归的。搜索过程从根节点开始,按照某种规则向下遍历树,直到找到目标数据点或者达到叶子节点。

搜索过程中,通常采用以下策略来确定搜索顺序:

  • 确定分割方向: 根据目标数据点在当前节点所选定的轴上的值,确定搜索方向。如果目标值小于当前节点的分割值,则向左子树搜索;否则向右子树搜索。

  • 确定搜索顺序: 根据目标数据点在当前节点所选定的轴上的距离,确定搜索顺序。首先搜索距离更近的子树,然后再搜索距离更远的子树。

  • 剪枝策略: 在搜索过程中,可以采用剪枝策略来减少搜索的分支,提高搜索效率。例如,可以计算目标点与当前节点分割超平面的距离,如果距离大于当前最近距离,则可以剪掉该分支。

4. 应用场景

KD 树常用于高维空间中的数据组织和搜索,特别是在机器学习和数据挖掘领域中。常见的应用包括最近邻搜索、范围搜索、近似最近邻搜索等。

5. 总结

KD 树是一种高效的多维空间数据结构,适用于快速的最近邻搜索和范围搜索。它的构建和搜索过程都基于递归的思想,并且可以通过选择合适的分割策略和剪枝策略来提高搜索效率。

距离计算

欧式距离

一说到距离计算,最开始想到的就是欧氏距离(欧几里得距离 Euclidean distance),表示在m维空间中两个点之间的真实距离。公式为

\[d = \sqrt{\sum_{i=1}^{k}(x_i – y_i)^2} \]

在二维和三维空间中的欧式距离的就是两点之间的距离,其中二维空间的表示为

\[d = \sqrt{(x_1 – x_2)^2 + (y_1 – y_2)^2} \]

其中 \((x_1, x_2), (y_1, y_2)\) 是两点的坐标

我们这里计算两地之间的距离,看起来似乎是用二维空间两点之间的距离就行了,同学也是这么说的。

曲面上的两点距离

不过转念一想,不对啊,这个只是在平面上计算距离,但地球不是个平面,有曲率的啊

于是查了下资料,找到了这个 Haversine(半正矢)公式

Haversine 名字来历是 Ha-VERSINE,即 Half-Versine ,表示 sin 的一半的意思。Haversine公式给出了用两点经纬度计算两点在球面上的距离的方式。

\[haversin(\frac{d}{R}) = haversin(\varphi_2 – \varphi_1) + cos(\varphi_1)cos(\varphi_2)haversin(\Delta\lambda) \]

其中

  • d 是沿着球体大圆的两点之间的距离(参见球面距离)
  • R 为球体半径,地球半径可取平均值 6371km
  • φ1, φ2 表示两点的纬度
  • Δλ 表示两点经度的差值

上面应用到圆心角 θ 以及纬度和经度差值的半正矢函数 hav(θ) 的定义为

\[hav(\theta) = \sin^2(\frac{\theta}{2}) = \frac{1-\cos \theta}{2} \]

关于这个公式的更详细推导过程就省略不表了,感兴趣的同学请查阅参考资料。

根据半正矢的定义, archaversine(反半正弦)可以用反正弦表示:

\[archav(h) = 2 \arcsin \sqrt{h} \]

其中 \(0 \le h \le 1\)

代入半正矢公式可得到距离 d 的求解公式:

\[d = 2R \arcsin (\sqrt{\sin^2(\frac{\varphi_2-\varphi_1}{2}) + \cos \theta_2 \cdot \sin^2(\frac{\lambda_2 – \lambda_1}{2})}) \]

PS: 这部分的数学顾问: @Wyu-Cnk

接下来是使用 C# 实现 Haversine 公式计算两点之间距离(以公里为单位)

public static class DistanceCalculator {
  // Radius of the Earth in kilometers
  private const double EarthRadius = 6371;

  public static double CalculateDistance(ILocation location1, ILocation location2) {
    var lat1 = DegreeToRadian(location1.Lat);
    var lon1 = DegreeToRadian(location1.Lng);
    var lat2 = DegreeToRadian(location2.Lat);
    var lon2 = DegreeToRadian(location2.Lng);

    var dlon = lon2 - lon1;
    var dlat = lat2 - lat1;

    var a = Math.Pow(Math.Sin(dlat / 2), 2) + Math.Cos(lat1) * Math.Cos(lat2) * Math.Pow(Math.Sin(dlon / 2), 2);
    var c = 2 * Math.Atan2(Math.Sqrt(a), Math.Sqrt(1 - a));

    var distance = EarthRadius * c;
    return distance;
  }

  private static double DegreeToRadian(double degree) {
    return degree * Math.PI / 180;
  }
}

初步实现 KD 树

首先使用 C# 实现基本的 KD 树,包括 KD 树的构建,以及最近节点的搜索功能。

数据结构

为了使程序更加通用,本文使用接口以及泛型对数据结构进行了一定程度的抽象。

地点的数据结构

public interface ILocation {
  public double Lng { get; set; }
  public double Lat { get; set; }
}

public interface ILocationNode {
  public ILocation Location { get; set; }
}

KD 树节点

public class KdTreeNode<T> where T : ILocationNode {
  public T Value { get; set; }
  public KdTreeNode<T>? Left { get; set; }
  public KdTreeNode<T>? Right { get; set; }
}

KDTree 类

需要用到一个泛型参数,与上述的 KdTreeNode<T> 类中的泛型参数相同。

/// <summary>
/// k-dimensional tree
/// </summary>
/// <typeparam name="T"></typeparam>
public class KdTree<T> where T : ILocationNode {
  private readonly IProgress<string> _progress = new Progress<string>(Console.WriteLine);
  private KdTreeNode<T>? _root;
}

KD 树构建

使用递归来构建 KD 树。

public void BuildTree(List<T> items) {
  _root = BuildTree(items, 0, items.Count - 1, 0);
}

private KdTreeNode<T>? BuildTree(List<T> items, int start, int end, int depth) {
  if (start > end)
    return null;

  var axis = depth % 2; // Assuming latitude and longitude as 2D coordinates

  items.Sort(
    (a, b) => axis == 0
    ? a.Location.Lng.CompareTo(b.Location.Lng)
    : a.Location.Lat.CompareTo(b.Location.Lat)
  );

  var medianIndex = start + (end - start) / 2;
  var node = new KdTreeNode<T> {
    Value = items[medianIndex],
    Left = BuildTree(items, start, medianIndex - 1, depth + 1),
    Right = BuildTree(items, medianIndex + 1, end, depth + 1)
  };

  // 进度显示
  _progress.Report($"Splitting on depth {depth}, axis {axis}, median index {medianIndex}");

  return node;
}

使用方式

准备好一批地点数据,转换为 KdNode<T> 类型,作为 BuildTree 方法的参数执行。

class Store {
  public string Name { get; set; }
  public string Address { get; set; }
  public Location? Location { get; set; }
}

var stores = new List<Store>(){
  // Example coordinates for New York City
  new Store {
    Name = "Store1",
    Location = new Location { lat = 40.7128, lng = -74.0060 }
  },
  // Example coordinates for Los Angeles
  new Store {
    Name = "Store2",
    Location = new Location { lat = 34.0522, lng = -118.2437 }
  },
  // 更多数据请从外部数据源加载
};

var kdTree = new KdTree<KdNode<Store>>();
kdTree.BuildTreeParallel(stores.Select(e => new KdNode<Store> {
  Node = e,
  Location = e.Location
}).ToList());

最近节点查找

KD 树的查找过程也是基于递归,输入参数为 ILocation 地点坐标。

public KdTreeNode<T>? FindNearestNode(ILocation location) {
  return FindNearestNode(_root, location, 0, null);
}

private KdTreeNode<T>? FindNearestNode(KdTreeNode<T>? node, ILocation location, int depth, KdTreeNode<T>? best) {
  if (node == null) return best;

  var bestDistance = best != null
    ? DistanceCalculator.CalculateDistance(location, best.Value.Location)
    : double.PositiveInfinity;
  var currentNodeDistance = DistanceCalculator.CalculateDistance(location, node.Value.Location);

  if (currentNodeDistance < bestDistance)
    best = node;

  var axis = depth % 2;
  var axisDistance =
    axis == 0 ? location.Lng - node.Value.Location.Lng : location.Lat - node.Value.Location.Lat;

  var nearChild = axisDistance < 0 ? node.Left : node.Right;
  var farChild = axisDistance < 0 ? node.Right : node.Left;

  var nearest = FindNearestNode(nearChild, location, depth + 1, best);

  if (nearest != null) {
    var nearestDistance = DistanceCalculator.CalculateDistance(location, nearest.Value.Location);
    if (nearestDistance < bestDistance)
      best = nearest;
  }

  if (Math.Abs(axisDistance) < bestDistance) {
    var farthest = FindNearestNode(farChild, location, depth + 1, best);
    if (farthest != null) {
      var farthestDistance = DistanceCalculator.CalculateDistance(location, farthest.Value.Location);
      if (farthestDistance < bestDistance)
        best = farthest;
    }
  }

  return best;
}

使用方式

// Example coordinates for Tokyo City
var location = new Location { lat = 35.652832, lng = 139.839478 }
var nearestNode = kdTree.FindNearestNode(location);

构建性能

使用 C# 的 StopWatch 工具对 BuildTree 方法进行计时,数据集的大小为 45066 的情况下,上述的代码构建 KD 树耗时为 06:35.9083329 即 6 分钟 35 秒。

Building KD tree...
KD tree construction complete. 耗时: 00:06:35.9083329

以下是本文代码运行环境的 CPU 信息

12th Gen Intel(R) Core(TM) i7-12700

基准速度:	2.10 GHz
插槽:	1
内核:	12
逻辑处理器:	20
虚拟化:	已启用
L1 缓存:	1.0 MB
L2 缓存:	12.0 MB
L3 缓存:	25.0 MB

检查 CPU 各个核心的利用率情况,发现 CPU 只有 2 个核心处于 100% 利用率,整体负载为 20% 左右。

查找性能

使用上述代码中的 FindNearestNode 查找节点,执行耗时

00:00:00.0242371

性能优化

目前的 KD 树实现已经能够满足本文场景的需求,在性能方面,构建过程的耗时较长,查找过程的速度在可接受范围内,所以性能优化的重点放在了构建过程上。

在构建过程中检查 CPU 各个核心的利用率情况,发现 CPU 只有 2 个核心处于 100% 利用率,整体负载为 20% 左右,究其原因,如果没有特别的优化,普通的代码是以单线程的形式运行,现在的 CPU 都是多核架构,单线程只能利用到很少的 CPU 性能,因此我的优化思路是提高 CPU 的多核利用率。

在 C# 中,有多种方式来让程序实现并行执行,一般会选择 Parallel 或者是 Task 类提供的方法,Thread 可以直接对线程进行管理操作,是比较底层的实现,一般不直接管理线程。

Task 的抽象级别更高,使用 Task (包括 async/await 语句)创建的任务会在线程池中调度运行,,Task 类似于 go 语言中的「协程(Goroutine)」概念(区别在于 Task 通过编译器+状态机实现,在编译期间完成,属于无栈协程;Goroutine 则是有栈协程)。

Parallel vs Task

Parallel

  • Parallel 类提供了一种简单的方法来执行并行循环或迭代。它允许指定一个循环范围,并自动将其分割成较小的任务,然后并行执行这些任务。
  • Parallel 类的主要目的是简化并行循环的编写,使得开发者可以轻松地利用多核处理器来提高性能,而不必担心管理线程或任务的细节。
  • Parallel 类通常用于处理可迭代的数据集合,比如数组、列表等。

Task

  • Task 类提供了更加灵活和底层的并行编程机制。可以使用 Task 类来创建和管理异步操作,每个 Task 实例代表一个可执行的操作单元。
  • Task 类允许创建具有不同执行策略和调度选项的任务,例如使用线程池线程或新线程执行任务,也可以设置任务的优先级、取消任务等。
  • Task 类更适合处理异步操作,而不仅仅是并行循环。可以使用 Task 类来执行任何需要异步执行的操作,比如异步文件 I/O、网络请求等。

小结

  • Task 的扩展性比 Parallel 更强
  • Parallel 的性能比 Task 更强
  • Parallel.ForEach 更适合 CPU 密集型任务
  • Task.WhenAll 更适合 IO 密集型任务

针对本文的场景,属于 CPU 密集型任务,使用 Parallel 来实现的性能会更好一点。为了做对比,本文会分别实现 ParallelTask 两个版本。

使用 Task 优化性能

重构 BuildTree 方法

KdTree<T> 中添加 BuildTreeParallel 方法。

public void BuildTreeParallel(List<T> items) {
  Console.WriteLine("Building KD tree...");
  _root = BuildTreeParallel(items, 0, items.Count - 1, 0);
  Console.WriteLine("KD tree construction complete.");
}

private KdTreeNode<T>? BuildTreeParallel(List<T> items, int start, int end, int depth) {
  if (start > end)
    return null;

  var axis = depth % 2; // Assuming latitude and longitude as 2D coordinates

  // 复制一份新的 items 列表
  var sortedItems = items.ToList();

  sortedItems.Sort(
    (a, b) => axis == 0
    ? a.Location.Lng.CompareTo(b.Location.Lng)
    : a.Location.Lat.CompareTo(b.Location.Lat)
  );

  var medianIndex = start + (end - start) / 2;
  var node = new KdTreeNode<T> {
    Value = sortedItems[medianIndex]
  };

  var leftTask = Task.Run(() => BuildTreeParallel(sortedItems, start, medianIndex - 1, depth + 1));
  var rightTask = Task.Run(() => BuildTreeParallel(sortedItems, medianIndex + 1, end, depth + 1));

  node.Left = leftTask.Result;
  node.Right = rightTask.Result;

  return node;
}

代码逻辑与单线程版本基本相同,区别在于并行版本将每个节点的叶子节点放在一个新的 Task 中构建。

还有 items 的排序部分,在多线程版本中不能直接进行排序,而是要复制一份新的副本进行排序。高并发的场景下,对同一个 items 对象进行排序会引发一致性错误问题,具体报错表现为:

System.AggregateException: One or more errors occurred. (Unable to sort because the IComparer.Compare() method returns inconsistent result
s. Either a value does not compare equal to itself, or one value repeatedly compared to another value yields different results. IComparer:
'System.Comparison`1[StoreProximityBroadbandLocator.Services.KdNode`1[StoreProximityBroadbandLocator.Models.Device]]'.)

大意为:比较器(IComparer)的比较方法返回了不一致的结果。复制一份新的 items 副本进行排序可以解决这个问题。

重构 FindNearestNode 方法

事实上,查找节点的计算量不大,使用并行优化的效果不明显,反而会因为线程调度带来额外的开销。

本文提供基于 Task 的实现,但并不提倡使用这种方式。

public KdTreeNode<T>? FindNearestNodeParallel(ILocation location) {
  return FindNearestNodeParallel(_root, location, 0, null);
}

private KdTreeNode<T>? FindNearestNodeParallel(KdTreeNode<T>? node, ILocation location, int depth, KdTreeNode<T>? best) {
  if (node == null) return best;

  var bestDistance = best != null
    ? DistanceCalculator.CalculateDistance(location, best.Value.Location)
    : double.PositiveInfinity;
  var currentNodeDistance = DistanceCalculator.CalculateDistance(location, node.Value.Location);

  if (currentNodeDistance < bestDistance)
    best = node;

  var axis = depth % 2;
  var axisDistance =
    axis == 0 ? location.Lng - node.Value.Location.Lng : location.Lat - node.Value.Location.Lat;

  var nearChild = axisDistance < 0 ? node.Left : node.Right;
  var farChild = axisDistance < 0 ? node.Right : node.Left;

  var nearestTask = Task.Run(() => FindNearestNodeParallel(nearChild, location, depth + 1, best));
  var nearest = nearestTask.Result;

  if (nearest != null) {
    var nearestDistance = DistanceCalculator.CalculateDistance(location, nearest.Value.Location);
    if (nearestDistance < bestDistance)
      best = nearest;
  }

  if (Math.Abs(axisDistance) < bestDistance) {
    var farthestTask = Task.Run(() => FindNearestNodeParallel(farChild, location, depth + 1, best));
    var farthest = farthestTask.Result;
    if (farthest != null) {
      var farthestDistance = DistanceCalculator.CalculateDistance(location, farthest.Value.Location);
      if (farthestDistance < bestDistance)
        best = farthest;
    }
  }

  return best;
}

性能测试

与单线程版本一样,运行 KD 树构建与查找节点方法。

现在可以跑满整个 CPU 了,本文的开发环境使用的 CPU 为 Intel i7-12700,全核频率最高到 3.6GHz 左右。

基于 Task 实现的 BuildTree 方法执行耗时为 48 秒左右。

使用 Parallel 优化性能

重构 BuildTree 方法

KdTree<T> 中添加 BuildTreeParallel2 方法。

public void BuildTreeParallel2(List<T> items) {
  _root = new KdTreeNode<T>();
  BuildTreeParallel2(_root, items, 0, items.Count - 1, 0);
  Console.WriteLine("KD tree construction complete.");
}

private void BuildTreeParallel2(KdTreeNode<T> node, List<T> items, int start, int end, int depth) {
  if (start > end)
    return;

  var axis = depth % 2; // Assuming latitude and longitude as 2D coordinates

  // 复制一份新的 items 列表
  var sortedItems = items.ToList();

  sortedItems.Sort(
    (a, b) => axis == 0
    ? a.Location.Lng.CompareTo(b.Location.Lng)
    : a.Location.Lat.CompareTo(b.Location.Lat)
  );

  var medianIndex = start + (end - start) / 2;
  node.Value = sortedItems[medianIndex];

  if (start < medianIndex) {
    node.Left = new KdTreeNode<T>();
  }

  if (medianIndex < end) {
    node.Right = new KdTreeNode<T>();
  }

  Parallel.Invoke(
    () => {
      if (node.Left != null) {
        BuildTreeParallel2(node.Left, sortedItems, start, medianIndex - 1, depth + 1);
      }
    },
    () => {
      if (node.Right != null) {
        BuildTreeParallel2(node.Right, sortedItems, medianIndex + 1, end, depth + 1);
      }
    }
  );
}

错误示例

最开始的时候我写了个错误的代码

构建的时候没问题,但每次查找的时候总报 null reference 异常

经过 debug 才找出来错误的原因

这里附上最开始的代码,然后再来分析一下

node.Left = new KdTreeNode<T>();
node.Right = new KdTreeNode<T>();

Parallel.Invoke(
  () => BuildTreeParallel2(node.Left, sortedItems, start, medianIndex - 1, depth + 1),
  () => BuildTreeParallel2(node.Right, sortedItems, medianIndex + 1, end, depth + 1)
);

旧版的代码比较简洁

看起来似乎很好,但是因为最开始有一行 if (start > end) return;

所以在已经构建完成了某个分支之后,节点的末端还会生成一个空的节点

也不完全是空的,就是一个 KdTreeNode<T> 对象,但里面的 Value , Location 属性全是 null

所以只能改成我上面那种方式,既能完全利用CPU,也能解决这个问题。

性能测试

使用 C# 的 StopWatch 工具测试,基于 Parallel 实现的 BuildTree 方法执行耗时为 43 秒左右,对比 Task 的实现确实会快一些。

参考资料