0%

CART模型与实现

对于连续型数据,CART采用二元分割的方法构建树,用平方误差的总和(即y的方差乘样本数)来度量混乱度。本文采用Numpy手动实现CART模型,包括预剪枝和后剪枝。

CART算法是树模型的经典算法,可以用于构建二元树并处理离散或者连续型数据的切分,如果采用不同的误差准则,就可以通过CART来构建模型树和回归树。回归树指的是叶子节点使用的分段常数,模型树指的是叶子节点使用线性回归方程。本文采用Numpy手动实现CART模型,构建模型树和回归树,以及预剪枝和后剪枝,参考自《机器学习实战》第九章。

1、CART

构建CART的伪代码createTree()为:

1
2
3
4
5
6
7
8
9
找到最佳的待切分特征:

如果该节点不能再分,将该节点存为叶节点

执行二元分割

在左子树调用createTree()

在右子树调用createTree()

createTree()的Python实现:

1
2
3
4
5
6
7
8
9
10
11
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
if feat == None:
return val #if the splitting hit a stop condition return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree

这个函数的输入为一个mat格式的二维数据,还有三个可选参数,这三个参数都是在选择切分点的时候使用的,我们在下面的代码中再解释。现在,只需要看到,它输入了一个二维的数据,递归调用自身创建树就可以了。最终创建的树结构保存在一个字典中,即代码中的retTree。例如,如果构造的的树有两个叶子节点和一个根节点,表示为

{'spInd':0,'spVal':matrix([[1]]),'right':-2,'left':2}

表示根节点的特征为第0列对应的特征,阈值为1,叶子节点输出分别为-2和2。如果不只有一层,例如左子树不是叶子节点,那么左子树的值就不是数字,而是字典。如此嵌套,就可以表示一个完整的树。

寻找切分点的函数比较复杂,第2部分再介绍,先看二元切分,意思就是找到了合适的特征切分点后,将数据集合分为两半,代码如下:

1
2
3
4
def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
return mat0,mat1

该函数实现了已知一个特征和阈值,可以对数据块进行切分,例如:

1
2
3
4
5
import numpy as np
testMat = mat(eye(4))
mat0,mat1 = binSplitDataSet(testMat,1,0.5)
# mat0 = [[0,1,0,0]]
# mat1 = [[1,0,0,0],[0,0,1,0],[0,0,0,1]]

在createTree()中,还有最重要的函数没有实现,那就是chooseBestSplit()函数,它的目的是输入一个数据块,返回最合适的划分特征以及阈值,实现二分分割,或者对于不能再划分的数据,返回None和由整个数据组成的叶子节点。具体看第2部分所示。

2、寻找切分点

在连续情况下,寻找切分点的原则和离散时候相同,即使得切分后的数据块混乱度降低。在离散情况下,我们采用熵或者基尼指数等指标衡量离散程度,对于连续情况,自然的,采用总方差来衡量数据的离散程度。

寻找切分点的伪代码如下:

1
2
3
4
5
6
7
8
9
对于每个特征:

对于每个特征值:

将数据切分称为两份

计算当前切分的混乱程度

如果当前的混乱程度小于历史最小误差,那么保存此时的特征以及阈值

