决策分类树算法之ID3,C4.5算法系列
一、引言
在最开始的时候,我本来准备学习的是C4.5算法,后来发现C4.5算法的核心还是ID3算法,所以又辗转回到学习ID3算法了,因为C4.5是他的一个改进。至于是什么改进,在后面的描述中我会提到。
二、ID3算法
ID3算法是一种分类决策树算法。他通过一系列的规则,将数据最后分类成决策树的形式。分类的根据是用到了熵这个概念。熵在物理这门学科中就已经出现过,表示是一个物质的稳定度,在这里就是分类的纯度的一个概念。公式为:
在ID3算法中,是采用Gain信息增益来作为一个分类的判定标准的。他的定义为:
每次选择属性中信息增益最大作为划分属性,在这里本人实现了一个java版本的ID3算法,为了模拟数据的可操作性,就把数据写到一个input.txt文件中,作为数据源,格式如下:
[java] view plaincopyprint?
-
Day OutLook Temperature Humidity Wind PlayTennis
-
1 Sunny Hot High Weak No
-
2 Sunny Hot High Strong No
-
3 Overcast Hot High Weak Yes
-
4 Rainy Mild High Weak Yes
-
5 Rainy Cool Normal Weak Yes
-
6 Rainy Cool Normal Strong No
-
7 Overcast Cool Normal Strong Yes
-
8 Sunny Mild High Weak No
-
9 Sunny Cool Normal Weak Yes
-
10 Rainy Mild Normal Weak Yes
-
11 Sunny Mild Normal Strong Yes
-
12 Overcast Mild High Strong Yes
-
13 Overcast Hot Normal Weak Yes
-
14 Rainy Mild High Strong No
PalyTennis属性为结构属性,是作为类标识用的,中间的OutLool,Temperature,Humidity,Wind才是划分属性,通过将源数据与执行程序分类,这样可以模拟巨大的数据量了。下面是ID3的主程序类,本人将ID3的算法进行了包装,对外只开放了一个构建决策树的方法,在构造函数时候,只需传入一个数据路径文件即可:
[java] view plaincopyprint?
-
package DataMing_ID3;
-
-
import java.io.BufferedReader;
-
import java.io.File;
-
import java.io.FileReader;
-
import java.io.IOException;
-
import java.util.ArrayList;
-
import java.util.HashMap;
-
import java.util.Iterator;
-
import java.util.Map;
-
import java.util.Map.Entry;
-
import java.util.Set;
-
-
/**
-
* ID3算法实现类
-
*
-
* @author lyq
-
*
-
*/
-
public class ID3Tool {
-
// 类标号的值类型
-
private final String YES = "Yes";
-
private final String NO = "No";
-
-
// 所有属性的类型总数,在这里就是data源数据的列数
-
private int attrNum;
-
private String filePath;
-
// 初始源数据,用一个二维字符数组存放模仿表格数据
-
private String[][] data;
-
// 数据的属性行的名字
-
private String[] attrNames;
-
// 每个属性的值所有类型
-
private HashMap<String, ArrayList<String>> attrValue;
-
-
public ID3Tool(String filePath) {
-
this.filePath = filePath;
-
attrValue = new HashMap<>();
-
}
-
-
/**
-
* 从文件中读取数据
-
*/
-
private void readDataFile() {
-
File file = new File(filePath);
-
ArrayList<String[]> dataArray = new ArrayList<String[]>();
-
-
try {
-
BufferedReader in = new BufferedReader(new FileReader(file));
-
String str;
-
String[] tempArray;
-
while ((str = in.readLine()) != null) {
-
tempArray = str.split(" ");
-
dataArray.add(tempArray);
-
}
-
in.close();
-
} catch (IOException e) {
-
e.getStackTrace();
-
}
-
-
data = new String[dataArray.size()][];
-
dataArray.toArray(data);
-
attrNum = data[0].length;
-
attrNames = data[0];
-
-
/*
-
* for(int i=0; i<data.length;i++){ for(int j=0; j<data[0].length; j++){
-
* System.out.print(" " + data[i][j]); }
-
*
-
* System.out.print("\n"); }
-
*/
-
}
-
-
/**
-
* 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
-
*/
-
private void initAttrValue() {
-
ArrayList<String> tempValues;
-
-
// 按照列的方式,从左往右找
-
for (int j = 1; j < attrNum; j++) {
-
// 从一列中的上往下开始寻找值
-
tempValues = new ArrayList<>();
-
for (int i = 1; i < data.length; i++) {
-
if (!tempValues.contains(data[i][j])) {
-
// 如果这个属性的值没有添加过,则添加
-
tempValues.add(data[i][j]);
-
}
-
}
-
-
// 一列属性的值已经遍历完毕,复制到map属性表中
-
attrValue.put(data[0][j], tempValues);
-
}
-
-
/*
-
* for(Map.Entry entry : attrValue.entrySet()){
-
* System.out.println("key:value " + entry.getKey() + ":" +
-
* entry.getValue()); }
-
*/
-
}
-
-
/**
-
* 计算数据按照不同方式划分的熵
-
*
-
* @param remainData
-
* 剩余的数据
-
* @param attrName
-
* 待划分的属性,在算信息增益的时候会使用到
-
* @param attrValue
-
* 划分的子属性值
-
* @param isParent
-
* 是否分子属性划分还是原来不变的划分
-
*/
-
private double computeEntropy(String[][] remainData, String attrName,
-
String value, boolean isParent) {
-
// 实例总数
-
int total = 0;
-
// 正实例数
-
int posNum = 0;
-
// 负实例数
-
int negNum = 0;
-
-
// 还是按列从左往右遍历属性
-
for (int j = 1; j < attrNames.length; j++) {
-
// 找到了指定的属性
-
if (attrName.equals(attrNames[j])) {
-
for (int i = 1; i < remainData.length; i++) {
-
// 如果是父结点直接计算熵或者是通过子属性划分计算熵,这时要进行属性值的过滤
-
if (isParent
-
|| (!isParent && remainData[i][j].equals(value))) {
-
if (remainData[i][attrNames.length - 1].equals(YES)) {
-
// 判断此行数据是否为正实例
-
posNum++;
-
} else {
-
negNum++;
-
}
-
}
-
}
-
}
-
}
-
-
total = posNum + negNum;
-
double posProbobly = (double) posNum / total;
-
double negProbobly = (double) negNum / total;
-
-
if (posProbobly == 1 || posProbobly == 0) {
-
// 如果数据全为同种类型,则熵为0,否则带入下面的公式会报错
-
return 0;
-
}
-
-
double entropyValue = -posProbobly * Math.log(posProbobly)
-
/ Math.log(2.0) - negProbobly * Math.log(negProbobly)
-
/ Math.log(2.0);
-
-
// 返回计算所得熵
-
return entropyValue;
-
}
-
-
/**
-
* 为某个属性计算信息增益
-
*
-
* @param remainData
-
* 剩余的数据
-
* @param value
-
* 待划分的属性名称
-
* @return
-
*/
-
private double computeGain(String[][] remainData, String value) {
-
double gainValue = 0;
-
// 源熵的大小将会与属性划分后进行比较
-
double entropyOri = 0;
-
// 子划分熵和
-
double childEntropySum = 0;
-
// 属性子类型的个数
-
int childValueNum = 0;
-
// 属性值的种数
-
ArrayList<String> attrTypes = attrValue.get(value);
-
// 子属性对应的权重比
-
HashMap<String, Integer> ratioValues = new HashMap<>();
-
-
for (int i = 0; i < attrTypes.size(); i++) {
-
// 首先都统一计数为0
-
ratioValues.put(attrTypes.get(i), 0);
-
}
-
-
// 还是按照一列,从左往右遍历
-
for (int j = 1; j < attrNames.length; j++) {
-
// 判断是否到了划分的属性列
-
if (value.equals(attrNames[j])) {
-
for (int i = 1; i <= remainData.length - 1; i++) {
-
childValueNum = ratioValues.get(remainData[i][j]);
-
// 增加个数并且重新存入
-
childValueNum++;
-
ratioValues.put(remainData[i][j], childValueNum);
-
}
-
}
-
}
-
-
// 计算原熵的大小
-
entropyOri = computeEntropy(remainData, value, null, true);
-
for (int i = 0; i < attrTypes.size(); i++) {
-
double ratio = (double) ratioValues.get(attrTypes.get(i))
-
/ (remainData.length - 1);
-
childEntropySum += ratio
-
* computeEntropy(remainData, value, attrTypes.get(i), false);
-
-
// System.out.println("ratio:value: " + ratio + " " +
-
// computeEntropy(remainData, value,
-
// attrTypes.get(i), false));
-
}
-
-
// 二者熵相减就是信息增益
-
gainValue = entropyOri - childEntropySum;
-
return gainValue;
-
}
-
-
/**
-
* 计算信息增益比
-
*
-
* @param remainData
-
* 剩余数据
-
* @param value
-
* 待划分属性
-
* @return
-
*/
-
private double computeGainRatio(String[][] remainData, String value) {
-
double gain = 0;
-
double spiltInfo = 0;
-
int childValueNum = 0;
-
// 属性值的种数
-
ArrayList<String> attrTypes = attrValue.get(value);
-
// 子属性对应的权重比
-
HashMap<String, Integer> ratioValues = new HashMap<>();
-
-
for (int i = 0; i < attrTypes.size(); i++) {
-
// 首先都统一计数为0
-
ratioValues.put(attrTypes.get(i), 0);
-
}
-
-
// 还是按照一列,从左往右遍历
-
for (int j = 1; j < attrNames.length; j++) {
-
// 判断是否到了划分的属性列
-
if (value.equals(attrNames[j])) {
-
for (int i = 1; i <= remainData.length - 1; i++) {
-
childValueNum = ratioValues.get(remainData[i][j]);
-
// 增加个数并且重新存入
-
childValueNum++;
-
ratioValues.put(remainData[i][j], childValueNum);
-
}
-
}
-
}
-
-
// 计算信息增益
-
gain = computeGain(remainData, value);
-
// 计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):
-
for (int i = 0; i < attrTypes.size(); i++) {
-
double ratio = (double) ratioValues.get(attrTypes.get(i))
-
/ (remainData.length - 1);
-
spiltInfo += -ratio * Math.log(ratio) / Math.log(2.0);
-
}
-
-
// 计算机信息增益率
-
return gain / spiltInfo;
-
}
-
-
/**
-
* 利用源数据构造决策树
-
*/
-
private void buildDecisionTree(AttrNode node, String parentAttrValue,
-
String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {
-
node.setParentAttrValue(parentAttrValue);
-
-
String attrName = "";
-
double gainValue = 0;
-
double tempValue = 0;
-
-
// 如果只有1个属性则直接返回
-
if (remainAttr.size() == 1) {
-
System.out.println("attr null");
-
return;
-
}
-
-
// 选择剩余属性中信息增益最大的作为下一个分类的属性
-
for (int i = 0; i < remainAttr.size(); i++) {
-
// 判断是否用ID3算法还是C4.5算法
-
if (isID3) {
-
// ID3算法采用的是按照信息增益的值来比
-
tempValue = computeGain(remainData, remainAttr.get(i));
-
} else {
-
// C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
-
tempValue = computeGainRatio(remainData, remainAttr.get(i));
-
}
-
-
if (tempValue > gainValue) {
-
gainValue = tempValue;
-
attrName = remainAttr.get(i);
-
}
-
}
-
-
node.setAttrName(attrName);
-
ArrayList<String> valueTypes = attrValue.get(attrName);
-
remainAttr.remove(attrName);
-
-
AttrNode[] childNode = new AttrNode[valueTypes.size()];
-
String[][] rData;
-
for (int i = 0; i < valueTypes.size(); i++) {
-
// 移除非此值类型的数据
-
rData = removeData(remainData, attrName, valueTypes.get(i));
-
-
childNode[i] = new AttrNode();
-
boolean sameClass = true;
-
ArrayList<String> indexArray = new ArrayList<>();
-
for (int k = 1; k < rData.length; k++) {
-
indexArray.add(rData[k][0]);
-
// 判断是否为同一类的
-
if (!rData[k][attrNames.length - 1]
-
.equals(rData[1][attrNames.length - 1])) {
-
// 只要有1个不相等,就不是同类型的
-
sameClass = false;
-
break;
-
}
-
}
-
-
if (!sameClass) {
-
// 创建新的对象属性,对象的同个引用会出错
-
ArrayList<String> rAttr = new ArrayList<>();
-
for (String str : remainAttr) {
-
rAttr.add(str);
-
}
-
-
buildDecisionTree(childNode[i], valueTypes.get(i), rData,
-
rAttr, isID3);
-
} else {
-
// 如果是同种类型,则直接为数据节点
-
childNode[i].setParentAttrValue(valueTypes.get(i));
-
childNode[i].setChildDataIndex(indexArray);
-
}
-
-
}
-
node.setChildAttrNode(childNode);
-
}
-
-
/**
-
* 属性划分完毕,进行数据的移除
-
*
-
* @param srcData
-
* 源数据
-
* @param attrName
-
* 划分的属性名称
-
* @param valueType
-
* 属性的值类型
-
*/
-
private String[][] removeData(String[][] srcData, String attrName,
-
String valueType) {
-
String[][] desDataArray;
-
ArrayList<String[]> desData = new ArrayList<>();
-
// 待删除数据
-
ArrayList<String[]> selectData = new ArrayList<>();
-
selectData.add(attrNames);
-
-
// 数组数据转化到列表中,方便移除
-
for (int i = 0; i < srcData.length; i++) {
-
desData.add(srcData[i]);
-
}
-
-
// 还是从左往右一列列的查找
-
for (int j = 1; j < attrNames.length; j++) {
-
if (attrNames[j].equals(attrName)) {
-
for (int i = 1; i < desData.size(); i++) {
-
if (desData.get(i)[j].equals(valueType)) {
-
// 如果匹配这个数据,则移除其他的数据
-
selectData.add(desData.get(i));
-
}
-
}
-
}
-
}
-
CDA数据分析师考试相关入口一览(建议收藏):
▷ 想报名CDA认证考试,点击>>>
“CDA报名”
了解CDA考试详情;
▷ 想加入CDA考试题库,点击>>> “CDA题库” 了解CDA考试详情;
▷ 想学习CDA考试教材,点击>>> “CDA教材” 了解CDA考试详情;
▷ 想查询CDA考试成绩,点击>>> “CDA成绩” 了解CDA考试详情;
▷ 想了解CDA考试含金量,点击>>> “CDA含金量” 了解CDA考试详情;
▷ 想获取CDA考试时间/费用/条件/大纲/通过率,点击 >>>“CDA考试官网” 了解CDA考试详情;