jjzjj

利用 python 实现 KNN 算法(自己实现 和 sklearn)

白御空 2023-12-21 原文

利用 python 实现 KNN 算法(自己实现 和 sklearn)

创作背景

昨天有个朋友请我帮他做一个 python 的作业,作业要求如下图(翻译过)

也就是:

给定了数据集,使用 KNN 算法完成下列目标

  1. 编写 自己的 代码实现 KNN 并且用绘制图像
  2. 使用 sklearn 绘制图像(使用 KNeighborsClassifier 进行分类)

绘制的图像效果如下

  • 偷偷说一句:如果对我的答案和解析满意的话可不可以给我 点个赞点个收藏 之类的
  • Let's do it !!!

思路讲解

先开始我很懵,毕竟我也没怎么学过 KNN ,只是大概了解这个算法,想必来看文章的你也是有点不知所云,所以我们就先了解一下这个算法。

了解算法

KNN ,全称是 K-NearestNeighbors ,直译过来就是 K 个距离最近的邻居 ,专业术语是 K 最近邻分类算法
俗话说的好,物以类聚,人以群分 ,这个算法也是体现了这个思想,说的是每个样本的类别都可以用 离它最接近的 K 个邻近值的类别 来代表。
拿最常用的一个例子来说,看下边这一张图

我们要判断 绿色的圆形 也就是未知的数据属于哪个类别,我们就可以根据离它最近的几个点的类别来判断。

  • 如果 k = 3 ,也就是我们要看离这个点最近的 3 个点(如实心⚪圈住的点),其中 2 个红色三角形1 个蓝色正方形 ,那我们就可以判断这个未知的点属于 红色三角形 ,因为离它最近的三个点中 红色三角形 的点数量多。
  • 如果 k = 5 ,也就是我们要看离这个点最近的 5 个点(如虚线⚪圈住的点),其中 3 个蓝色正方形红色三角形 的数量还是 2 个 ,这时候形势逆转,那现在我们就认为未知点属于 蓝色正方形

上边的例子应该很好理解,其他数据也是类似。

作业思路(自己实现)

知道了 KNN 是怎么回事了以后我们就可以来做作业了。

第一步

Of course,导库 ,这次我们用到的库有 numpy矩阵操作pandas读取数据collections统计数量matplotlib绘图

import numpy as np
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt

第二步

我们要 查看 一下作业 数据 ,并且进行 数据预处理 ,数据如下图所示(部分)

  • 读取完毕后的数据,其中,x.1x.2 分别是每个点的 横纵坐标y 是该点对应的 类别 ,取值为 01
  • 数据预处理,即将点的坐标转换为 二维数组np.concatenate 进行矩阵合并,axis=1 指定 横向合并 。代码如下(为了方便讲解代码逻辑,所以把一段长代码分为不同的行,文章后边也一样):
spots = np.concatenate(
	[
		np.array(df['x.1']).reshape(-1, 1), 
		np.array(df['x.2']).reshape(-1, 1)
	],
	axis=1
)
  • 画一下 散点图 ,看一下数据分布,代码如下
for i, fig in enumerate([('#87CEEB', '.'), ('orange', 'x')]):
	
	# 找到对应分类的点
    data = df.where(df['y'] == i).dropna()
	
	# 绘制散点图
    plt.scatter(data['x.1'], data['x.2'], marker=fig[1], color=fig[0])
    
plt.show()

第三步

读取完数据后就到了第三步,利用 python 实现 knn

  • 这里我们计算点之间的 欧式距离 ,并以此作为评判标准。
  • 为了提高代码的 复用性 ,我将算法封装成函数,参数为 要预测的点的坐标k 值,代码如下:
def take_nearest(grid, k):
	'''
	对传入的点进行 knn 分类
	
	:param grid: 点的坐标
	:type grid : tuple
	
	:param k: 邻居个数
	:type k : int

	:return : 点的分类
	'''
   	
   	# 计算所有已知点距离未知点的距离,即实现 欧氏距离 的计算
    distance = np.sqrt(
    	np.sum((spots - grid) ** 2, axis=1)
	)
    
    # 类别判断
    # 具体细节见下述
    cate = Counter(
    	np.take(
    		df['y'],
    		distance.argsort()[:k]
   		)
   	).most_common(1)[0][0]
    
    return cate
  • 其中:
    • distance.argsort() 得到 排序后的列表对应数据索引[:k]前 k 个 元素
    • np.take 根据第二个参数 条件 取第一个参数 数据 中对应的数据
    • Counter 计算序列中 每个类别出现的频率
    • most_common(1)频率最高类别数量

