一种回归决策树的快速遍历划分算法

A Fast Traversal Partition Algorithm for CART

本文所述算法改进实现见代码仓库: ML-Algorithm

本文描述了CART决策树的一种算法改进,能在基础的CART算法上提升样本数目倍的性能。同时也对该算法的理论依据进行了推导,实际效果进行了测试。

1. 问题的提出

众所周知,建立CART树有一个关键步骤:遍历数据空间中的所有划分界限,寻找最优切分特征$\alpha$与阈值$c$,以最小化分出的两个集合的方差,也就是下面这个式子:

其中, $\bar{y_1},\bar{y_2}$分别是$x_i[\alpha]<c,x_i[\alpha]>c$样本点的y均值.

问题在于,经典的CART树要遍历所有划分界限,所以构造一个平衡二叉树的时间代价为$O(n_{samples}^{2}n_{features}\log(n_{samples}))$,对于一个较大大数据集来说,这样的时间复杂度是不能够接受的.

于是,我查看了Sklearn的源码,发现它并没有计算上面的式子,而是对$\sum{y_i}$在做操作,查遍了资料也无法理解为什么要这样做(应该在某些老文献里有,但我没找到),于是只有从上式开始一步一步自己推导.

2. 优化算法的推导

首先,拎出一个方差式来变形

那么,对一个切分特征$\alpha$与阈值$c$的划分(假设依据$x_i[\alpha]<c$划分为了L,R两个集合):

由于$\sum{y^2}$对任意划分都相同,故我们现在只需要$\max\limits_{\alpha,c}{[\frac{(\sum_{(x,y) \in L}{y})^2}{n_{L}}+\frac{(\sum_{(x,y) \in R}{y})^2}{n_{R}}]}$得到这个式子以后,我们就知道了,每一次迭代,只需要获知两个集合中样本点个数和样本点y的和就足够了,而无需重新计算方差,时间代价降至了$O(n_{features}n_{samples}\log(n_{samples}))$(降了一个$n_{samples}$下来),同时迭代求和的运算可以使用备忘录。

3. 算法的描述

在寻找最佳分割超平面时,按照以下步骤进行(x表示输入,y表示输出):

计算y的总和;
初始化最佳增益時的相关数据;

遍历所有特征:feature{
  // 此时初始化左节点集为全部样本,右节点为空;
  初始化左右节点集y总和;
  将所有训练样本按feature的值进行排序;
  遍历所有排序好的样本:sample{
    // 此时将sample从左节点集移到右节点集中;
    更新左右节点集y总和;
    计算当前左右节点集划分的增益:sum_left * sum_left / n_left + sum_right * sum_right / n_right;
    
    若增益高于当前最佳增益{
      记录最佳增益時的相关数据;
    }
  }
}

返回最佳增益時的相关数据;

4. 算法实现与测试

4.1. 优化算法的代码实现

完整决策树见: cquai-ml/DecisionTreeRegressor

def _build_tree(self, X, y, cur_depth, parent, is_left):
    if cur_depth == self.max_depth:
        self._build_leaf(X, y, cur_depth, parent, is_left)

    best_gain = -np.inf
    best_feature = None
    best_threshold = None
    best_left_ind = best_right_ind = None

    sum_all = np.sum(y)
    step = lambda x, y, a: (x + a, y - a)

    for i in range(X.shape[1]):  # for features
        sum_left, sum_right = 0, sum_all
        n_left = 0
        n_right = X.shape[0]
        ind = np.argsort(X[:, i])

        for j in range(ind.shape[0] - 1):  # for all sample
            # step by step
            sum_left, sum_right = step(sum_left, sum_right, y[ind[j]])
            n_left, n_right = step(n_left, n_right, 1)

            cur_gain = (
                sum_left * sum_left / n_left + sum_right * sum_right / n_right
            )
            if cur_gain > best_gain:  # found better choice
                best_gain = cur_gain
                best_feature = i
                best_threshold = X[ind[j], i]
                best_left_ind, best_right_ind = ind[: j + 1], ind[j + 1 :]
    
    # ...
    

4.2. 优化算法的测试

4.2.1. 测试Baseline

首先,需要定义baseline,考虑4.1中的代码实现,我将其计算增益的部分改为以下的实现以进行baseline效果对照

# ...
for i in range(X.shape[1]):  # for features
    ind = np.argsort(X[:, i])
    init_var = np.var(X[:,i])
    
    for j in range(ind.shape[0] - 1):  # for all sample
        cur_gain = np.var(X[:ind[j],i]) + np.var(X[ind[j]:,i]) - init_var

    # ...
    

4.2.2. 测试结果

基于sklearn提供的california_housing数据集,从该数据集随机抽出x个样本,组成测试数据集并分别测试改进算法和原算法的运行时间y。

其中的采样间隔:

样本数(x)范围 采样间隔 重复测试次数
$500 \le x <2100$ 50 5
$2200\le x <8000$ 100 3
$8000\le x <20000$ 1000 1

结果如下图所示,发现算法改进能显著提高决策树性能。

res.png

5. 附录

5.1. 测试结果

使用Aliyun ecs.t5-lc1m2.large型实例,只使用单核单线程进行测试。

