在机器学习中,因为决策树的算法是十分给力,因此使用决策树能够帮助我们解决很多的问题。决策树的算法分为很多种,今天小编主要跟大家介绍一下决策树的分类算法。
一、决策树的概念
决策树,根据名字就能知道,是一种树,一种依托于策略抉择而建立起来的树。在 机器学习中,决策树是一个预测模型,代表的是对象属性与对象值之间的一种映射关系。从数据产生决策树的机器学习技术叫做决策树学习, 通俗点说就是:决策树是一种依托于分类、训练上的预测树,可以根据已知,对未来进行预测、归类。
举一个简单的例子来说明:
一个女孩选择相亲对象,通过年龄是否超过30、长相丑或不丑、收入是否低水平,以及否是公务员这几项,将相亲对象分为两个类别:见和不见。假设这个女孩对相亲对象的要求为:30岁以下、长相不丑,而且高收入,或者中等以上收入的公务员,那么女孩的决策逻辑可以用下图来表示,典型的分类决策树。
二、决策树分类算法
1. ID3选取信息增益的属性递归进行分类
“熵”表示随机变量不确定性的度量,并且熵只依赖于X的分布,与X具体取值无关,所以可以表示为,熵越大,随机变量的不确定性就越大:
信息熵: H(X)=-sigma(对每一个x)(plogp)
“条件熵H(Y|X)”表示在已知随机变量X的条件下,随机变量Y的不确定性:
H(Y|X)=sigma(对每一个x)(pH(Y|X=xi))
“信息增益”特征A对训练数据集D的信息增益g(D,A),具体定义为:集合D的经验熵H(D),和特征A给定条件下D的经验条件熵H(D)熵
信息增益:g(D,A)=H(D)-H(D|A) H(D)为整个数据集的熵
信息增益率:(H(D)-H(D|X))/H(X)
算法流程:(1)计算每一个属性的信息增益,如果信息增益小于阈值,就将该支置为叶节点,选择其中个数最多的类标签来作为该类的类标签。反之,则选择其中最大的作为分类属 性。
(2)若果各个分支中都只含有同一类数据,那么就将这支置为叶子节点, 否则 继续进行(1)。
2. C4.5算法
C4.5算法是ID3的改进算法 , 是机器学习算法中的另一个分类决策树算法,可以说是决策树核心算法。
C4.5算法特点:
C4.5用信息增益率来选择属性。
能处理非离散数据。
能够处理不完整数据进行
一个可以选择的度量标准是增益比率gain ratio(Quinlan 1986)。增益比率度量是用前面的增益度量Gain(S,A)和分裂信息度量SplitInformation(S,A)来共同定义的,如下所示:
其中,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):
C4.5算法构造决策树过程:
Function C4.5(R:包含连续属性的无类别属性集合,C:类别属性,S:训练集)
/*返回一棵决策树*/
Begin
If S为空,返回一个值为Failure的单个节点;
If S是由相同类别属性值的记录组成,
返回一个带有该值的单个节点;
If R为空,则返回一个单节点,其值为在S的记录中找出的频率最高的类别属性值;
[注意未出现错误则意味着是不适合分类的记录];
For 所有的属性R(Ri) Do
If 属性Ri为连续属性,则
Begin
将Ri的最小值赋给A1:
将Rm的最大值赋给Am;/*m值手工设置*/
For j From 2 To m-1 Do Aj=A1+j*(A1Am)/m;
将Ri点的基于{< =Aj,>Aj}的最大信息增益属性(Ri,S)赋给A;
End;
将R中属性之间具有最大信息增益的属性(D,S)赋给D;
将属性D的值赋给{dj/j=1.2...m};
将分别由对应于D的值为dj的记录组成的S的子集赋给{sj/j=1.2...m};
返回一棵树,其根标记为D;树枝标记为d1.d2...dm;
再分别构造以下树:
C4.5(R-{D},C,S1),C4.5(R-{D},C,S2)...C4.5(R-{D},C,Sm);
End C4.5
3.CART算法:
基尼系数:Gini(p)=sigma(每一个类)p(1-p)
回归树:属性值为连续实数。将整个输入空间划分为m块,每一块以其平均值作为输出。f(x)=sigma(每一块)Cm*I(x属于Rm)
回归树生成:(1)选取切分变量和切分点,将输入空间分为两份。
(2)每一份分别进行第一步,直到满足停止条件。
切分变量和切分点选取:对于每一个变量进行遍历,从中选择切分点。选择一个切分点满足分类均方误差最小。然后在选出所有变量中最小分类误差最小的变量作为切分 变量。
分类树:属性值为离散值。
分类树生成:(1)根据每一个属性的每一个取值,是否取该值将样本分成两类,计算基尼系数。选择基尼系数最小的特征和属性值,将样本分成两份。
(2)递归调用(1)直到无法分割。完成CART树生成。
四、python实现
from sklearn.datasets import load_iris import numpy as np import math from collections import Counter class decisionnode: def __init__(self, d=None, thre=None, results=None, NH=None, lb=None, rb=None, max_label=None): self.d = d # d表示维度 self.thre = thre # thre表示二分时的比较值,将样本集分为2类 self.results = results # 最后的叶节点代表的类别 self.NH = NH # 存储各节点的样本量与经验熵的乘积,便于剪枝时使用 self.lb = lb # desision node,对应于样本在d维的数据小于thre时,树上相对于当前节点的子树上的节点 self.rb = rb # desision node,对应于样本在d维的数据大于thre时,树上相对于当前节点的子树上的节点 self.max_label = max_label # 记录当前节点包含的label中同类最多的label def entropy(y): ''' 计算信息熵,y为labels ''' if y.size > 1: category = list(set(y)) else: category = [y.item()] y = [y.item()] ent = 0 for label in category: p = len([label_ for label_ in y if label_ == label]) / len(y) ent += -p * math.log(p, 2) return ent def Gini(y): ''' 计算基尼指数,y为labels ''' category = list(set(y)) gini = 1 for label in category: p = len([label_ for label_ in y if label_ == label]) / len(y) gini += -p * p return gini def GainEnt_max(X, y, d): ''' 计算选择属性attr的最大信息增益,X为样本集,y为label,d为一个维度,type为int ''' ent_X = entropy(y) X_attr = X[:, d] X_attr = list(set(X_attr)) X_attr = sorted(X_attr) Gain = 0 thre = 0 for i in range(len(X_attr) - 1): thre_temp = (X_attr[i] + X_attr[i + 1]) / 2 y_small_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] <= thre_temp] y_big_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] > thre_temp] y_small = y[y_small_index] y_big = y[y_big_index] Gain_temp = ent_X - (len(y_small) / len(y)) * \ entropy(y_small) - (len(y_big) / len(y)) * entropy(y_big) ''' intrinsic_value = -(len(y_small) / len(y)) * math.log(len(y_small) / len(y), 2) - (len(y_big) / len(y)) * math.log(len(y_big) / len(y), 2) Gain_temp = Gain_temp / intrinsic_value ''' # print(Gain_temp) if Gain < Gain_temp: Gain = Gain_temp thre = thre_temp return Gain, thre def Gini_index_min(X, y, d): ''' 计算选择属性attr的最小基尼指数,X为样本集,y为label,d为一个维度,type为int ''' X = X.reshape(-1, len(X.T)) X_attr = X[:, d] X_attr = list(set(X_attr)) X_attr = sorted(X_attr) Gini_index = 1 thre = 0 for i in range(len(X_attr) - 1): thre_temp = (X_attr[i] + X_attr[i + 1]) / 2 y_small_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] <= thre_temp] y_big_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] > thre_temp] y_small = y[y_small_index] y_big = y[y_big_index] Gini_index_temp = (len(y_small) / len(y)) * \ Gini(y_small) + (len(y_big) / len(y)) * Gini(y_big) if Gini_index > Gini_index_temp: Gini_index = Gini_index_temp thre = thre_temp return Gini_index, thre def attribute_based_on_GainEnt(X, y): ''' 基于信息增益选择最优属性,X为样本集,y为label ''' D = np.arange(len(X[0])) Gain_max = 0 thre_ = 0 d_ = 0 for d in D: Gain, thre = GainEnt_max(X, y, d) if Gain_max < Gain: Gain_max = Gain thre_ = thre d_ = d # 维度标号 return Gain_max, thre_, d_ def attribute_based_on_Giniindex(X, y): ''' 基于信息增益选择最优属性,X为样本集,y为label ''' D = np.arange(len(X.T)) Gini_Index_Min = 1 thre_ = 0 d_ = 0 for d in D: Gini_index, thre = Gini_index_min(X, y, d) if Gini_Index_Min > Gini_index: Gini_Index_Min = Gini_index thre_ = thre d_ = d # 维度标号 return Gini_Index_Min, thre_, d_ def devide_group(X, y, thre, d): ''' 按照维度d下阈值为thre分为两类并返回 ''' X_in_d = X[:, d] x_small_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] <= thre] x_big_index = [i_arg for i_arg in range( len(X[:, d])) if X[i_arg, d] > thre] X_small = X[x_small_index] y_small = y[x_small_index] X_big = X[x_big_index] y_big = y[x_big_index] return X_small, y_small, X_big, y_big def NtHt(y): ''' 计算经验熵与样本数的乘积,用来剪枝,y为labels ''' ent = entropy(y) print('ent={},y_len={},all={}'.format(ent, len(y), ent * len(y))) return ent * len(y) def maxlabel(y): label_ = Counter(y).most_common(1) return label_[0][0] def buildtree(X, y, method='Gini'): ''' 递归的方式构建决策树 ''' if y.size > 1: if method == 'Gini': Gain_max, thre, d = attribute_based_on_Giniindex(X, y) elif method == 'GainEnt': Gain_max, thre, d = attribute_based_on_GainEnt(X, y) if (Gain_max > 0 and method == 'GainEnt') or (Gain_max >= 0 and len(list(set(y))) > 1 and method == 'Gini'): X_small, y_small, X_big, y_big = devide_group(X, y, thre, d) left_branch = buildtree(X_small, y_small, method=method) right_branch = buildtree(X_big, y_big, method=method) nh = NtHt(y) max_label = maxlabel(y) return decisionnode(d=d, thre=thre, NH=nh, lb=left_branch, rb=right_branch, max_label=max_label) else: nh = NtHt(y) max_label = maxlabel(y) return decisionnode(results=y[0], NH=nh, max_label=max_label) else: nh = NtHt(y) max_label = maxlabel(y) return decisionnode(results=y.item(), NH=nh, max_label=max_label) def printtree(tree, indent='-', dict_tree={}, direct='L'): # 是否是叶节点 if tree.results != None: print(tree.results) dict_tree = {direct: str(tree.results)} else: # 打印判断条件 print(str(tree.d) + ":" + str(tree.thre) + "? ") # 打印分支 print(indent + "L->",) a = printtree(tree.lb, indent=indent + "-", direct='L') aa = a.copy() print(indent + "R->",) b = printtree(tree.rb, indent=indent + "-", direct='R') bb = b.copy() aa.update(bb) stri = str(tree.d) + ":" + str(tree.thre) + "?" if indent != '-': dict_tree = {direct: {stri: aa}} else: dict_tree = {stri: aa} return dict_tree def classify(observation, tree): if tree.results != None: return tree.results else: v = observation[tree.d] branch = None if v > tree.thre: branch = tree.rb else: branch = tree.lb return classify(observation, branch) def pruning(tree, alpha=0.1): if tree.lb.results == None: pruning(tree.lb, alpha) if tree.rb.results == None: pruning(tree.rb, alpha) if tree.lb.results != None and tree.rb.results != None: before_pruning = tree.lb.NH + tree.rb.NH + 2 * alpha after_pruning = tree.NH + alpha print('before_pruning={},after_pruning={}'.format( before_pruning, after_pruning)) if after_pruning <= before_pruning: print('pruning--{}:{}?'.format(tree.d, tree.thre)) tree.lb, tree.rb = None, None tree.results = tree.max_label if __name__ == '__main__': iris = load_iris() X = iris.data y = iris.target permutation = np.random.permutation(X.shape[0]) shuffled_dataset = X[permutation, :] shuffled_labels = y[permutation] train_data = shuffled_dataset[:100, :] train_label = shuffled_labels[:100] test_data = shuffled_dataset[100:150, :] test_label = shuffled_labels[100:150] tree1 = buildtree(train_data, train_label, method='Gini') print('=============================') tree2 = buildtree(train_data, train_label, method='GainEnt') a = printtree(tree=tree1) b = printtree(tree=tree2) true_count = 0 for i in range(len(test_label)): predict = classify(test_data[i], tree1) if predict == test_label[i]: true_count += 1 print("CARTTree:{}".format(true_count)) true_count = 0 for i in range(len(test_label)): predict = classify(test_data[i], tree2) if predict == test_label[i]: true_count += 1 print("C3Tree:{}".format(true_count)) #print(attribute_based_on_Giniindex(X[49:51, :], y[49:51])) from pylab import * mpl.rcParams['font.sans-serif'] = ['SimHei'] # 指定默认字体 mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像时负号'-'显示为方块的问题 import treePlotter import matplotlib.pyplot as plt treePlotter.createPlot(a, 1) treePlotter.createPlot(b, 2) # 剪枝处理 pruning(tree=tree1, alpha=4) pruning(tree=tree2, alpha=4) a = printtree(tree=tree1) b = printtree(tree=tree2) true_count = 0 for i in range(len(test_label)): predict = classify(test_data[i], tree1) if predict == test_label[i]: true_count += 1 print("CARTTree:{}".format(true_count)) true_count = 0 for i in range(len(test_label)): predict = classify(test_data[i], tree2) if predict == test_label[i]: true_count += 1 print("C3Tree:{}".format(true_count)) treePlotter.createPlot(a, 3) treePlotter.createPlot(b, 4) plt.show()
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
Excel是数据分析的重要工具,强大的内置功能使其成为许多分析师的首选。在日常工作中,启用Excel的数据分析工具库能够显著提升数 ...
2024-12-23在当今信息爆炸的时代,数据分析师如同一位现代社会的侦探,肩负着从海量数据中提炼出有价值信息的重任。在这个过程中,掌握一系 ...
2024-12-23在现代的职场中,制作吸引人的PPT已经成为展示信息的重要手段,而其中数据对比的有效呈现尤为关键。为了让数据在幻灯片上不仅准 ...
2024-12-23在信息泛滥的现代社会,数据分析师已成为企业决策过程中不可或缺的角色。他们的任务是从海量数据中提取有价值的洞察,帮助组织制 ...
2024-12-23在数据驱动时代,数据分析已成为各行各业的必需技能。无论是提升个人能力还是推动职业发展,选择一条适合自己的学习路线至关重要 ...
2024-12-23在准备数据分析师面试时,掌握高频考题及其解答是应对面试的关键。为了帮助大家轻松上岸,以下是10个高频考题及其详细解析,外加 ...
2024-12-20互联网数据分析师是一个热门且综合性的职业,他们通过数据挖掘和分析,为企业的业务决策和运营优化提供强有力的支持。尤其在如今 ...
2024-12-20在现代商业环境中,数据分析师是不可或缺的角色。他们的工作不仅仅是对数据进行深入分析,更是协助企业从复杂的数据信息中提炼出 ...
2024-12-20随着大数据时代的到来,数据驱动的决策方式开始受到越来越多企业的青睐。近年来,数据分析在人力资源管理中正在扮演着至关重要的 ...
2024-12-20在数据分析的世界里,表面上的技术操作只是“入门票”,而真正的高手则需要打破一些“看不见的墙”。这些“隐形天花板”限制了数 ...
2024-12-19在数据分析领域,尽管行业前景广阔、岗位需求旺盛,但实际的工作难度却远超很多人的想象。很多新手初入数据分析岗位时,常常被各 ...
2024-12-19入门数据分析,许多人都会感到“难”,但这“难”究竟难在哪儿?对于新手而言,往往不是技术不行,而是思维方式、业务理解和实践 ...
2024-12-19在如今的行业动荡背景下,数据分析师的职业前景虽然面临一些挑战,但也充满了许多新的机会。随着技术的不断发展和多领域需求的提 ...
2024-12-19在信息爆炸的时代,数据分析师如同探险家,在浩瀚的数据海洋中寻觅有价值的宝藏。这不仅需要技术上的过硬实力,还需要一种艺术家 ...
2024-12-19在当今信息化社会,大数据已成为各行各业不可或缺的宝贵资源。大数据专业应运而生,旨在培养具备扎实理论基础和实践能力,能够应 ...
2024-12-19阿里P8、P9失业都找不到工作?是我们孤陋寡闻还是世界真的已经“癫”成这样了? 案例一:本硕都是 985,所学的专业也是当红专业 ...
2024-12-19CDA持证人Louis CDA持证人基本情况 我大学是在一个二线城市的一所普通二本院校读的,专业是旅游管理,非计算机非统计学。毕业之 ...
2024-12-18最近,知乎上有个很火的话题:“一个人为何会陷入社会底层”? 有人说,这个世界上只有一个分水岭,就是“羊水”;还有人说,一 ...
2024-12-18在这个数据驱动的时代,数据分析师的技能需求快速增长。掌握适当的编程语言不仅能增强分析能力,还能帮助分析师从海量数据中提取 ...
2024-12-17在当今信息爆炸的时代,数据分析已经成为许多行业中不可或缺的一部分。想要在这个领域脱颖而出,除了热情和毅力外,你还需要掌握 ...
2024-12-17