第四步

  • 这函数也弄完了,可是这题目到底要用 KNN 分类什么点呀?
  • 我当时已知没搞明白。后来,看了看上边要的效果图我才终于明白分类什么点。
  • 如果你仔细看题目要求的图就会发现图的背景是 像素点,根据不同的分类,像素点的 颜色 也不同,代表 两个不同的分类
  • 那我们就有 方向 了(插一句,选对方向 对于学习之路很重要,要不然会 找不到前进的方向)。

这一步,我们应该 生成像素点 。图中的像素点之间的间隔为 0.2 ,所以我们可以 生成 两个差值为 0.2等差数列 ,然后使用 np.meshgrid 生成网格点坐标矩阵。代码如下:

In[]:	# 生成背景像素点
		bg_x, bg_y = np.meshgrid(
			np.arange(-3.0, 5.2, 0.2),
			np.arange(-2.0, 3.2, 0.2)
		)
		
		# 拼接成二维矩阵
		bg_spots = np.concatenate(
			[
				bg_x.reshape(-1, 1),
				bg_y.reshape(-1, 1)
			],
			axis=1
		)
		
		bg_x, bg_y, bg_spots

---------------------------------------------------------------------------------

Out[]:	(array([[-3. , -2.8, -2.6, ...,  4.6,  4.8,  5. ],
	        [-3. , -2.8, -2.6, ...,  4.6,  4.8,  5. ],
	        [-3. , -2.8, -2.6, ...,  4.6,  4.8,  5. ],
	        ...,
	        [-3. , -2.8, -2.6, ...,  4.6,  4.8,  5. ],
	        [-3. , -2.8, -2.6, ...,  4.6,  4.8,  5. ],
	        [-3. , -2.8, -2.6, ...,  4.6,  4.8,  5. ]]),
		 array([[-2. , -2. , -2. , ..., -2. , -2. , -2. ],
		        [-1.8, -1.8, -1.8, ..., -1.8, -1.8, -1.8],
		        [-1.6, -1.6, -1.6, ..., -1.6, -1.6, -1.6],
		        ...,
		        [ 2.6,  2.6,  2.6, ...,  2.6,  2.6,  2.6],
		        [ 2.8,  2.8,  2.8, ...,  2.8,  2.8,  2.8],
		        [ 3. ,  3. ,  3. , ...,  3. ,  3. ,  3. ]]),
		 array([[-3. , -2. ],
		        [-2.8, -2. ],
		        [-2.6, -2. ],
		        ...,
		        [ 4.6,  3. ],
		        [ 4.8,  3. ],
		        [ 5. ,  3. ]]))

第五步

利用 第三步 封装的函数对每个像素点的类别进行判断,代码如下:

In[]:	bg_spots_df = pd.DataFrame(
			np.concatenate(
				[
					bg_spots,
					np.array(
						list(map(lambda x: take_nearest(x, 1), bg_spots))
					).reshape(-1, 1)
				], 
				axis=1
			),
			columns=data1.columns)
		
		bg_spots_df


---------------------------------------------------------------------------------

Out[]:	 	x.1 	x.2 	y
		0 	-3.0 	-2.0 	0.0
		1 	-2.8 	-2.0 	0.0
		2 	-2.6 	-2.0 	0.0
		3 	-2.4 	-2.0 	0.0
		4 	-2.2 	-2.0 	0.0
		... 	... 	... 	...
		1061 	4.2 	3.0 	1.0
		1062 	4.4 	3.0 	1.0
		1063 	4.6 	3.0 	1.0
		1064 	4.8 	3.0 	0.0
		1065 	5.0 	3.0 	0.0
		
		1066 rows × 3 columns

其中:

  • list(map(lambda x: take_nearest(x, 1), bg_spots)) 是对每个像素点进行 KNN 分类 ,并将结果存为列表,这是 k=1 的情况,如果要变化 k 值,则改为 take_nearest(x, 【k 值】)
  • np.array().reshape(-1, 1) 将分类结果转换为 n 行 1 列 的二维矩阵
  • np.concatenate() 将背景点的坐标与分类进行对应
  • pd.DataFrame() 将结果转换为 DataFrame ,为了 方便绘图

