热线电话:13121318867

登录
首页精彩阅读CART分类回归树算法
CART分类回归树算法
2015-12-03
收藏

CART分类回归树算法


    CART分类回归树算法

与上次文章中提到的ID3算法和C4.5算法类似,CART算法也是一种决策树分类算法。CART分类回归树算法的本质也是对数据进行分类的,最终数据的表现形式也是以树形的模式展现的,与ID3,C4.5算法不同的是,他的分类标准所采用的算法不同了。下面列出了其中的一些不同之处:

1、CART最后形成的树是一个二叉树,每个节点会分成2个节点,左孩子节点和右孩子节点,而在ID3和C4.5中是按照分类属性的值类型进行划分,于是这就要求CART算法在所选定的属性中又要划分出最佳的属性划分值,节点如果选定了划分属性名称还要确定里面按照那个值做一个二元的划分。

2、CART算法对于属性的值采用的是基于Gini系数值的方式做比较,gini某个属性的某次值的划分的gini指数的值为:

,pk就是分别为正负实例的概率,gini系数越小说明分类纯度越高,可以想象成与熵的定义一样。因此在最后计算的时候我们只取其中值最小的做出划分。最后做比较的时候用的是gini的增益做比较,要对分类号的数据做出一个带权重的gini指数的计算。举一个网上的一个例子:

比如体温为恒温时包含哺乳类5个、鸟类2个,则:

体温为非恒温时包含爬行类3个、鱼类3个、两栖类2个,则

所以如果按照“体温为恒温和非恒温”进行划分的话,我们得到GINI的增益(类比信息增益):

最好的划分就是使得GINI_Gain最小的划分。

通过比较每个属性的最小的gini指数值,作为最后的结果。

3、CART算法在把数据进行分类之后,会对树进行一个剪枝,常用的用前剪枝和后剪枝法,而常见的后剪枝发包括代价复杂度剪枝,悲观误差剪枝等等,我写的此次算法采用的是代价复杂度剪枝法。代价复杂度剪枝的算法公式为:

α表示的是每个非叶子节点的误差增益率,可以理解为误差代价,最后选出误差代价最小的一个节点进行剪枝。

里面变量的意思为:

是子树中包含的叶子节点个数;

是节点t的误差代价,如果该节点被剪枝;

r(t)是节点t的误差率;

p(t)是节点t上的数据占所有数据的比例。

是子树Tt的误差代价,如果该节点不被剪枝。它等于子树Tt上所有叶子节点的误差代价之和。下面说说我对于这个公式的理解:其实这个公式的本质是对于剪枝前和剪枝后的样本偏差率做一个差值比较,一个好的分类当然是分类后的样本偏差率相较于没分类(就是剪枝掉的时候)的偏差率小,所以这时的值就会大,如果分类前后基本变化不大,则意味着分类不起什么效果,α值的分子位置就小,所以误差代价就小,可以被剪枝。但是一般分类后的偏差率会小于分类前的,因为偏差数在高层节点的时候肯定比子节点的多,子节点偏差数最多与父亲节点一样。

CART算法实现

首先是程序的备用数据,我是把他存在了一个文字中,通过程序进行逐行的读取:

[java] view plaincopyprint?
  1. Rid Age Income Student CreditRating BuysComputer  
  2. 1 Youth High No Fair No  
  3. 2 Youth High No Excellent No  
  4. 3 MiddleAged High No Fair Yes  
  5. 4 Senior Medium No Fair Yes  
  6. 5 Senior Low Yes Fair Yes  
  7. 6 Senior Low Yes Excellent No  
  8. 7 MiddleAged Low Yes Excellent Yes  
  9. 8 Youth Medium No Fair No  
  10. 9 Youth Low Yes Fair Yes  
  11. 10 Senior Medium Yes Fair Yes  
  12. 11 Youth Medium Yes Excellent Yes  
  13. 12 MiddleAged Medium No Excellent Yes  
  14. 13 MiddleAged High Yes Fair Yes  
  15. 14 Senior Medium No Excellent No  

下面是主程序,里面有具体的注释:

