在机器学习中,因为决策树的算法是十分给力,因此使用决策树能够帮助我们解决很多的问题。决策树的算法分为很多种,今天小编主要跟大家介绍一下决策树的分类算法。
一、决策树的概念
决策树,根据名字就能知道,是一种树,一种依托于策略抉择而建立起来的树。在 机器学习中,决策树是一个预测模型,代表的是对象属性与对象值之间的一种映射关系。从数据产生决策树的机器学习技术叫做决策树学习, 通俗点说就是:决策树是一种依托于分类、训练上的预测树,可以根据已知,对未来进行预测、归类。
举一个简单的例子来说明:
一个女孩选择相亲对象,通过年龄是否超过30、长相丑或不丑、收入是否低水平,以及否是公务员这几项,将相亲对象分为两个类别:见和不见。假设这个女孩对相亲对象的要求为:30岁以下、长相不丑,而且高收入,或者中等以上收入的公务员,那么女孩的决策逻辑可以用下图来表示,典型的分类决策树。
二、决策树分类算法
1. ID3选取信息增益的属性递归进行分类
“熵”表示随机变量不确定性的度量,并且熵只依赖于X的分布,与X具体取值无关,所以可以表示为,熵越大,随机变量的不确定性就越大:
信息熵: H(X)=-sigma(对每一个x)(plogp)
“条件熵H(Y|X)”表示在已知随机变量X的条件下,随机变量Y的不确定性:
H(Y|X)=sigma(对每一个x)(pH(Y|X=xi))
“信息增益”特征A对训练数据集D的信息增益g(D,A),具体定义为:集合D的经验熵H(D),和特征A给定条件下D的经验条件熵H(D)熵
信息增益:g(D,A)=H(D)-H(D|A) H(D)为整个数据集的熵
信息增益率:(H(D)-H(D|X))/H(X)
算法流程:(1)计算每一个属性的信息增益,如果信息增益小于阈值,就将该支置为叶节点,选择其中个数最多的类标签来作为该类的类标签。反之,则选择其中最大的作为分类属 性。
(2)若果各个分支中都只含有同一类数据,那么就将这支置为叶子节点, 否则 继续进行(1)。
2. C4.5算法
C4.5算法是ID3的改进算法 , 是机器学习算法中的另一个分类决策树算法,可以说是决策树核心算法。
C4.5算法特点:
C4.5用信息增益率来选择属性。
能处理非离散数据。
能够处理不完整数据进行
一个可以选择的度量标准是增益比率gain ratio(Quinlan 1986)。增益比率度量是用前面的增益度量Gain(S,A)和分裂信息度量SplitInformation(S,A)来共同定义的,如下所示:
其中,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):
C4.5算法构造决策树过程:
Function C4.5(R:包含连续属性的无类别属性集合,C:类别属性,S:训练集)
/*返回一棵决策树*/
Begin
If S为空,返回一个值为Failure的单个节点;
If S是由相同类别属性值的记录组成,
返回一个带有该值的单个节点;
If R为空,则返回一个单节点,其值为在S的记录中找出的频率最高的类别属性值;
[注意未出现错误则意味着是不适合分类的记录];
For 所有的属性R(Ri) Do
If 属性Ri为连续属性,则
Begin
将Ri的最小值赋给A1:
将Rm的最大值赋给Am;/*m值手工设置*/
For j From 2 To m-1 Do Aj=A1+j*(A1Am)/m;
将Ri点的基于{< =Aj,>Aj}的最大信息增益属性(Ri,S)赋给A;
End;
将R中属性之间具有最大信息增益的属性(D,S)赋给D;
将属性D的值赋给{dj/j=1.2...m};
将分别由对应于D的值为dj的记录组成的S的子集赋给{sj/j=1.2...m};
返回一棵树,其根标记为D;树枝标记为d1.d2...dm;
再分别构造以下树:
C4.5(R-{D},C,S1),C4.5(R-{D},C,S2)...C4.5(R-{D},C,Sm);
End C4.5
3.CART算法:
基尼系数:Gini(p)=sigma(每一个类)p(1-p)
回归树:属性值为连续实数。将整个输入空间划分为m块,每一块以其平均值作为输出。f(x)=sigma(每一块)Cm*I(x属于Rm)
回归树生成:(1)选取切分变量和切分点,将输入空间分为两份。
(2)每一份分别进行第一步,直到满足停止条件。
切分变量和切分点选取:对于每一个变量进行遍历,从中选择切分点。选择一个切分点满足分类均方误差最小。然后在选出所有变量中最小分类误差最小的变量作为切分 变量。
分类树:属性值为离散值。
分类树生成:(1)根据每一个属性的每一个取值,是否取该值将样本分成两类,计算基尼系数。选择基尼系数最小的特征和属性值,将样本分成两份。
(2)递归调用(1)直到无法分割。完成CART树生成。
四、python实现
from sklearn.datasets import load_iris import numpy as np import math from collections import Counter class decisionnode: def __init__(self, d=None, thre=None, results=None, NH=None, lb=None, rb=None, max_label=None): self.d = d # d表示维度 self.thre = thre # thre表示二分时的比较值,将样本集分为2类 self.results = results # 最后的叶节点代表的类别 self.NH = NH # 存储各节点的样本量与经验熵的乘积,便于剪枝时使用 self.lb = lb # desision node,对应于样本在d维的数据小于thre时,树上相对于当前节点的子树上的节点 self.rb = rb # desision node,对应于样本在d维的数据大于thre时,树上相对于当前节点的子树上的节点 self.max_label = max_label # 记录当前节点包含的label中同类最多的label def entropy(y): ''' 计算信息熵,y为labels ''' if y.size > 1: category = list(set(y)) else: category = [y.item()] y = [y.item()] ent = 0 for label in category: p = len([label_ for label_ in y if label_ == label]) / len(y) ent += -p * math.log(p, 2) return ent def Gini(y): ''' 计算基尼指数,y为labels ''' category = list(set(y)) gini = 1 for label in category: p = len([label_ for label_ in y if label_ == label]) / len(y) gini += -p * p return gini def GainEnt_max(X, y, d): ''' 计算选择属性attr的最大信息增益,X为样本集,y为label,d为一个维度,type为int ''' ent_X = entropy(y) X_attr = X[:, d] X_attr = list(set(X_attr)) X_attr = sorted(X_attr) Gain = 0 thre = 0 for i in range(len(X_attr) - 1): thre_temp = (X_attr[i] + X_attr[i + 1]) / 2 y_small_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] <= thre_temp] y_big_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] > thre_temp] y_small = y[y_small_index] y_big = y[y_big_index] Gain_temp = ent_X - (len(y_small) / len(y)) * \ entropy(y_small) - (len(y_big) / len(y)) * entropy(y_big) ''' intrinsic_value = -(len(y_small) / len(y)) * math.log(len(y_small) / len(y), 2) - (len(y_big) / len(y)) * math.log(len(y_big) / len(y), 2) Gain_temp = Gain_temp / intrinsic_value ''' # print(Gain_temp) if Gain < Gain_temp: Gain = Gain_temp thre = thre_temp return Gain, thre def Gini_index_min(X, y, d): ''' 计算选择属性attr的最小基尼指数,X为样本集,y为label,d为一个维度,type为int ''' X = X.reshape(-1, len(X.T)) X_attr = X[:, d] X_attr = list(set(X_attr)) X_attr = sorted(X_attr) Gini_index = 1 thre = 0 for i in range(len(X_attr) - 1): thre_temp = (X_attr[i] + X_attr[i + 1]) / 2 y_small_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] <= thre_temp] y_big_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] > thre_temp] y_small = y[y_small_index] y_big = y[y_big_index] Gini_index_temp = (len(y_small) / len(y)) * \ Gini(y_small) + (len(y_big) / len(y)) * Gini(y_big) if Gini_index > Gini_index_temp: Gini_index = Gini_index_temp thre = thre_temp return Gini_index, thre def attribute_based_on_GainEnt(X, y): ''' 基于信息增益选择最优属性,X为样本集,y为label ''' D = np.arange(len(X[0])) Gain_max = 0 thre_ = 0 d_ = 0 for d in D: Gain, thre = GainEnt_max(X, y, d) if Gain_max < Gain: Gain_max = Gain thre_ = thre d_ = d # 维度标号 return Gain_max, thre_, d_ def attribute_based_on_Giniindex(X, y): ''' 基于信息增益选择最优属性,X为样本集,y为label ''' D = np.arange(len(X.T)) Gini_Index_Min = 1 thre_ = 0 d_ = 0 for d in D: Gini_index, thre = Gini_index_min(X, y, d) if Gini_Index_Min > Gini_index: Gini_Index_Min = Gini_index thre_ = thre d_ = d # 维度标号 return Gini_Index_Min, thre_, d_ def devide_group(X, y, thre, d): ''' 按照维度d下阈值为thre分为两类并返回 ''' X_in_d = X[:, d] x_small_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] <= thre] x_big_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] > thre] X_small = X[x_small_index] y_small = y[x_small_index] X_big = X[x_big_index] y_big = y[x_big_index] return X_small, y_small, X_big, y_big def NtHt(y): ''' 计算经验熵与样本数的乘积,用来剪枝,y为labels ''' ent = entropy(y) print('ent={},y_len={},all={}'.format(ent, len(y), ent * len(y))) return ent * len(y) def maxlabel(y): label_ = Counter(y).most_common(1) return label_[0][0] def buildtree(X, y, method='Gini'): ''' 递归的方式构建决策树 ''' if y.size > 1: if method == 'Gini': Gain_max, thre, d = attribute_based_on_Giniindex(X, y) elif method == 'GainEnt': Gain_max, thre, d = attribute_based_on_GainEnt(X, y) if (Gain_max > 0 and method == 'GainEnt') or (Gain_max >= 0 and len(list(set(y))) > 1 and method == 'Gini'): X_small, y_small, X_big, y_big = devide_group(X, y, thre, d) left_branch = buildtree(X_small, y_small, method=method) right_branch = buildtree(X_big, y_big, method=method) nh = NtHt(y) max_label = maxlabel(y) return decisionnode(d=d, thre=thre, NH=nh, lb=left_branch, rb=right_branch, max_label=max_label) else: nh = NtHt(y) max_label = maxlabel(y) return decisionnode(results=y[0], NH=nh, max_label=max_label) else: nh = NtHt(y) max_label = maxlabel(y) return decisionnode(results=y.item(), NH=nh, max_label=max_label) def printtree(tree, indent='-', dict_tree={}, direct='L'): # 是否是叶节点 if tree.results != None: print(tree.results) dict_tree = {direct: str(tree.results)} else: # 打印判断条件 print(str(tree.d) + ":" + str(tree.thre) + "? ") # 打印分支 print(indent + "L->",) a = printtree(tree.lb, indent=indent + "-", direct='L') aa = a.copy() print(indent + "R->",) b = printtree(tree.rb, indent=indent + "-", direct='R') bb = b.copy() aa.update(bb) stri = str(tree.d) + ":" + str(tree.thre) + "?" if indent != '-': dict_tree = {direct: {stri: aa}} else: dict_tree = {stri: aa} return dict_tree def classify(observation, tree): if tree.results != None: return tree.results else: v = observation[tree.d] branch = None if v > tree.thre: branch = tree.rb else: branch = tree.lb return classify(observation, branch) def pruning(tree, alpha=0.1): if tree.lb.results == None: pruning(tree.lb, alpha) if tree.rb.results == None: pruning(tree.rb, alpha) if tree.lb.results != None and tree.rb.results != None: before_pruning = tree.lb.NH + tree.rb.NH + 2 * alpha after_pruning = tree.NH + alpha print('before_pruning={},after_pruning={}'.format( before_pruning, after_pruning)) if after_pruning <= before_pruning: print('pruning--{}:{}?'.format(tree.d, tree.thre)) tree.lb, tree.rb = None, None tree.results = tree.max_label if __name__ == '__main__': iris = load_iris() X = iris.data y = iris.target permutation = np.random.permutation(X.shape[0]) shuffled_dataset = X[permutation, :] shuffled_labels = y[permutation] train_data = shuffled_dataset[:100, :] train_label = shuffled_labels[:100] test_data = shuffled_dataset[100:150, :] test_label = shuffled_labels[100:150] tree1 = buildtree(train_data, train_label, method='Gini') print('=============================') tree2 = buildtree(train_data, train_label, method='GainEnt') a = printtree(tree=tree1) b = printtree(tree=tree2) true_count = 0 for i in range(len(test_label)): predict = classify(test_data[i], tree1) if predict == test_label[i]: true_count += 1 print("CARTTree:{}".format(true_count)) true_count = 0 for i in range(len(test_label)): predict = classify(test_data[i], tree2) if predict == test_label[i]: true_count += 1 print("C3Tree:{}".format(true_count)) #print(attribute_based_on_Giniindex(X[49:51, :], y[49:51])) from pylab import * mpl.rcParams['font.sans-serif'] = ['SimHei'] # 指定默认字体 mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像时负号'-'显示为方块的问题 import treePlotter import matplotlib.pyplot as plt treePlotter.createPlot(a, 1) treePlotter.createPlot(b, 2) # 剪枝处理 pruning(tree=tree1, alpha=4) pruning(tree=tree2, alpha=4) a = printtree(tree=tree1) b = printtree(tree=tree2) true_count = 0 for i in range(len(test_label)): predict = classify(test_data[i], tree1) if predict == test_label[i]: true_count += 1 print("CARTTree:{}".format(true_count)) true_count = 0 for i in range(len(test_label)): predict = classify(test_data[i], tree2) if predict == test_label[i]: true_count += 1 print("C3Tree:{}".format(true_count)) treePlotter.createPlot(a, 3) treePlotter.createPlot(b, 4) plt.show()
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
你是否被统计学复杂的理论和晦涩的公式劝退过?别担心,“山有木兮:统计学极简入门(Python)” 将为你一一化解这些难题。课程 ...
2025-03-31在电商、零售、甚至内容付费业务中,你真的了解你的客户吗? 有些客户下了一两次单就消失了,有些人每个月都回购,有些人曾经是 ...
2025-03-31在数字化浪潮中,数据驱动决策已成为企业发展的核心竞争力,数据分析人才的需求持续飙升。世界经济论坛发布的《未来就业报告》, ...
2025-03-28你有没有遇到过这样的情况?流量进来了,转化率却不高,辛辛苦苦拉来的用户,最后大部分都悄无声息地离开了,这时候漏斗分析就非 ...
2025-03-27TensorFlow Datasets(TFDS)是一个用于下载、管理和预处理机器学习数据集的库。它提供了易于使用的API,允许用户从现有集合中 ...
2025-03-26"不谋全局者,不足谋一域。"在数据驱动的商业时代,战略级数据分析能力已成为职场核心竞争力。《CDA二级教材:商业策略数据分析 ...
2025-03-26当你在某宝刷到【猜你喜欢】时,当抖音精准推来你的梦中情猫时,当美团外卖弹窗刚好是你想吃的火锅店…… 恭喜你,你正在被用户 ...
2025-03-26当面试官问起随机森林时,他到底在考察什么? ""请解释随机森林的原理""——这是数据分析岗位面试中的经典问题。但你可能不知道 ...
2025-03-25在数字化浪潮席卷的当下,数据俨然成为企业的命脉,贯穿于业务运作的各个环节。从线上到线下,从平台的交易数据,到门店的运营 ...
2025-03-25在互联网和移动应用领域,DAU(日活跃用户数)是一个耳熟能详的指标。无论是产品经理、运营,还是数据分析师,DAU都是衡量产品 ...
2025-03-24ABtest做的好,产品优化效果差不了!可见ABtest在评估优化策略的效果方面地位还是很高的,那么如何在业务中应用ABtest? 结合企业 ...
2025-03-21在企业数据分析中,指标体系是至关重要的工具。不仅帮助企业统一数据标准、提升数据质量,还能为业务决策提供有力支持。本文将围 ...
2025-03-20解锁数据分析师高薪密码,CDA 脱产就业班助你逆袭! 在数字化浪潮中,数据驱动决策已成为企业发展的核心竞争力,数据分析人才的 ...
2025-03-19在 MySQL 数据库中,查询一张表但是不包含某个字段可以通过以下两种方法实现:使用 SELECT 子句以明确指定想要的字段,或者使 ...
2025-03-17在当今数字化时代,数据成为企业发展的关键驱动力,而用户画像作为数据分析的重要成果,改变了企业理解用户、开展业务的方式。无 ...
2025-03-172025年是智能体(AI Agent)的元年,大模型和智能体的发展比较迅猛。感觉年初的deepseek刚火没多久,这几天Manus又成为媒体头条 ...
2025-03-14以下的文章内容来源于柯家媛老师的专栏,如果您想阅读专栏《小白必备的数据思维课》,点击下方链接 https://edu.cda.cn/goods/sh ...
2025-03-13以下的文章内容来源于刘静老师的专栏,如果您想阅读专栏《10大业务分析模型突破业务瓶颈》,点击下方链接 https://edu.cda.cn/go ...
2025-03-12以下的文章内容来源于柯家媛老师的专栏,如果您想阅读专栏《小白必备的数据思维课》,点击下方链接 https://edu.cda.cn/goods/sh ...
2025-03-11随着数字化转型的加速,企业积累了海量数据,如何从这些数据中挖掘有价值的信息,成为企业提升竞争力的关键。CDA认证考试体系应 ...
2025-03-10