机器学习之k-近邻(kNN)算法与Python实现
k-近邻算法(kNN,k-NearestNeighbor),是最简单的机器学习分类算法之一,其核心思想在于用距离目标最近的k个样本数据的分类来代表目标的分类(这k个样本数据和目标数据最为相似)。
一 k-近邻(kNN)算法概述
1.概念
kNN算法的核心思想是用距离最近的k个样本数据的分类来代表目标数据的分类。
其原理具体地讲,存在一个训练样本集,这个数据训练样本的数据集合中的每个样本都包含数据的特征和目标变量(即分类值),输入新的不含目标变量的数据,将该数据的特征与训练样本集中每一个样本进行比较,找到最相似的k个数据,这k个数据出席那次数最多的分类,即输入的具有特征值的数据的分类。
例如,训练样本集中包含一系列数据,这个数据包括样本空间位置(特征)和分类信息(即目标变量,属于红色三角形还是蓝色正方形),要对中心的绿色数据的分类。运用kNN算法思想,距离最近的k个样本的分类来代表测试数据的分类,那么:
当k=3时,距离最近的3个样本在实线内,具有2个红色三角和1个蓝色正方形**,因此将它归为红色三角。
当k=5时,距离最近的5个样本在虚线内,具有2个红色三角和3个蓝色正方形**,因此将它归为蓝色正方形。
2.特点
优点
(1)监督学习:可以看到,kNN算法首先需要一个训练样本集,这个集合中含有分类信息,因此它属于监督学习。
(2)通过计算距离来衡量样本之间相似度,算法简单,易于理解和实现。
(3)对异常值不敏感
缺点 (4)需要设定k值,结果会受到k值的影响,通过上面的例子可以看到,不同的k值,最后得到的分类结果不尽相同。k一般不超过20。(5)计算量大,需要计算样本集中每个样本的距离,才能得到k个最近的数据样本。 (6)训练样本集不平衡导致结果不准确问题。当样本集中主要是某个分类,该分类数量太大,导致近邻的k个样本总是该类,而不接近目标分类。
3.kNN算法流程
一般情况下,kNN有如下流程:
(1)收集数据:确定训练样本集合测试数据;
(2)计算测试数据和训练样本集中每个样本数据的距离;
常用的距离计算公式:
(3)按照距离递增的顺序排序;
(4)选取距离最近的k个点;
(5)确定这k个点中分类信息的频率;
(6)返回前k个点中出现频率最高的分类,作为当前测试数据的分类。二 、Python算法实现
1.KNN算法分类器
建立一个名为“KNN.py”的文件,构造一个kNN算法分类器的函数:
from numpy import *
import operator
#定义KNN算法分类器函数
#函数参数包括:(测试数据,训练数据,分类,k值)
def classify(inX,dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX,(dataSetSize,1))-dataSet
sqDiffMat=diffMat**2
sqDistances=sqDiffMat.sum(axis=1)
distances=sqDistances**0.5 #计算欧式距离
sortedDistIndicies=distances.argsort() #排序并返回index
#选择距离最近的k个值
classCount={}
for i in range(k):
voteIlabel=labels[sortedDistIndicies[i]]
#D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.
classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
#排序
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
在KNN.py中定义一个生成“训练样本集”的函数:
#定义一个生成“训练样本集”的函数,包含特征和分类信息在Python控制台先将当前目录设置为“KNN.py”所在的文件目录,将测试数据[0,0]进行KNN算法分类测试,输入:
import KNN
#生成训练样本
group,labels=KNN.createDataSet()
#对测试数据[0,0]进行KNN算法分类测试
KNN.classify([0,0],group,labels,3)
Out[3]: 'B'
可以看到该分类器函数将[0,0]分类为B组,符合实际情况,分入了符合逻辑的正确的类别。但如何知道KNN分类的正确性呢?
2.kNN算法用于约会网站配对
2.1准备数据
该数据在文本文件datingTestSet2.txt中,该数据具有1000行,4列,分别是特征数据(每年获得的飞行常客里程数,玩视频游戏所耗时间百分比,每周消费的冰淇淋公升数),和目标变量/分类数据(是否喜欢(1表示不喜欢,2表示魅力一般,3表示极具魅力)),部分数据展示如下:
完整地数据下载地址如下:
约会网站测试数据
(1)将文本记录转为成numpy
在python控制台输入:
in [5]:datingDataMat,datingLabels=KNN.file2matrix('G:\Workspaces\MachineLearning\machinelearninginaction\Ch02\datingTestSet2.txt')#括号是文件路径
(2)可视化分析数据
运用Matplotlib创建散点图来分析数据:
import matplotlib
import matplotlib.pyplot as plt
#对第二列和第三列数据进行分析:
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(datingDataMat[:,1],datingDataMat[:,2],c=datingLabels)
plt.xlabel('Percentage of Time Spent Playing Video Games')
plt.ylabel('Liters of Ice Cream Consumed Per Week')
#对第一列和第二列进行分析:
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(datingDataMat[:,0],datingDataMat[:,1],c=datingLabels)
plt.xlabel('Miles of plane Per year')
plt.ylabel('Percentage of Time Spent Playing Video Games')
ax.legend(loc='best')
(3)数据归一化
由于不同的数据在大小上差别较大,在计算欧式距离,整体较大的数据明细所占的比重更高,因此需要对数据进行归一化处理。
在Python控制台输入:
reload(KNN)数据的准备工作完成,下一步对算法进行测试。
2.2 算法测试
kNN算法分类的结果的效果,可以使用正确率/错误率来衡量,错误率为0,则表示分类很完美,如果错误率为1,表示分类完全错误。我们使用1000条数据中的90%作为训练样本集,其中的10%来测试错误率。
#定义测试算法的函数在控制台输入命令来测试错误率:
reload(KNN)
Out[150]: <module 'KNN' from 'G:\\Workspaces\\MachineLearning\\KNN.py'>
KNN.datingClassTest()
the classifier came back with: 3,the real answer is: 3
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 1,the real answer is: 1
... ...
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 1,the real answer is: 1
the classifier came back with: 3,the real answer is: 1
the total error rate is : 0.050000
可以看到KNN算法分类器处理约会数据的错误率是5%,具有较高额正确率。
可以在datingClassTest函数中传入参数h来改变测试数据比例,来看修改后Ration后错误率有什么样的变化。
KNN.datingClassTest(0.2)
the classifier came back with: 3,the real answer is: 3
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 1,the real answer is: 1
... ...
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 3,the real answer is: 3
the classifier came back with: 2,the real answer is: 2
the total error rate is : 0.080000
减小训练样本集数据,增加测试数据,错误率增加到8%。
2.3 使用KNN算法进行预测
def classifypersion():测试一下:
reload(KNN)
Out[153]: <module 'KNN' from 'G:\\Workspaces\\MachineLearning\\KNN.py'>
KNN.classifypersion()
percentage of time spent playing video games?10
frequent flier miles earned per year?10000
liters of ice creamconsued per year?0.5
You will probably like this persion :not at all
3. KNN算法用于手写识别系统
已经将图片转化为32*32 的文本格式,文本格式如下:
00000000000111110000000000000000
00000000001111111000000000000000
00000000011111111100000000000000
00000000111111111110000000000000
00000001111111111111000000000000
00000011111110111111100000000000
00000011111100011111110000000000
00000011111100001111110000000000
00000111111100000111111000000000
00000111111100000011111000000000
00000011111100000001111110000000
00000111111100000000111111000000
00000111111000000000011111000000
00000111111000000000011111100000
00000111111000000000011111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000011111000000000001111100000
00000011111100000000011111100000
00000011111100000000111111000000
00000001111110000000111111100000
00000000111110000001111111000000
00000000111110000011111110000000
00000000111111000111111100000000
00000000111111111111111000000000
00000000111111111111110000000000
00000000011111111111100000000000
00000000001111111111000000000000
00000000000111111110000000000000
3.1数据准备
(1)将32*32的文本格式转为成1*2014的向量
在控制台中输入命令测试下函数:
reload(KNN)
3.2 算法测试
使用kNN算法测试手写数字识别
#引入os模块的listdir函数,列出给定目录的文件名
from os impor listdir
def handwritingClassTest():
hwLabels=[]
trainingFileList=listdir('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/trainingDigits')#列出文件名
m=len(trainingFileList) #文件数目
trainMat=zeros((m,1024))
#从文件名中解析分类信息,如0_13.txt
for i in range(m):
fileNameStr=trainingFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumber=int(fileStr.split('_')[0])
hwLabels.append(classNumber)
trainMat[i]=img2vector('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/trainingDigits/%s'%fileNameStr)
testFileList=listdir('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/testDigits')
errorCount=0
#同上,解析测试数据的分类信息
mTest=len(testFileList)
for i in range(mTest):
fileNameStr=testFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumber=int(fileStr.split('_')[0])
vectorUnderTest=img2vector('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/testDigits/%s'%fileNameStr)
classifierResult=classify(vectorUnderTest,trainMat,hwLabels,3)
print('the classifier came back with :%d,the real answer is:%d'%(classifierResult,classNumber))
if(classifierResult!=classNumber):errorCount+=1
print('\n the total number of errors is: %d'%errorCount)
print('\n total error rate is %f'%(errorCount/float(mTest)))
接下来在Python控制台输入命令来测试手写数字识别:
reload(KNN)
KNN.handwritingClassTest()
the classifier came back with :0,the real answer is:0
the classifier came back with :0,the real answer is:0
the classifier came back with :0,the real answer is:0
... ...
the classifier came back with :9,the real answer is:9
the classifier came back with :9,the real answer is:9
the classifier came back with :9,the real answer is:9
the total number of errors is: 10
total error rate is 0.010571
错误利率1.057%,具有较高的准确率。
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
Excel是数据分析的重要工具,强大的内置功能使其成为许多分析师的首选。在日常工作中,启用Excel的数据分析工具库能够显著提升数 ...
2024-12-23在当今信息爆炸的时代,数据分析师如同一位现代社会的侦探,肩负着从海量数据中提炼出有价值信息的重任。在这个过程中,掌握一系 ...
2024-12-23在现代的职场中,制作吸引人的PPT已经成为展示信息的重要手段,而其中数据对比的有效呈现尤为关键。为了让数据在幻灯片上不仅准 ...
2024-12-23在信息泛滥的现代社会,数据分析师已成为企业决策过程中不可或缺的角色。他们的任务是从海量数据中提取有价值的洞察,帮助组织制 ...
2024-12-23在数据驱动时代,数据分析已成为各行各业的必需技能。无论是提升个人能力还是推动职业发展,选择一条适合自己的学习路线至关重要 ...
2024-12-23在准备数据分析师面试时,掌握高频考题及其解答是应对面试的关键。为了帮助大家轻松上岸,以下是10个高频考题及其详细解析,外加 ...
2024-12-20互联网数据分析师是一个热门且综合性的职业,他们通过数据挖掘和分析,为企业的业务决策和运营优化提供强有力的支持。尤其在如今 ...
2024-12-20在现代商业环境中,数据分析师是不可或缺的角色。他们的工作不仅仅是对数据进行深入分析,更是协助企业从复杂的数据信息中提炼出 ...
2024-12-20随着大数据时代的到来,数据驱动的决策方式开始受到越来越多企业的青睐。近年来,数据分析在人力资源管理中正在扮演着至关重要的 ...
2024-12-20在数据分析的世界里,表面上的技术操作只是“入门票”,而真正的高手则需要打破一些“看不见的墙”。这些“隐形天花板”限制了数 ...
2024-12-19在数据分析领域,尽管行业前景广阔、岗位需求旺盛,但实际的工作难度却远超很多人的想象。很多新手初入数据分析岗位时,常常被各 ...
2024-12-19入门数据分析,许多人都会感到“难”,但这“难”究竟难在哪儿?对于新手而言,往往不是技术不行,而是思维方式、业务理解和实践 ...
2024-12-19在如今的行业动荡背景下,数据分析师的职业前景虽然面临一些挑战,但也充满了许多新的机会。随着技术的不断发展和多领域需求的提 ...
2024-12-19在信息爆炸的时代,数据分析师如同探险家,在浩瀚的数据海洋中寻觅有价值的宝藏。这不仅需要技术上的过硬实力,还需要一种艺术家 ...
2024-12-19在当今信息化社会,大数据已成为各行各业不可或缺的宝贵资源。大数据专业应运而生,旨在培养具备扎实理论基础和实践能力,能够应 ...
2024-12-19阿里P8、P9失业都找不到工作?是我们孤陋寡闻还是世界真的已经“癫”成这样了? 案例一:本硕都是 985,所学的专业也是当红专业 ...
2024-12-19CDA持证人Louis CDA持证人基本情况 我大学是在一个二线城市的一所普通二本院校读的,专业是旅游管理,非计算机非统计学。毕业之 ...
2024-12-18最近,知乎上有个很火的话题:“一个人为何会陷入社会底层”? 有人说,这个世界上只有一个分水岭,就是“羊水”;还有人说,一 ...
2024-12-18在这个数据驱动的时代,数据分析师的技能需求快速增长。掌握适当的编程语言不仅能增强分析能力,还能帮助分析师从海量数据中提取 ...
2024-12-17在当今信息爆炸的时代,数据分析已经成为许多行业中不可或缺的一部分。想要在这个领域脱颖而出,除了热情和毅力外,你还需要掌握 ...
2024-12-17