第六步(The Final Step)

这一步也是最后一步,进行 绘图 ,代码如下:

for i, fig in enumerate([('#87CEEB', '.'), ('orange', 'x')]):
    
    # 查找对应分类的数据点
    spot = data1.where(data1['y'] == i).dropna()
    
    # 查找对应分类的背景点
    bg_spot = bg_spots_df.where(bg_spots_df['y'] == i).dropna()
	
	# 绘制散点图
    plt.scatter(bg_spot['x.1'], bg_spot['x.2'], s=0.2, color=fig[0])
    plt.scatter(spot['x.1'], spot['x.2'], marker=fig[1], color=fig[0])
    
plt.show()

绘制的图像如下

  • k = 1
  • k = 15

效果还算不错😁

使用 sklearn 实现

这就简单许多,因为 sklearn 已经封装好了 KNN 算法,我们只需要调用即可,代码如下:

from sklearn.metrics import accuracy_score
from sklearn.metrics import mean_squared_error
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split

# 将数据分为训练集和测试集,用来测试模型分类正确率
train_set, test = train_test_split(deepcopy(df), test_size = 0.2, random_state = 42)

def train(k=1):
	
	# 创建分类器
    clf = KNeighborsClassifier(n_neighbors=k)
    
	# 训练数据
    clf.fit(train_set[train_set.columns[:-1]], train_set['y'])
    
	# 测试数据
    test_predictions = clf.predict(test[test.columns[:-1]])
    print('Accuracy:', accuracy_score(test['y'], test_predictions))
    print('MSE:', mean_squared_error(test['y'], test_predictions))
    
    # 预测数据,绘图
    for i, fig in enumerate([('#87CEEB', '.'), ('orange', 'x')]):
    	
	    spots = pd.DataFrame(np.take(bg_spots, np.where(clf.predict(bg_spots) == i)[0], axis=0))
	    
	    plt.scatter(spots[0], spots[1], s=0.2, marker=fig[1], color=fig[0])
    



结尾

以上就是我要分享的内容,因为学识尚浅,会有不足,还请各位大佬指正。
有什么问题也可在评论区留言。

