资源描述
决策树程序实验
精品文档
决策树程序实验
众所周知,数据库技术从20世纪80年代开始,已经得到广泛的普及和应用。随着数据库容量的膨胀,特别是数据仓库以及web等新型数据源的日益普及,人们面临的主要问题不再是缺乏足够的信息可以使用,而是面对浩瀚的数据海洋如何有效地利用这些数据。
从数据中生成分类器的一个特别有效的方法是生成一个决策树(Decision Tree)。决策树表示方法是应用最广泛的逻辑方法之一,它从一组无次序、无规则的事例中推理出决策树表示形式的分类规则。决策树分类方法采用自顶向下的递归方式,在决策树的内部结点进行属性值的比较并根据不同的属性值判断从该结点向下的分支,在决策树的叶结点得到结论。所以从决策树的根到叶结点的一条路径就对应着一条合取规则,整棵决策树就对应着一组析取表达式规则。
决策树是应用非常广泛的分类方法,目前有多种决策树方法,如ID3、CN2、SLIQ、SPRINT等。
一、问题描述
1.1相关信息
决策树是一个类似于流程图的树结构,其中每个内部结点表示在一个属性上的测试,每个分支代表一个测试输入,而每个树叶结点代表类或类分布。数的最顶层结点是根结点。一棵典型的决策树如图1所示。它表示概念buys_computer,它预测顾客是否可能购买计算机。内部结点用矩形表示,而树叶结点用椭圆表示。为了对未知的样本分类,样本的属性值在决策树上测试。决策树从根到叶结点的一条路径就对应着一条合取规则,因此决策树容易转化成分类规则。
图1
ID3算法:
■ 决策树中每一个非叶结点对应着一个非类别属性,树枝代表这个属性的值。一个叶结点代表从树根到叶结点之间的路径对应的记录所属的类别属性值。
■ 每一个非叶结点都将与属性中具有最大信息量的非类别属性相关联。
■ 采用信息增益来选择能够最好地将样本分类的属性。
信息增益基于信息论中熵的概念。ID3总是选择具有最高信息增益(或最大熵压缩)的属性作为当前结点的测试属性。该属性使得对结果划分中的样本分类所需的信息量最小,并反映划分的最小随机性或“不纯性”。
1.2问题重述
1、目标概念为“寿险促销”
2、计算每个属性的信息增益
3、确定根节点的测试属性
模型求解
构造决策树的方法是采用自上而下的递归构造,其思路是:
■ 以代表训练样本的单个结点开始建树(步骤1)。
■ 如果样本都在同一类,则该结点成为树叶,并用该类标记(步骤2和3)。
■ 否则,算法使用称为信息增益的机遇熵的度量为启发信息,选择能最好地将样本分类的属性(步骤6)。该属性成为该结点的“测试”或“判定”属性(步骤7)。值得注意的是,在这类算法中,所有的属性都是分类的,即取离散值的。连续值的属性必须离散化。
■ 对测试属性的每个已知的值,创建一个分支,并据此划分样本(步骤8~10)。
■ 算法使用同样的过程,递归地形成每个划分上的样本决策树。一旦一个属性出现在一个结点上,就不必考虑该结点的任何后代(步骤13)。
■ 递归划分步骤,当下列条件之一成立时停止:
(a)给定结点的所有样本属于同一类(步骤2和3)。
(b)没有剩余属性可以用来进一步划分样本(步骤4)。在此情况下,采用多数表决(步骤5)。这涉及将给定的结点转换成树叶,并用samples中的多数所在类别标记它。换一种方式,可以存放结点样本的类分布。
(c)分支test_attribute=ai 没有样本。在这种情况下,以samples中的多数类创建一个树叶(步骤12)。
算法 Decision_Tree(samples,attribute_list)
输入 由离散值属性描述的训练样本集samples;
候选属性集合attribute_list。
输出 一棵决策树。
(1) 创建节点N;
(2) If samples 都在同一类C中then
(3) 返回N作为叶节点,以类C标记;
(4) If attribute_list为空then
(5) 返回N作为叶节点,以samples 中最普遍的类标记;//多数表决
(6) 选择attribute_list 中具有最高信息增益的属性test_attribute;
(7) 以test_attribute 标记节点N;
(8) For each test_attribute 的已知值v //划分 samples
(9) 由节点N分出一个对应test_attribute=v的分支;
(10) 令Sv为 samples中 test_attribute=v 的样本集合;//一个划分块
(11) If Sv为空 then
(12) 加上一个叶节点,以samples中最普遍的类标记;
(13) Else 加入一个由Decision_Tree(Sv,attribute_list-test_attribute)返回节点值
E(S)=(-9\15)log2(9\15)-(6\15)log2(6\15)=0.971
Values(收入范围)={20-30K,30-40k,40-50K,50-60K}
E(S(20-30K))= (-2\4)log2(2\4)- (2\4)log2(2\4)=1
E(S(30-40K))= (-4\5)log2(4\5)- (1\5)log2(1\5)=0.7219
E(S(40-50K))= (-1\4)log2(1\4)- (3\4)log2(3\4)=0.8113
E(S(50-60K))= (-2\2)log2 (2\2)- (0\2)log2(0\2)=0
所以
E(S,收入范围)=(4/15) E(S(20-30K)) +(5/15) E(S(30-40K)) +(4/15) E(S(40-50K)) +(2/15) E(S(50-60K))=0.7236
Gain(S,收入范围)=0.971-0.7236=0.2474
同理:计算“保险”,“性别”,“年龄”的信息增益为:
E(S)=(-9\15)log2(9\15)-(6\15)log2(6\15)=0.971
Insurance(保险)={yes, no}
E(S(yes))= (-3\3)log2 (3\3)- (0\3)log2(0\3)=0
E(S(no))= (-6\12)log2 (6\12)- (6\12)log2(6\12)=1
E(S, 保险)=(3/15) E(S(yes)) +(12/15) E(S(no)) =0.8
Gain(S, 保险)=0.971-0.8=0.171
E(S)=(-9\15)log2(9\15)-(6\15)log2(6\15)=0.971
sex(性别)={male, female}
E(S(male))= (-3\7)log2 (3\7)- (4\7)log2(4\7)=0.9852
E(S(female))= (-6\8)log2 (6\8)- (2\8)log2(2\8)=0.8113
E(S, 性别)=(7/15) E(S(male)) +(8/15) E(S(female)) =0.8925
Gain(S, 性别)=0.971-0.8925=0.0785
E(S)=(-9\15)log2(9\15)-(6\15)log2(6\15)=0.971
age(年龄)={15~40,41 ~60}
E(S(15~40))= (-6\7)log2 (6\7)- (1\7)log2(1\7)=0.5917
E(S(41 ~60))= (-3\8)log2 (3\8)- (5\8)log2(5\8)=0.9544
E(S, 年龄)=(7/15) E(S(15~40)) +(8/15) E(S(41 ~60)) =0.7851
Gain(S, 年龄)=0.971-0.7851=0.1859
代码
package DecisionTree;
import java.util.ArrayList;
/**
* 决策树结点类
*/
public class TreeNode {
private String name; //节点名(分裂属性的名称)
private ArrayList<String> rule; //结点的分裂规则
ArrayList<TreeNode> child; //子结点集合
private ArrayList<ArrayList<String>> datas; //划分到该结点的训练元组
private ArrayList<String> candAttr; //划分到该结点的候选属性
public TreeNode() {
this.name = "";
this.rule = new ArrayList<String>();
this.child = new ArrayList<TreeNode>();
this.datas = null;
this.candAttr = null;
}
public ArrayList<TreeNode> getChild() {
return child;
}
public void setChild(ArrayList<TreeNode> child) {
this.child = child;
}
public ArrayList<String> getRule() {
return rule;
}
public void setRule(ArrayList<String> rule) {
this.rule = rule;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public ArrayList<ArrayList<String>> getDatas() {
return datas;
}
public void setDatas(ArrayList<ArrayList<String>> datas) {
this.datas = datas;
}
public ArrayList<String> getCandAttr() {
return candAttr;
}
public void setCandAttr(ArrayList<String> candAttr) {
this.candAttr = candAttr;
}
}
package DecisionTree;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;
/**
* 决策树算法测试类
*/
public class TestDecisionTree {
/**
* 读取候选属性
* @return 候选属性集合
* @throws IOException
*/
public ArrayList<String> readCandAttr() throws IOException{
ArrayList<String> candAttr = new ArrayList<String>();
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
String str = "";
while (!(str = reader.readLine()).equals("")) {
StringTokenizer tokenizer = new StringTokenizer(str);
while (tokenizer.hasMoreTokens()) {
candAttr.add(tokenizer.nextToken());
}
}
return candAttr;
}
/**
* 读取训练元组
* @return 训练元组集合
* @throws IOException
*/
public ArrayList<ArrayList<String>> readData() throws IOException {
ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>();
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
String str = "";
while (!(str = reader.readLine()).equals("")) {
StringTokenizer tokenizer = new StringTokenizer(str);
ArrayList<String> s = new ArrayList<String>();
while (tokenizer.hasMoreTokens()) {
s.add(tokenizer.nextToken());
}
datas.add(s);
}
return datas;
}
/**
* 递归打印树结构
* @param root 当前待输出信息的结点
*/
public void printTree(TreeNode root){
System.out.println("name:" + root.getName());
ArrayList<String> rules = root.getRule();
System.out.print("node rules: {");
for (int i = 0; i < rules.size(); i++) {
System.out.print(rules.get(i) + " ");
}
System.out.print("}");
System.out.println("");
ArrayList<TreeNode> children = root.getChild();
int size =children.size();
if (size == 0) {
System.out.println("-->leaf node!<--");
} else {
System.out.println("size of children:" + children.size());
for (int i = 0; i < children.size(); i++) {
System.out.print("child " + (i + 1) + " of node " + root.getName() + ": ");
printTree(children.get(i));
}
}
}
/**
* 主函数,程序入口
* @param args
*/
public static void main(String[] args) {
TestDecisionTree tdt = new TestDecisionTree();
ArrayList<String> candAttr = null;
ArrayList<ArrayList<String>> datas = null;
try {
System.out.println("请输入候选属性");
candAttr = tdt.readCandAttr();
System.out.println("请输入训练数据");
datas = tdt.readData();
} catch (IOException e) {
e.printStackTrace();
}
DecisionTree tree = new DecisionTree();
TreeNode root = tree.buildTree(datas, candAttr);
tdt.printTree(root);
}
}
package DecisionTree;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
/**
* 选择最佳分裂属性
*/
public class Gain {
private ArrayList<ArrayList<String>> D = null; //训练元组
private ArrayList<String> attrList = null; //候选属性集
public Gain(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList) {
this.D = datas;
this.attrList = attrList;
}
/**
* 获取最佳侯选属性列上的值域(假定所有属性列上的值都是有限的名词或分类类型的)
* @param attrIndex 指定的属性列的索引
* @return 值域集合
*/
public ArrayList<String> getValues(ArrayList<ArrayList<String>> datas, int attrIndex){
ArrayList<String> values = new ArrayList<String>();
String r = "";
for (int i = 0; i < datas.size(); i++) {
r = datas.get(i).get(attrIndex);
if (!values.contains(r)) {
values.add(r);
}
}
return values;
}
/**
* 获取指定数据集中指定属性列索引的域值及其计数
* @param d 指定的数据集
* @param attrIndex 指定的属性列索引
* @return 类别及其计数的map
*/
public Map<String, Integer> valueCounts(ArrayList<ArrayList<String>> datas, int attrIndex){
Map<String, Integer> valueCount = new HashMap<String, Integer>();
String c = "";
ArrayList<String> tuple = null;
for (int i = 0; i < datas.size(); i++) {
tuple = datas.get(i);
c = tuple.get(attrIndex);
if (valueCount.containsKey(c)) {
valueCount.put(c, valueCount.get(c) + 1);
} else {
valueCount.put(c, 1);
}
}
return valueCount;
}
/**
* 求对datas中元组分类所需的期望信息,即datas的熵
* @param datas 训练元组
* @return datas的熵值
*/
public double infoD(ArrayList<ArrayList<String>> datas){
double info = 0.000;
int total = datas.size();
Map<String, Integer> classes = valueCounts(datas, attrList.size());
Iterator iter = classes.entrySet().iterator();
Integer[] counts = new Integer[classes.size()];
for(int i = 0; iter.hasNext(); i++)
{
Map.Entry entry = (Map.Entry) iter.next();
Integer val = (Integer) entry.getValue();
counts[i] = val;
}
for (int i = 0; i < counts.length; i++) {
double base = DecimalCalculate.div(counts[i], total, 3);
info += (-1) * base * Math.log(base);
}
return info;
}
/**
* 获取指定属性列上指定值域的所有元组
* @param attrIndex 指定属性列索引
* @param value 指定属性列的值域
* @return 指定属性列上指定值域的所有元组
*/
public ArrayList<ArrayList<String>> datasOfValue(int attrIndex, String value){
ArrayList<ArrayList<String>> Di = new ArrayList<ArrayList<String>>();
ArrayList<String> t = null;
for (int i = 0; i < D.size(); i++) {
t = D.get(i);
if(t.get(attrIndex).equals(value)){
Di.add(t);
}
}
return Di;
}
/**
* 基于按指定属性划分对D的元组分类所需要的期望信息
* @param attrIndex 指定属性的索引
* @return 按指定属性划分的期望信息值
*/
public double infoAttr(int attrIndex){
double info = 0.000;
ArrayList<String> values = getValues(D, attrIndex);
for (int i = 0; i < values.size(); i++) {
ArrayList<ArrayList<String>> dv = datasOfValue(attrIndex, values.get(i));
info += DecimalCalculate.mul(DecimalCalculate.div(dv.size(), D.size(), 3), infoD(dv));
}
return info;
}
/**
* 获取最佳分裂属性的索引
* @return 最佳分裂属性的索引
*/
public int bestGainAttrIndex(){
int index = -1;
double gain = 0.000;
double tempGain = 0.000;
for (int i = 0; i < attrList.size(); i++) {
tempGain = infoD(D) - infoAttr(i);
if (tempGain > gain) {
gain = tempGain;
index = i;
}
}
return index;
}
}
package DecisionTree;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import javax.smartcardio.*;
/**
* 决策树构造类
*/
public class DecisionTree {
private Integer attrSelMode; //最佳分裂属性选择模式,1表示以信息增益度量,2表示以信息增益率度量。暂未实现2
public DecisionTree(){
this.attrSelMode = 1;
}
public DecisionTree(int attrSelMode) {
this.attrSelMode = attrSelMode;
}
public void setAttrSelMode(Integer attrSelMode) {
this.attrSelMode = attrSelMode;
}
/**
* 获取指定数据集中的类别及其计数
* @param datas 指定的数据集
* @return 类别及其计数的map
*/
public Map<String, Integer> classOfDatas(ArrayList<ArrayList<String>> datas){
Map<String, Integer> classes = new HashMap<String, Integer>();
String c = "";
ArrayList<String> tuple = null;
for (int i = 0; i < datas.size(); i++) {
tuple = datas.get(i);
c = tuple.get(tuple.size() - 1);
if (classes.containsKey(c)) {
classes.put(c, classes.get(c) + 1);
} else {
classes.put(c, 1);
}
}
return classes;
}
/**
* 获取具有最大计数的类名,即求多数类
* @param classes 类的键值集合
* @return 多数类的类名
*/
public String maxClass(Map<String, Integer> classes){
String maxC = "";
int max = -1;
Iterator iter = classes.entrySet().iterator();
for(int i = 0; iter.hasNext(); i++)
{
Map.Entry entry = (Map.Entry) iter.next();
String key = (String)entry.getKey();
Integer val = (Integer) entry.getValue();
if(val > max){
max = val;
maxC = key;
}
}
return maxC;
}
/**
* 构造决策树
* @param datas 训练元组集合
* @param attrList 候选属性集合
* @return 决策树根结点
*/
public TreeNode buildTree(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList){
// System.out.print("候选属性列表: ");
// for (int i = 0; i < attrList.size(); i++) {
// System.out.print(" " + attrList.get(i) + " ");
// }
System.out.println();
TreeNode node = new TreeNode();
node.setDatas(datas);
node.setCandAttr(attrList);
Map<String, Integer> classes = classOfDatas(datas);
String maxC = maxClass(classes);
if (classes.size() == 1 || attrList.size() == 0) {
node.setName(maxC);
return node;
}
Gain gain = new Gain(datas, attrList);
int bestAttrIndex = gain.bestGainAttrIndex();
ArrayList<String> rules = gain.getValues(datas, bestAttrIndex);
node.setRule(rules);
node.setName(attrList.get(bestAttrIndex));
if(rules.size() > 2){ //?此处有待商榷
attrList.remove(bestAttrIndex);
}
for (int i = 0; i < rules.size(); i++) {
String rule = rules.get(i);
ArrayList<ArrayList<String>> di = gain.datasOfValue(bestAttrIndex, rule);
for (int j = 0; j < di.size(); j++) {
di.get(j).remove(bestAttrIndex);
}
if (di.size() == 0) {
TreeNode leafNode = new TreeNode();
leafNode.setName(maxC);
leafNode.setDatas(di);
leafNode.setCandAttr(attrList);
node.getChild().add(leafNode);
} else {
TreeNode newNode = buildTree(di, attrList);
node.getChild().add(newNode);
}
}
return node;
}
}
package DecisionTree;
import java.math.BigDecimal;
public class DecimalCalculate {
/**
* 由于Java的简单类型不能够精确的对浮点数进行运算,这个工具类提供精
* 确的浮点数运算,包括加减乘除和四舍五入。
*/
//默认除法运算精度
private static final int DEF_DIV_SCALE = 10;
//这个类不能实例化
private DecimalCalculate(){
}
/**
* 提供精确的加法运算。
* @param v1 被加数
* @param v2 加数
* @return 两个参数的和
*/
public static double add(double v1,double v2){
BigDecimal b1 = new BigDecimal(Double.toString(v1));
BigDecimal b2 = new BigDecimal(Double.toString(v2));
return b1.add(b2).doubleValue();
}
/**
* 提供精确的减法运算。
* @param v1 被减数
* @param v2 减数
* @return 两个参数的差
*/
public static double sub(double v1,double v2){
BigDecimal b1 = new BigDecimal(Double.toString(v1));
BigDecimal b2 = new BigDecimal(Double.toString(v2));
return b1.subtract(b2).doubleValue();
}
/**
* 提供精确的乘法运算。
* @param v1 被乘数
* @param v2 乘数
* @return 两个参数的积
*/
public static double mul(double v1,double v2){
BigDecimal b1 = new BigDecimal(Doub
展开阅读全文