样本数x Baseline算法运行时间y1 改进算法运行时间y2
500 0.6455640614032745 0.03460801243782043
550 0.7065692529082298 0.03810213208198547
600 1.0056967556476593 0.0414450466632843
650 0.8232804700732231 0.044684456288814546
700 0.8901257112622261 0.047838394343852994
750 1.0026517659425735 0.05146666020154953
800 1.0485285922884942 0.05495186746120453
850 1.1395855575799942 0.058270522952079774
900 1.1657116502523421 0.06108095794916153
950 1.2251508012413979 0.06497829854488373
1000 1.4062113106250762 0.06846041977405548
1050 1.434651754796505 0.07195884585380555
1100 1.495849071443081 0.07824563086032868
1150 1.5689095377922058 0.082560795545578
1200 1.6347307205200194 0.08181308954954147
1250 1.6958067312836647 0.08498744070529937
1300 1.7384015157818795 0.08845457434654236
1350 1.8185098826885224 0.09170756489038467
1400 2.146059076488018 0.09549188911914826
1450 1.9570893928408624 0.09876007586717606
1500 2.101211351156235 0.1016637220978737
1550 2.108571472764015 0.10526050329208374
1600 2.2171461448073386 0.10812699645757676
1650 2.333981400728226 0.1120547667145729
1700 2.3909598261117937 0.11580100655555725
1750 2.498200277984142 0.11864992380142211
1800 2.5535635992884638 0.12206683903932572
1850 2.638751062750816 0.1258074700832367
1900 2.6781568259000776 0.12916499078273774
1950 2.755809526145458 0.14649802893400193
2000 2.8841904655098913 0.1365313023328781
2050 2.894517731666565 0.14018919616937636
2200 3.1409400403499603 0.15089931587378183
2300 3.3270686442653337 0.1572838177283605
2400 3.525728871424993 0.16408776491880417
2500 3.7560477579633393 0.1711770842472712
2600 3.9674307058254876 0.19605823854605356
2700 4.013482913374901 0.2282602166136106
2800 4.446668999890487 0.23765958100557327
2900 4.48147031168143 0.22356532017389932
3000 4.948460293312867 0.20575664192438126
3100 4.874664048353831 0.2128776361544927
3200 5.022817427913348 0.220118078092734
3300 5.1897163813312845 0.2260249381264051
3400 5.3919211477041245 0.2337716445326805
3500 5.67029333114624 0.24055358270804086
3600 5.770458854734898 0.24743302911520004
3700 6.001950363318126 0.25265806913375854
3800 6.317233701546987 0.2606764684120814
3900 6.6534217819571495 0.2671254873275757
4000 6.629562981426716 0.2737120861808459
4100 7.198160822192828 0.2796892672777176
4200 7.056483464936416 0.2931108921766281
4300 7.427918809155623 0.29908887296915054
4400 7.40793473025163 0.31269486993551254
4500 7.415556532641252 0.3146485487620036
4600 7.823212171594302 0.34473737825949985
4700 8.000009703139463 0.3931199833750725
4800 8.230909988284111 0.35864416758219403
4900 8.696146366496881 0.3422362531224887
5000 8.821599625051022 0.35491255422433216
5100 9.121791074673334 0.34864648431539536
5200 9.001985614498457 0.3566499973336856
5300 9.513516945143541 0.37122564017772675
5400 9.839489658673605 0.36829609672228497
5500 10.308331390221914 0.37442801396052044
5600 10.183131786684195 0.3832108899950981
5700 10.464193053543568 0.39055130382378894
5800 11.124716185033321 0.3948906287550926
5900 11.035219440857569 0.402115764717261
6000 11.121028197308382 0.4073619817694028
6100 11.415477780004343 0.41393469522396725
6200 11.608833223581314 0.4227173924446106
6300 11.858192754288515 0.43291472891966504
6400 12.654898650944233 0.4566325644652049
6500 12.49950555463632 0.4969073285659154
6600 12.536672584712505 0.6367243006825447
6700 12.912669330835342 0.8427434911330541
6800 13.501878333588442 0.8529169037938118
6900 13.264134099086126 0.868061825633049
7000 13.845865001281103 0.8798240969578425
7100 14.280490110317865 0.8885419617096583
7200 14.445830556253592 0.8153113002578417
7300 14.631624579429626 0.4957072486480077
7400 14.913921765983105 0.5014924754699072
7500 15.05869090060393 0.510867344836394
7600 15.384378271798292 0.539261760811011
7700 15.773150896032652 0.5437572970986366
7800 15.910091089705626 0.543494924902916
7900 16.582480480273563 0.556576170027256
8000 16.312348648905754 0.5606570094823837
9000 19.766967564821243 0.6264805942773819
10000 23.218108020722866 0.7066168040037155
11000 27.715855732560158 0.7624464482069016
12000 30.983531765639782 0.836162842810154
13000 37.89370854943991 0.910807304084301
14000 41.12514939159155 0.9750531241297722
15000 46.30769927799702 1.054315097630024
16000 52.22487150132656 1.126071348786354
17000 56.75384297221899 1.1827935576438904
18000 66.73259945958853 1.2493708804249763
19000 73.34698545187712 1.3327653110027313
20000 79.7716326713562 1.447318822145462
本文总字数: 6599