[java] view plaincopyprint?
  1. package DataMing_CART;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.util.ArrayList;  
  8. import java.util.HashMap;  
  9. import java.util.LinkedList;  
  10. import java.util.Map;  
  11. import java.util.Queue;  
  12.   
  13. import javax.lang.model.element.NestingKind;  
  14. import javax.swing.text.DefaultEditorKit.CutAction;  
  15. import javax.swing.text.html.MinimalHTMLWriter;  
  16.   
  17. /** 
  18.  * CART分类回归树算法工具类 
  19.  *  
  20.  * @author lyq 
  21.  *  
  22.  */  
  23. public class CARTTool {  
  24.     // 类标号的值类型  
  25.     private final String YES = "Yes";  
  26.     private final String NO = "No";  
  27.   
  28.     // 所有属性的类型总数,在这里就是data源数据的列数  
  29.     private int attrNum;  
  30.     private String filePath;  
  31.     // 初始源数据,用一个二维字符数组存放模仿表格数据  
  32.     private String[][] data;  
  33.     // 数据的属性行的名字  
  34.     private String[] attrNames;  
  35.     // 每个属性的值所有类型  
  36.     private HashMap<String, ArrayList<String>> attrValue;  
  37.   
  38.     public CARTTool(String filePath) {  
  39.         this.filePath = filePath;  
  40.         attrValue = new HashMap<>();  
  41.     }  
  42.   
  43.     /** 
  44.      * 从文件中读取数据 
  45.      */  
  46.     public void readDataFile() {  
  47.         File file = new File(filePath);  
  48.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  49.   
  50.         try {  
  51.             BufferedReader in = new BufferedReader(new FileReader(file));  
  52.             String str;  
  53.             String[] tempArray;  
  54.             while ((str = in.readLine()) != null) {  
  55.                 tempArray = str.split(" ");  
  56.                 dataArray.add(tempArray);  
  57.             }  
  58.             in.close();  
  59.         } catch (IOException e) {  
  60.             e.getStackTrace();  
  61.         }  
  62.   
  63.         data = new String[dataArray.size()][];  
  64.         dataArray.toArray(data);  
  65.         attrNum = data[0].length;  
  66.         attrNames = data[0];  
  67.   
  68.         /* 
  69.          * for (int i = 0; i < data.length; i++) { for (int j = 0; j < 
  70.          * data[0].length; j++) { System.out.print(" " + data[i][j]); } 
  71.          * System.out.print("\n"); } 
  72.          */  
  73.   
  74.     }  
  75.   
  76.     /** 
  77.      * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用 
  78.      */  
  79.     public void initAttrValue() {  
  80.         ArrayList<String> tempValues;  
  81.   
  82.         // 按照列的方式,从左往右找  
  83.         for (int j = 1; j < attrNum; j++) {  
  84.             // 从一列中的上往下开始寻找值  
  85.             tempValues = new ArrayList<>();  
  86.             for (int i = 1; i < data.length; i++) {  
  87.                 if (!tempValues.contains(data[i][j])) {  
  88.                     // 如果这个属性的值没有添加过,则添加  
  89.                     tempValues.add(data[i][j]);  
  90.                 }  
  91.             }  
  92.   
  93.             // 一列属性的值已经遍历完毕,复制到map属性表中  
  94.             attrValue.put(data[0][j], tempValues);  
  95.         }  
  96.   
  97.         /* 
  98.          * for (Map.Entry entry : attrValue.entrySet()) { 
  99.          * System.out.println("key:value " + entry.getKey() + ":" + 
  100.          * entry.getValue()); } 
  101.          */  
  102.     }  
  103.   
  104.     /** 
  105.      * 计算机基尼指数 
  106.      *  
  107.      * @param remainData 
  108.      *            剩余数据 
  109.      * @param attrName 
  110.      *            属性名称 
  111.      * @param value 
  112.      *            属性值 
  113.      * @param beLongValue 
  114.      *            分类是否属于此属性值 
  115.      * @return 
  116.      */  
  117.     public double computeGini(String[][] remainData, String attrName,  
  118.             String value, boolean beLongValue) {  
  119.         // 实例总数  
  120.         int total = 0;  
  121.         // 正实例数  
  122.         int posNum = 0;  
  123.         // 负实例数  
  124.         int negNum = 0;  
  125.         // 基尼指数  
  126.         double gini = 0;  
  127.   
  128.         // 还是按列从左往右遍历属性  
  129.         for (int j = 1; j < attrNames.length; j++) {  
  130.             // 找到了指定的属性  
  131.             if (attrName.equals(attrNames[j])) {  
  132.                 for (int i = 1; i < remainData.length; i++) {  
  133.                     // 统计正负实例按照属于和不属于值类型进行划分  
  134.                     if ((beLongValue && remainData[i][j].equals(value))  
  135.                             || (!beLongValue && !remainData[i][j].equals(value))) {  
  136.                         if (remainData[i][attrNames.length - 1].equals(YES)) {  
  137.                             // 判断此行数据是否为正实例  
  138.                             posNum++;  
  139.                         } else {  
  140.                             negNum++;  
  141.                         }  
  142.                     }  
  143.                 }  
  144.             }  
  145.         }  
  146.   
  147.         total = posNum + negNum;  
  148.         double posProbobly = (double) posNum / total;  
  149.         double negProbobly = (double) negNum / total;  
  150.         gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;  
  151.   
  152.         // 返回计算基尼指数  
  153.         return gini;  
  154.     }  
  155.   
  156.     /** 
  157.      * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中 
  158.      *  
  159.      * @param remainData 
  160.      *            剩余谁 
  161.      * @param attrName 
  162.      *            属性名称 
  163.      * @return 
  164.      */  
  165.     public String[] computeAttrGini(String[][] remainData, String attrName) {  
  166.         String[] str = new String[2];  
  167.         // 最终该属性的划分类型值  
  168.         String spiltValue = "";  
  169.         // 临时变量  
  170.         int tempNum = 0;  
  171.         // 保存属性的值划分时的最小的基尼指数  
  172.         double minGini = Integer.MAX_VALUE;  
  173.         ArrayList<String> valueTypes = attrValue.get(attrName);  
  174.         // 属于此属性值的实例数  
  175.         HashMap<String, Integer> belongNum = new HashMap<>();  
  176.   
  177.         for (String string : valueTypes) {  
  178.             // 重新计数的时候,数字归0  
  179.             tempNum = 0;  
  180.             // 按列从左往右遍历属性  
  181.             for (int j = 1; j < attrNames.length; j++) {  
  182.                 // 找到了指定的属性  
  183.                 if (attrName.equals(attrNames[j])) {  
  184.                     for (int i = 1; i < remainData.length; i++) {  
  185.                         // 统计正负实例按照属于和不属于值类型进行划分  
  186.                         if (remainData[i][j].equals(string)) {  
  187.                             tempNum++;  
  188.                         }  
  189.                     }  
  190.                 }  
  191.             }  
  192.   
  193.             belongNum.put(string, tempNum);  
  194.         }  
  195.   
  196.         double tempGini = 0;  
  197.         double posProbably = 1.0;  
  198.         double negProbably = 1.0;  
  199.         for (String string : valueTypes) {  
  200.             tempGini = 0;  
  201.   
  202.             posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);  
  203.             negProbably = 1 - posProbably;  
  204.   
  205.             tempGini += posProbably  
  206.                     * computeGini(remainData, attrName, string, true);  
  207.             tempGini += negProbably  
  208.                     * computeGini(remainData, attrName, string, false);  
  209.   
  210.             if (tempGini < minGini) {  
  211.                 minGini = tempGini;  
  212.                 spiltValue = string;  
  213.             }  
  214.         }  
  215.   
  216.         str[0] = spiltValue;  
  217.         str[1] = minGini + "";  
  218.   
  219.         return str;  
  220.     }  
  221.   
  222.     public void buildDecisionTree(AttrNode node, String parentAttrValue,  
  223.             String[][] remainData, ArrayList<String> remainAttr,  
  224.             boolean beLongParentValue) {  
  225.         // 属性划分值  
  226.         String valueType = "";  
  227.         // 划分属性名称  
  228.         String spiltAttrName = "";  
  229.         double minGini = Integer.MAX_VALUE;  
  230.         double tempGini = 0;  
  231.         // 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值  
  232.         String[] giniArray;  
  233.   
  234.         if (beLongParentValue) {  
  235.             node.setParentAttrValue(parentAttrValue);  
  236.         } else {  
  237.             node.setParentAttrValue("!" + parentAttrValue);  
  238.         }  
  239.   
  240.         if (remainAttr.size() == 0) {  
  241.             if (remainData.length > 1) {  
  242.                 ArrayList<String> indexArray = new ArrayList<>();  
  243.                 for (int i = 1; i < remainData.length; i++) {  
  244.                     indexArray.add(remainData[i][0]);  
  245.                 }  
  246.                 node.setDataIndex(indexArray);  
  247.             }  
  248.             System.out.println("attr remain null");  
  249.             return;  
  250.         }  
  251.   
  252.         for (String str : remainAttr) {  
  253.             giniArray = computeAttrGini(remainData, str);  
  254.             tempGini = Double.parseDouble(giniArray[1]);  
  255.   
  256.             if (tempGini < minGini) {  
  257.                 spiltAttrName = str;  
  258.                 minGini = tempGini;  
  259.                 valueType = giniArray[0];  
  260.             }  
  261.         }  
  262.         // 移除划分属性  
  263.         remainAttr.remove(spiltAttrName);  
  264.         node.setAttrName(spiltAttrName);  
  265.   
  266.         // 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点  
  267.         AttrNode[] childNode = new AttrNode[2];  
  268.         String[][] rData;  
  269.   
  270.         boolean[] bArray = new boolean[] { true, false };  
  271.         for (int i = 0; i < bArray.length; i++) {  
  272.             // 二元划分属于属性值的划分  
  273.             rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);  
  274.   
  275.             boolean sameClass = true;  
  276.             ArrayList<String> indexArray = new ArrayList<>();  
  277.             for (int k = 1; k < rData.length; k++) {  
  278.                 indexArray.add(rData[k][0]);  
  279.                 // 判断是否为同一类的  
  280.                 if (!rData[k][attrNames.length - 1]  
  281.                         .equals(rData[1][attrNames.length - 1])) {  
  282.                     // 只要有1个不相等,就不是同类型的  
  283.                     sameClass = false;  
  284.                     break;  
  285.                 }  
  286.             }  
  287.   
  288.             childNode[i] = new AttrNode();  
  289.             if (!sameClass) {  
  290.                 // 创建新的对象属性,对象的同个引用会出错  
  291.                 ArrayList<String> rAttr = new ArrayList<>();  
  292.                 for (String str : remainAttr) {  
  293.                     rAttr.add(str);  
  294.                 }  
  295.                 buildDecisionTree(childNode[i], valueType, rData, rAttr,  
  296.                         bArray[i]);  
  297.             } else {  
  298.                 String pAtr = (bArray[i] ? valueType : "!" + valueType);  
  299.                 childNode[i].setParentAttrValue(pAtr);  
  300.                 childNode[i].setDataIndex(indexArray);  
  301.             }  
  302.         }  
  303.   
  304.         node.setChildAttrNode(childNode);  
  305.     }  
  306.   
  307.     /** 
  308.      * 属性划分完毕,进行数据的移除 
  309.      *  
  310.      * @param srcData 
  311.      *            源数据 
  312.      * @param attrName 
  313.      *            划分的属性名称 
  314.      * @param valueType 
  315.      *            属性的值类型 
  316.      * @parame beLongValue 分类是否属于此值类型 
  317.      */  
  318.     private String[][] removeData(String[][] srcData, String attrName,  
  319.             String valueType, boolean beLongValue) {  
  320.         String[][] desDataArray;  
  321.         ArrayList<String[]> desData = new ArrayList<>();  
  322.         // 待删除数据  
  323.         ArrayList<String[]> selectData = new ArrayList<>();  
  324.         selectData.add(attrNames);  
  325.   
  326.         // 数组数据转化到列表中,方便移除  
  327.         for (int i = 0; i < srcData.length; i++) {  
  328.             desData.add(srcData[i]);  
  329.         }  
  330.   
  331.         // 还是从左往右一列列的查找  
  332.         for (int j = 1; j < attrNames.length; j++) {  
  333.             if (attrNames[j].equals(attrName)) {  
  334.                 for (int i = 1; i < desData.size(); i++) {  
  335.                     if (desData.get(i)[j].equals(valueType)) {  
  336.                         // 如果匹配这个数据,则移除其他的数据  
  337.                         selectData.add(desData.get(i));  
  338.                     }  
  339.                 }  
  340.             }  
  341.         }  
  342.   
  343.         if (beLongValue) {  
  344.             desDataArray = new String[selectData.size()][];  
  345.             selectData.toArray(desDataArray);  
  346.         } else {  
  347.             // 属性名称行不移除  
  348.             selectData.remove(attrNames);  
  349.             // 如果是划分不属于此类型的数据时,进行移除  
  350.             desData.removeAll(selectData);  
  351.             desDataArray = new String[desData.size()][];  
  352.      &

数据分析咨询请扫描二维码

最新资讯
更多
客服在线
立即咨询