1、
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
class TreeNode
{
String element; // 该值为数据的属性名称
String value; // 上一个分裂属性在此结点的值
LinkedHashSet
2、 this.element=null; this.value=null; this.childs=null; } public TreeNode(String value) { this.element=null; this.value=value; this.childs=null; } public String getElement() { return this.element; }
3、
public void setElement(String e)
{
this.element=e;
}
public String getValue()
{
return this.value;
}
public void setValue(String v)
{
this.value=v;
}
public LinkedHashSet
4、this.childs;
}
public void setChilds(LinkedHashSet
5、eNode root) { this.root=root; } public TreeNode getRoot() { return root; } public void setRoot(TreeNode root) { this.root=root; } public String selectAtrribute(TreeNode node, String[][] deData, boolean flags[],
6、 LinkedHashSet
7、eturn_atrribute=null; // 计算每个未分类属性的 Gain值 int count=0; // 计算到第几个属性 for(String atrribute : atrributes) { // 该属性有多少个值,该属性有多少个分类 int values_count, class_count; // 属性值对应的下标 int index=attrIndexMap.g
8、et(atrribute);
// 存放属性的各个值和分类值
LinkedHashSet
9、 { values.add(deData[i][index]); classes.add(deData[i][class_index]); } } values_count=values.size(); class_count=classes.size(); int values_vector[]=new int[values_count
10、 class_count]; int class_vector[]=new int[class_count]; for(int i=0; i < deData.length; i++) { if(flags[i] == true) { int j=0; for(String v : values) {
11、 if(deData[i][index].equals(v)) { break; } else { j++; } }
12、 int k=0; for(String c : classes) { if(deData[i][class_index].equals(c)) { break; } else {
13、 k++; } } values_vector[j * class_count + k]++; class_vector[k]++; } } double InfoD=0.0; double class_total=0.0;
14、 for(int i=0; i < class_vector.length; i++) { class_total+=class_vector[i]; } for(int i=0; i < class_vector.length; i++) { if(class_vector[i] == 0) { continue;
15、 } else { double d=Math.log(class_vector[i] / class_total) / Math.log(2.0) * class_vector[i] / class_total; InfoD=InfoD - d; } } // 计算InfoA double InfoA=0.0;
16、 for(int i=0; i < values_count; i++) { double middle=0.0; double attr_count=0.0; for(int j=0; j < class_count; j++) { attr_count+=values_vector[i * class_count + j]; }
17、 for(int j=0; j < class_count; j++) { if(values_vector[i * class_count + j] != 0) { double k=values_vector[i * class_count + j]; middle=middle - Math.log(k / attr_count) / Math.log(2.0)
18、 * k / attr_count; } } InfoA+=middle * attr_count / class_total; } Gain[count]=InfoD - InfoA; count++; } double max=0.0; int i=0; for(String atrribute : atrributes)
19、 { if(Gain[i] > max) { max=Gain[i]; return_atrribute=atrribute; } i++; } return return_atrribute; } public void buildDecisionTree(TreeNode node, String[][] deData, boolean fla
20、gs[],
LinkedHashSet
21、classIndex=deData[0].length - 1; for(int i=0; i < deData.length; i++) { if(flags[i] == true) { if(classMap.containsKey(deData[i][classIndex])) { int count=classMap.get(deDat
22、a[i][classIndex]); classMap.put(deData[i][classIndex], count + 1); } else { classMap.put(deData[i][classIndex], 1); } } }
23、 // 选择多数类
String mostClass=null;
int mostCount=0;
Iterator
24、 { mostClass=strClass; mostCount=classMap.get(strClass); } } // 对结点进行赋值,该结点为叶结点 node.setElement(mostClass); node.setChilds(null); System.out.println("yezhi:" + node.getEle
25、ment() + ":" + node.getValue());
return;
}
// 如果待分类数据全都属于一个类
int class_index=deData[0].length - 1;
String class_name=null;
HashSet
26、f(flags[i] == true) { class_name=deData[i][class_index]; classSet.add(class_name); } } // 则该结点为叶结点,设置有关值,然后返回 if(classSet.size() == 1) { node.setElement(class_name); node.setChil
27、ds(null); System.out.println("leaf:" + node.getElement() + ":" + node.getValue()); return; } // 给定的分枝没有元组,是不是有这种情况? // 选择一个分类属性 String attribute=selectAtrribute(node, deData, flags, attributes, attrIndexMap); // 设置分裂结点的值
28、 node.setElement(attribute); // System.out.println(attribute); if(node == root) { System.out.println("root:" + node.getElement() + ":" + node.getValue()); } else { System.out.println("branch:" + node.getElement() + "
29、" + node.getValue());
}
// 生成和设置各个子结点
int attrIndex=attrIndexMap.get(attribute);
LinkedHashSet
30、 attrValues.add(deData[i][attrIndex]);
}
}
LinkedHashSet
31、Childs(childs); // 在候选分类属性中删除当前属性 attributes.remove(attribute); // 在各个子结点上递归调用本函数 if(childs.isEmpty() != true) { for(TreeNode child : childs) { // 设置子结点待分类的数据集 boolean newFlags[]=new boolean[deData.length];
32、 for(int i=0; i < deData.length; i++) { newFlags[i]=flags[i]; if(deData[i][attrIndex] != child.getValue()) { newFlags[i]=false; } }
33、 // 设置子结点待分类的属性集
LinkedHashSet
34、ree(child, deData, newFlags, newAttributes, attrIndexMap); } } } // 输出决策树 public void printDecisionTree() {} } public class Data2 { public static void main(String[] args) { /* * //输入数据集1 String deData[][] = new String[12][]; deData[0
35、] = new * String[]{"Yes","No","No","Yes","Some","high","No","Yes","French","0~10","Yes"}; deData[1] = * new String[]{"Yes","No","No","Yes","Full","low","No","No","Thai","30~60","No"}; deData[2] = * new String[]{"No","Yes","No","No","Some","low","No","No","Burger","0~10","Yes"
36、}; deData[3] * = new String[]{"Yes","No","Yes","Yes","Full","low","Yes","No","Thai","10~30","Yes"}; * deData[4] = new * String[]{"Yes","No","Yes","No","Full","high","No","Yes","French",">60","No"}; deData[5] = * new String[]{"No","Yes","No","Yes","Some","middle","Yes"
37、"Yes","Italian","0~10","Yes"}; * deData[6] = new * String[]{"No","Yes","No","No","None","low","Yes","No","Burger","0~10","No"}; deData[7] = * new String[]{"No","No","No","Yes","Some","middle","Yes","Yes","Thai","0~10","Yes"}; * deData[8] = new * String[]{"No"
38、"Yes","Yes","No","Full","low","Yes","No","Burger",">60","No"}; deData[9] = * new String[]{"Yes","Yes","Yes","Yes","Full","high","No","Yes","Italian","10~30","No"}; * deData[10]= new String[]{"No","No","No","No","None","low","No","No","Thai","0~10","No"}; * deData[11]= new
39、 * String[]{"Yes","Yes","Yes","Yes","Full","low","No","No","Burger","30~60","Yes"}; //待分类的属性集1 * String attr[] = new String[]{"alt", "bar", "fri", "hun", "pat", "price", "rain", "res", * "type", "est"}; */// 输入数据集2 String deData[][]=new String[14][]; de
40、Data[0]=new String[] { "youth", "high", "no", "fair", "no" }; deData[1]=new String[] { "youth", "high", "no", "excellent", "no" }; deData[2]=new String[] { "middle_aged", "high", "no", "fair", "yes" }; deData[3]=new String[] { "senior",
41、"medium", "no", "fair", "yes" }; deData[4]=new String[] { "senior", "low", "yes", "fair", "yes" }; deData[5]=new String[] { "senior", "low", "yes", "excellent", "no" }; deData[6]=new String[] { "middle_aged", "low", "yes", "excellent", "yes" };
42、 deData[7]=new String[] { "youth", "medium", "no", "fair", "no" }; deData[8]=new String[] { "youth", "low", "yes", "fair", "yes" }; deData[9]=new String[] { "senior", "medium", "yes", "fair", "yes" }; deData[10]=new String[] { "yo
43、uth", "medium", "yes", "excellent", "yes" }; deData[11]=new String[] { "middle_aged", "medium", "no", "excellent", "yes" }; deData[12]=new String[] { "middle_aged", "high", "yes", "fair", "yes" }; deData[13]=new String[] { "senior", "medium", "no
44、", "excellent", "no" };
// 待分类的属性集2
String attr[]=new String[]
{ "age", "income", "student", "credit_rating" };
LinkedHashSet
45、i]);
} // 属性与数据集中对应数据的下标
HashMap
46、 for(int i=0; i < deData.length; i++) { flags[i]=true; } // 构造决策树 TreeNode root=new TreeNode(); DecisionTree decisionTree=new DecisionTree(root); decisionTree.buildDecisionTree(root, deData, flags, attributes, attrIndexMap); } }