有关利用 python 实现 KNN 算法(自己实现 和 sklearn)的更多相关文章

  1. python - 如何使用 Ruby 或 Python 创建一系列高音调和低音调的蜂鸣声? - 2

    关闭。这个问题是opinion-based.它目前不接受答案。想要改进这个问题?更新问题,以便editingthispost可以用事实和引用来回答它.关闭4年前。Improvethisquestion我想在固定时间创建一系列低音和高音调的哔哔声。例如:在150毫秒时发出高音调的蜂鸣声在151毫秒时发出低音调的蜂鸣声200毫秒时发出低音调的蜂鸣声250毫秒的高音调蜂鸣声有没有办法在Ruby或Python中做到这一点?我真的不在乎输出编码是什么(.wav、.mp3、.ogg等等),但我确实想创建一个输出文件。

  2. ruby - 如何根据特征实现 FactoryGirl 的条件行为 - 2

    我有一个用户工厂。我希望默认情况下确认用户。但是鉴于unconfirmed特征,我不希望它们被确认。虽然我有一个基于实现细节而不是抽象的工作实现,但我想知道如何正确地做到这一点。factory:userdoafter(:create)do|user,evaluator|#unwantedimplementationdetailshereunlessFactoryGirl.factories[:user].defined_traits.map(&:name).include?(:unconfirmed)user.confirm!endendtrait:unconfirmeddoenden

  3. Python 相当于 Perl/Ruby ||= - 2

    这个问题在这里已经有了答案:关闭10年前。PossibleDuplicate:Pythonconditionalassignmentoperator对于这样一个简单的问题表示歉意,但是谷歌搜索||=并不是很有帮助;)Python中是否有与Ruby和Perl中的||=语句等效的语句?例如:foo="hey"foo||="what"#assignfooifit'sundefined#fooisstill"hey"bar||="yeah"#baris"yeah"另外,类似这样的东西的通用术语是什么?条件分配是我的第一个猜测,但Wikipediapage跟我想的不太一样。

  4. java - 什么相当于 ruby​​ 的 rack 或 python 的 Java wsgi? - 2

    什么是ruby​​的rack或python的Java的wsgi?还有一个路由库。 最佳答案 来自Python标准PEP333:Bycontrast,althoughJavahasjustasmanywebapplicationframeworksavailable,Java's"servlet"APImakesitpossibleforapplicationswrittenwithanyJavawebapplicationframeworktoruninanywebserverthatsupportstheservletAPI.ht

  5. 区块链之加解密算法&数字证书 - 2

    目录一.加解密算法数字签名对称加密DES(DataEncryptionStandard)3DES(TripleDES)AES(AdvancedEncryptionStandard)RSA加密法DSA(DigitalSignatureAlgorithm)ECC(EllipticCurvesCryptography)非对称加密签名与加密过程非对称加密的应用对称加密与非对称加密的结合二.数字证书图解一.加解密算法加密简单而言就是通过一种算法将明文信息转换成密文信息,信息的的接收方能够通过密钥对密文信息进行解密获得明文信息的过程。根据加解密的密钥是否相同,算法可以分为对称加密、非对称加密、对称加密和非

  6. 华为OD机试用Python实现 -【明明的随机数】 2023Q1A - 2

    华为OD机试题本篇题目:明明的随机数题目输入描述输出描述:示例1输入输出说明代码编写思路最近更新的博客华为od2023|什么是华为od,od薪资待遇,od机试题清单华为OD机试真题大全,用Python解华为机试题|机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南华为o

  7. python - 如何读取 MIDI 文件、更改其乐器并将其写回? - 2

    我想解析一个已经存在的.mid文件,改变它的乐器,例如从“acousticgrandpiano”到“violin”,然后将它保存回去或作为另一个.mid文件。根据我在文档中看到的内容,该乐器通过program_change或patch_change指令进行了更改,但我找不到任何在已经存在的MIDI文件中执行此操作的库.他们似乎都只支持从头开始创建的MIDI文件。 最佳答案 MIDIpackage会为您完成此操作,但具体方法取决于midi文件的原始内容。一个MIDI文件由一个或多个音轨组成,每个音轨是十六个channel中任何一个上的

  8. 基于C#实现简易绘图工具【100010177】 - 2

    C#实现简易绘图工具一.引言实验目的:通过制作窗体应用程序(C#画图软件),熟悉基本的窗体设计过程以及控件设计,事件处理等,熟悉使用C#的winform窗体进行绘图的基本步骤,对于面向对象编程有更加深刻的体会.Tutorial任务设计一个具有基本功能的画图软件**·包括简单的新建文件,保存,重新绘图等功能**·实现一些基本图形的绘制,包括铅笔和基本形状等,学习橡皮工具的创建**·设计一个合理舒适的UI界面**注明:你可能需要先了解一些关于winform窗体应用程序绘图的基本知识,以及关于GDI+类和结构的知识二.实验环境Windows系统下的visualstudio2017C#窗体应用程序三.

  9. 「Python|Selenium|场景案例」如何定位iframe中的元素? - 2

    本文主要介绍在使用Selenium进行自动化测试或者任务时,对于使用了iframe的页面,如何定位iframe中的元素文章目录场景描述解决方案具体代码场景描述当我们在使用Selenium进行自动化测试的时候,可能会遇到一些界面或者窗体是使用HTML的iframe标签进行承载的。对于iframe中的标签,如果直接查找是无法找到的,会抛出没有找到元素的异常。比如近在咫尺的例子就是,CSDN的登录窗体就是使用的iframe,大家可以尝试通过F12开发者模式查看到的tag_name,class_name,id或者xpath来定位中的页面元素,会抛出NoSuchElementException异常。解决

  10. MIMO-OFDM无线通信技术及MATLAB实现(1)无线信道:传播和衰落 - 2

     MIMO技术的优缺点优点通过下面三个增益来总体概括:阵列增益。阵列增益是指由于接收机通过对接收信号的相干合并而活得的平均SNR的提高。在发射机不知道信道信息的情况下,MIMO系统可以获得的阵列增益与接收天线数成正比复用增益。在采用空间复用方案的MIMO系统中,可以获得复用增益,即信道容量成倍增加。信道容量的增加与min(Nt,Nr)成正比分集增益。在采用空间分集方案的MIMO系统中,可以获得分集增益,即可靠性性能的改善。分集增益用独立衰落支路数来描述,即分集指数。在使用了空时编码的MIMO系统中,由于接收天线或发射天线之间的间距较远,可认为它们各自的大尺度衰落是相互独立的,因此分布式MIMO

随机推荐