资源描述
决策树算法总结
精品文档
决策树
研发二部
收集于网络,如有侵权请联系管理员删除
文件状态:
[ ] 草稿
[ ] 正式发布
[ ] 正在修改
文件标识:
当前版本:
1.0
作者:
张宏超
完成日期:
2019年3月8日
目录
1. 算法介绍 1
1.1. 分支节点选取 1
1.2. 构建树 3
1.3. 剪枝 10
2. sk-learn中的使用 12
3. sk-learn中源码分析 13
1. 算法介绍
决策树算法是机器学习中的经典算法之一,既可以作为分类算法,也可以作为回归算法。决策树算法又被发展出很多不同的版本,按照时间上分,目前主要包括,ID3、C4.5和CART版本算法。其中ID3版本的决策树算法是最早出现的,可以用来做分类算法。C4.5是针对ID3的不足出现的优化版本,也用来做分类。CART也是针对ID3优化出现的,既可以做分类,可以做回归。
决策树算法的本质其实很类似我们的if-elseif-else语句,通过条件作为分支依据,最终的数学模型就是一颗树。不过在决策树算法中我们需要重点考虑选取分支条件的理由,以及谁先判断谁后判断,包括最后对过拟合的处理,也就是剪枝。这是我们之前写if语句时不会考虑的问题。
决策树算法主要分为以下3个步骤:
1. 分支节点选取
2. 构建树
3. 剪枝
1.1. 分支节点选取
分支节点选取,也就是寻找分支节点的最优解。既然要寻找最优,那么必须要有一个衡量标准,也就是需要量化这个优劣性。常用的衡量指标有熵和基尼系数。
熵:熵用来表示信息的混乱程度,值越大表示越混乱,包含的信息量也就越多。比如,A班有10个男生1个女生,B班有5个男生5个女生,那么B班的熵值就比A班大,也就是B班信息越混乱。
基尼系数:同上,也可以作为信息混乱程度的衡量指标。
有了量化指标后,就可以衡量使用某个分支条件前后,信息混乱程度的收敛效果了。使用分支前的混乱程度,减去分支后的混乱程度,结果越大,表示效果越好。
#计算熵值
def entropy(dataSet):
tNum = len(dataSet)
print(tNum)
#用来保存标签对应的个数的,比如,男:6,女:5
labels = {}
for node in dataSet:
curL = node[-1] #获取标签
if curL not in labels.keys():
labels[curL] = 0 #如果没有记录过该种标签,就记录并初始化为0
labels[curL] += 1 #将标签记录个数加1
#此时labels中保存了所有标签和对应的个数
res = 0
#计算公式为-p*logp,p为标签出现概率
for node in labels:
p = float(labels[node]) / tNum
res -= p * log(p, 2)
return res
#计算基尼系数
def gini(dataSet):
tNum = len(dataSet)
print(tNum)
# 用来保存标签对应的个数的,比如,男:6,女:5
labels = {}
for node in dataSet:
curL = node[-1] # 获取标签
if curL not in labels.keys():
labels[curL] = 0 # 如果没有记录过该种标签,就记录并初始化为0
labels[curL] += 1 # 将标签记录个数加1
# 此时labels中保存了所有标签和对应的个数
res = 1
# 计算公式为-p*logp,p为标签出现概率
for node in labels:
p = float(labels[node]) / tNum
res -= p * p
return res
1.2. 构建树
ID3算法:利用信息熵增益,决定选取哪个特征作为分支节点。分支前的总样本熵值-分支后的熵值总和=信息熵增益。
A:10个
B:10个
特征T1
A:5个
B:8个
A:5个
B:2个
A:10个
B:10个
特征T2
A:3个
B:9个
A:7个
B:1个
T1的信息熵增益:1 – 13/20*0.961 - 7/20*0.863 = 0.073
T2的信息熵增益:1 – 12/20*0.812 - 8/20*0.544 = 0.295
所以使用T2作为分支特征更优。
ID3算法建树:
依据前面的逻辑,递归寻找最优分支节点,直到下面情况结束
1. 叶节点已经属于同一标签
2. 虽然叶节点不属于同一标签,但是特征已经用完了
3. 熵小于预先设置的阈值
4. 树的深度达到了预先设置的阈值
ID3算法的不足:
1. 取值多的特征比取值少的特征更容易被选取。
2. 不包含剪枝操作,过拟合严重
3. 特征取值必须是离散的,或者有限的区间的。
于是有了改进算法C4.5
C4.5算法:基于ID3算法进行了改进,首先,针对ID3的不足1,采用信息增益率取代ID3中使用信息增益而造成的偏向于选取取值较多的特征作为分裂点的问题。针对ID3的不足2,采用剪枝操作,缓解过拟合问题。针对ID3的不足3,采用将连续值先排列,然后逐个尝试分裂,找到连续值中的最佳分裂点。
信息增益率的计算:先计算信息增益,然后除以spliteInfo。spliteInfo为分裂后的子集合的函数,假设分裂后的子集合个数为sub1和sub2,total为分裂前的个数。spliteInfo = -sub1 / total * log(sub1 / total) – sub2 / total * log(sub2 / total)
#index:特征序号
#value:特征值
#该方法表示将index对应特征的值为value的集合返回,返回集合中不包含index对应的特征
def spliteDataSet(dataSet, index, value):
newDataSet = []
for node in dataSet:
if node[index] == value:
#[0,index)列的数据
newData = node[:index]
#[index+1,最后]列的数据
newData.extend(node[index + 1:])
newDataSet.append(newData)
return newDataSet;
#选择最优分裂项
def chooseBestFeature(dataSet):
#特征个数
featureNum = len(dataSet[0]) - 1
#计算整体样本的熵值
baseEntropy = entropy(dataSet)
print("baseEntropy = %f"%(baseEntropy))
#保存最大的信息增益率
maxInfoGainRatio = 0.0
bestFeatureId = -1
for i in range(featureNum):
#获取特征所有可能的值
featureValues = []
for node in dataSet:
featureValues.append(node[i])
print(featureValues)
#将特征值去除重复
uniqueFeatureValues = set(featureValues)
print(uniqueFeatureValues)
#按照i特征分裂之后的熵值
newEntropy = 0.0
#分裂信息
spliteInfo = 0.0
#按照i所表示的特征,开始分裂数据集
for value in uniqueFeatureValues:
#当i属性等于value时的分裂结果
subDataSet = spliteDataSet(dataSet, i, value)
print(subDataSet)
#计算占比
p = float(len(subDataSet)) / float(len(dataSet))
newEntropy += p * entropy(subDataSet)
spliteInfo += -p * log(p, 2)
#计算信息增益
infoGain = baseEntropy - newEntropy
#计算信息增益率
if spliteInfo == 0:
continue
infoGainRatio = infoGain / spliteInfo
if infoGainRatio > maxInfoGainRatio:
maxInfoGainRatio = infoGainRatio
bestFeatureId = i
return bestFeatureId
C4.5算法的不足:
1. 如果存在连续值的特征需要做排序等处理,计算比较耗时
2. 只能用于分类使用
于是有了CART算法
CART算法:也是基于ID3算法优化而来,支持分类和回归,使用基尼系数(分类树)或者均方差(回归树)替代熵的作用,减少运算难度。使用二叉树代替多叉树建模,降低复杂度。
基尼系数的计算:
均方差的计算:
计算举例,假设有如下数据源
看电视时间
婚姻情况
职业
年龄
3
未婚
学生
12
4
未婚
学生
18
2
已婚
老师
26
5
已婚
上班族
47
2.5
已婚
上班族
36
3.5
未婚
老师
29
4
已婚
学生
21
如果将婚否作为标签,该问题是一个分类问题,所以使用基尼系数
假设使用职业作为特征分支,对于看电视和年龄,都是连续数据,需要按照C4.5的算法排序后处理,这里先分析简单的按照职业开始划分。
又因为,CART算法的建模是二叉树,所以,针对职业来说,有以下组合,学生|非学生,老师|非老师,上班族|非上班族,到底怎么划分,就要通过基尼系数来判断了。
gini = 3 / 7 * (1 – 2 / 3 * 2 /3 – 1 / 3 * 1 / 3) + 4 / 7 * (1 – 3 / 4 * 3 / 4 – 1 / 4 * 1 / 4) = 0.4
gini = 2 / 7 * (1 – 1 / 2 * 1 / 2 – 1 / 2 * 1 / 2) + 5 / 7 * (1 – 2 / 5 * 2 / 5 – 3 / 5 * 3 / 5) = 0.49
gini = 2 / 7 * (1 – 1 * 1) + 5 / 7 * (1 – 3 / 5 * 3 / 5 – 2 / 5 * 2 / 5) = 0.34
所以,如果选择职业来划分,那么首先应该按照上班族|非上班族划分
如果将年龄作为标签,该问题是一个回归问题,所以使用均方差
同样,先考虑使用职业来划分
mean = 开方(12 * 12 + 18 * 18 + 21 * 21 – 3 * 17 * 17) + 开方(26 * 26 + 47 * 47 + 36 * 36 + 29 * 29 – 5 * 32.5 * 32.5) = 34.71
其他情况略。
可以看到选择分裂属性这一步骤会比较麻烦,首先要遍历所有特征,找到每一个特征的最优分裂方法,然后在选择最优的分裂特征。
功能
树结构
特征选取
连续值处理
缺失值处理
剪枝
ID3
分类
多叉
信息增益
不支持
不支持
不支持
C4.5
分类
多叉
信息增益率
支持
支持
支持
CART
分类/回归
二叉
基尼系数(分类)
,均方差(回归)
支持
支持
支持
1.3. 剪枝
CCP(Cost Complexity Pruning)代价复杂性剪枝法(CART常用)
REP(Reduced Error Pruning)错误降低剪枝法
PEP(Pessimistic Error Pruning)悲观错误剪枝法(C4.5使用)
MEP(Minimum Error Pruning)最小错误剪枝法
这里以CCP为例讲解其原理
CCP选择节点表面误差率增益值最小的非叶子节点,删除该节点的子节点。若多个非叶子节点的表面误差率增益值相同,则选择子节点最多的非叶子节点进行裁剪。
表面误差率增益值计算:
R(t)表示非叶子节点的错误率,比如,总样本20,在A节点上a类5个,b类2个,所以可以认为A节点代表的是a类,那么错误率就是2 / 7 * 7 / 20
R(T)表示叶子节点的错误率累积和
N(T)表示叶子节点的个数
剪枝步骤:
1. 构建子树序列
2. 找到最优子树,作为我们的决策树(交叉验证等)
举例:
t1是根节点
t2,t3,t4,t5是非叶子节点
t6,t7,t8,t9,t10,t11是叶子节点
首先我们计算所有非叶子节点误差率增益值
t4: (4/50 * 50/80 – 1/45 * 45/80 – 2/5 * 5/80) / (2 – 1) = 0.0125
t5: (4/10 * 10/80 – 0 - 0) / (2 - 1) = 0.05
t2: (10/60 * 60/80 – 1/45 * 45/80 – 2/5 * 5/80 – 0 - 0) / (4 - 1) = 0.0292
t3: 0.0375
因此得到第1颗子树:T0 = t4(0.0125),t5(0.05),t2(0.0292),t3(0.0375)
比较发现可以将t4裁剪掉
得到第2颗子树
t5: 0.05
t3: 0.0375
t2: (10/60 * 60/80 – 4/50 * 50/80 – 0 - 0) / (3 -1) = 0.0375
此时t2与t3相同,那么裁剪叶子节点较多的,因此t2被裁剪
得到第3颗树
然后对上面3颗子树进行验证,找到效果最后的作为剪枝之后的决策树。
2. sk-learn中的使用
from sklearn.datasets import load_iris
from sklearn import tree
import pydotplus
import graphviz
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("iris.pdf")
3. sk-learn中源码分析
主要分析tree的相关函数代码,使用pycharm下载sklearn包中tree文件,引用了_tree.pxd,pxd相当于头文件,其实现在_tree.pyd中,pyd是加密文件,无法查看。从github上下载源码中有_tree.pyx相当于c文件,因此可以查看。
.pxd:相当于.h
.pyx:相当于.c
.pyd:相当于dll
tree.DecisionTreeClassifier() 创建分类决策树对象
DecisionTreeClassifier继承BaseDecisionTree
clf.fit(iris.data, iris.target) 建树
DecisionTreeClassifier直接使用了父类BaseDecisionTree的方法
super().fit(
X, y,
sample_weight=sample_weight,
check_input=check_input,
X_idx_sorted=X_idx_sorted)
查看DecisionTreeClassifier的fit,学习建树过程
代码前面是对参数的校验之类的工作
criterion:表示选择分裂节点的准则,CLF表示分类使用gini系数、熵等,REG表示回归使用均方差等。他们的定义在
对于这些准则的计算,在_criterion.Gini或者其他文件中实现,使用Cpython实现的。以Gini的计算为例
同理,分裂的规则定义在splitter中,具体实现也是在Cpython中
最后是构造器,这也是面向对象设计模式中的一种设计模式,构造器模式。思想是,构造器中根据加入的原料,产出不同的东西。
builder = DepthFirstTreeBuilder (优先深度)
builder = BestFirstTreeBuilder (优先最优)
他们的代码实现在_tree.pyx中
展开阅读全文