Python实现如下,代码中包括四种退出方式和两种返回格式,即:

  • 如果遇到数据集中所有的标签值都相等(#1)
  • 划分完混乱度减少很小(#2)
  • 切分的数据集过于琐碎(#3)

那么就返回None, leafType(dataSet),即特征为None,和整个数据集构成的叶子节点;如果没有出现以上三种情况,那么就代表可以正常的划分数据,那么就输出划分的特征和阈值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0]; tolN = ops[1]
#if all the target variables are the same value: quit and return value
if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
return None, leafType(dataSet)
m,n = shape(dataSet)
#the choice of the best feature is driven by Reduction in RSS error from mean
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#if the decrease (S-bestS) is less than a threshold don't do the split
if (S - bestS) < tolS:
return None, leafType(dataSet) #exit cond 2
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3
return None, leafType(dataSet)
return bestIndex,bestValue#returns the best feature to split on
#and the value used for that split

切分函数除了输入数据集合dataSet外,还有三个可选参数,

  • 第一个是leafType负责生成叶节点,当切分数据停止时调用该函数得到叶节点模型
  • 第二个时errType,就是我们前面说的计算混乱程度的函数
  • 最后一个是用户指定的参数,用来控制程序停止,可以认为预剪枝,包括两个部分,第一个是容许的最小误差下降值,即如果切分完比不切分时混乱度下降的小于这个阈值,就停止切分,第二个是切分的最小样本数,可以认为是避免切分的过于琐碎。

以上代码还缺少了两个函数,即leafType和errType,补上就完整了,leafType目的返回一个叶子几点数值,即接受一个数据集输入,返回一个数值,为所有标签的平均。errType则是给定一个数据集,计算混论程度。

1
2
3
4
5
def regLeaf(dataSet):#returns the value used for each leaf
return mean(dataSet[:,-1])

def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]

3 运行代码

将以上所有代码放在一起,就可以正确的产生CART树了。注意输入的数据格式为二维mat数组,每一行为一个样本,其中前n-1

行为特征,最后一行为标签,例如这样的输入:

1
2
3
feature = np.random.random((100,10))
labels = np.ones((100,1))
DataMat = mat(np.hstack((feature,labels)))

其他情况下,自行导入数据并处理成以上格式即可。

如果是从给定的txt文件导入数据,可以参考如下代码,注意一下代码输出是二维列表,注意转换为mat格式。

1
2
3
4
5
6
7
8
def loadDataSet(fileName):      #general function to parse tab -delimited floats
dataMat = [] #assume last column is target value
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float,curLine)) #map all elements to float()
dataMat.append(fltLine)
return dataMat

4 树剪枝

如果树的节点过多,直观看就是过于琐碎,表示可能出现了过拟合。为了防止过拟合,可以采用剪枝的方法,在前面的代码里面,已经通过定义一些参数防止过拟合,即chooseBestSplit中的ops参数,这种方法称为预剪枝。但是预剪枝存在以下问题:容许的最小误差下降值采用硬编码的方式输入,这样对于label的数量集就非常的敏感,因此不方便的调参。

此处介绍后剪枝的方法,思路就是,通过测试集数据来评判一个训练好的树结构,在合并了一些节点的情况下,会不会表现的更好,后剪枝伪代码如下:

1
2
3
4
5
6
7
8
9
基于已有的树切分测试数据:

如果存在任一子集是一棵树,则在该子集递归剪枝过程

计算将当前两个叶子节点合并后的混乱度

计算不合并的混乱度

如果合并后混乱度降低的化,就合并

python实现如下,该函数输入一个树(以字典的方式存储),和测试集合,输出剪枝完的结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def prune(tree, testData):
if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree
if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
#if they are now both leafs, see if we can merge them
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
sum(power(rSet[:,-1] - tree['right'],2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(power(testData[:,-1] - treeMean,2))
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else: return tree
else: return tree

以上代码调用了两个函数,分别是isTree和getMean,实现如下,第一个函数顾名思义,判断节点是否是叶子节点(是的话返回flase);第二个函数则是一个递归函数,从上到下遍历直到叶节点位置。如果找到两个叶子节点则返回平均值,

1
2
3
4
5
6
7
def isTree(obj):
return (type(obj).__name__=='dict')

def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
return (tree['left']+tree['right'])/2.0

以上就是剪枝的全部代码。

5 模型树

模型树与上述回归树的区别在于,对叶子节点内的数据,用一个线性函数去拟合它,而不是仅输出所有样本的均值。

模型树就是在上述代码的基础上,改变两个参数即可。

在createsTree()的代码中,我们固定了leafType=regLeaf, errType=regErr,还记得这两个变量分别是用来返回叶子节点和定义误差类型的,我们将其改为modelLeafmodelErr,为了实现这两个函数需要定义一个线性拟合函数,输入数据集,返回拟合的参数和拆分的X和Y。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def linearSolve(dataSet):   #helper function used in two places
m,n = shape(dataSet)
X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion
X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
xTx = X.T*X
if linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverse,\n\
try increasing the second value of ops')
ws = xTx.I * (X.T * Y)
return ws,X,Y

def modelLeaf(dataSet):#create linear model and return coeficients
ws,X,Y = linearSolve(dataSet)
return ws

def modelErr(dataSet):
ws,X,Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat,2))

这样,就实现了模型树。

6 总结

本文介绍了利用CART进行连续特征回归的代码实现,以及预剪枝和后剪枝的操作。

参考文献:

  • 《机器学习实战》第九章