介绍K近邻算法基本原理
K近邻算法
一、什么是k近邻算法
给定一个训练数据集、对新的输入实例,在训练集中找到与该实例最邻近的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类。如下图所示:输入新的❓点判断它属于classA还是classB
那么问题来了k近邻算法市最近邻,那么这个最近邻怎么判断?计算距离!怎么计算距离
二、距离度量
1、欧几里得距离
$$
d(x,y)=\sqrt{\displaystyle\sum_{i=1}^{n}(y_i-x_i)^2}
$$
3、曼哈顿距离
$$
d(x,y)=(\displaystyle\sum_{i=0}^{n}{|y_i-x_i|})
$$
4、闵可夫斯基距离
$$
Minkowski Distance=(\displaystyle\sum_{i=0}^{n}{|y_i-x_i|})^\frac{1}{p}
$$
现在我们知道了如何确定距离,但是问题又来了,我们应该怎么去定义我们的K值呢?
三、K值选择
如果选择较小的 k 值,就相当于用较小的邻域中的训练实例进行预测,“学习”的近似误差(approximation error)会减小,只有与输入实例较近的(相似的)训练实例才会对预测结果起作用。但缺点是“学习”的估计误差(estimation error)会增大,预测结果会对近邻的实例点非常敏感 。如果邻近的实例点恰巧是噪声,预测就会出错。换句话说,k 值的减小就意味着整体模型变得复杂,容易发生过拟合。
如果选择较大的 k 值,就相当于用较大邻域中的训练实例进行预测。其优点是可以减少学习的估计误差,但缺点是学习的近似误差会增大。这时与输入实例较远的(不相似的)训练实例也会对预测起作用,使预测发生错误。k 值的增大就意味着整体的模型变得简单。
如果k =N,那么无论输入实例是什么,都将简单地预测它属于在训练实例中最多的类。这时,模型过于简单,完全忽略训练实例中的大量有用信息,是不可取的。在应用中,k 值一般取一个比较小的数值。通常采用交叉验证法来选取最优的k值。
四、算法流程(python)
第一步:确定KNN算法需要确定的参数:1、输入的待分类变量(训练集);2、k值;3、计算距离的方法;4、训练集数据;4、训练集标签
第二步:计算距离,KNN算法的核心就是计算与每一个变量的距离,而后挑选前K个。所以我们有时候需要对输入数据维度拓展,使其达到和测试集形状相同。
第三步:判断KNN的准确率
五、KNN评价
优势
易于实现:鉴于算法的简单性和准确性,它是新数据科学家将学习的首批分类器之一
轻松适应:随着新训练样本的增加,算法会根据任何新数据进行调整,因为所有训练数据都存储在内存中
很少的超参数:KNN 只需要 k值和距离度量,与其他机器学习算法相比,所需的超参数很少
缺点
不能很好地扩展:由于 KNN 是一种惰性算法,因此与其他分类器相比,它占用了更多的内存和数据存储。 从时间和金钱的角度来看,这可能是昂贵的。 更多的内存和存储将增加业务开支,而更多的数据可能需要更长的时间来计算。 虽然已经创建了不同的数据结构(例如 Ball-Tree)来解决计算效率低下的问题,但分类器是否理想可能取决于业务问题
维度的诅咒:KNN 算法容易成为维度诅咒的受害者,这意味着它在高维数据输入时表现不佳。 这有时也称为峰值现象,在算法达到最佳特征数量后,额外的特征会增加分类错误的数量,尤其是当样本尺寸较小时
容易过拟合:由于”维度的诅咒”,KNN 也更容易过拟合。 虽然利用特征选择和降维技术来防止这种情况发生
六、python代码实现(手写字体识别)
需要调用的包(可以尝试利用pytorch去提高计算速度、原理也很简单pytorch可以和numpy几乎无缝衔接)
1 2 3 4
| import torch import numpy as np import operator from os import listdir
|
核心代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| def KNN(inx,dataset,labels,k,distances_way): """ inx:输入需要分类的数字 datset:输入样本训练集 labels:标签向量 k:选择最近邻的数目,其中标签数量和矩阵dataset的行数相同 distance:计算距离的方式 """ datasize_h = dataset.shape[0] inx = np.tile(inx,(datasize_h,1)) if distances_way == str('o'): diffmat = inx - dataset sq_diffmat = diffmat**2 sq_distances = sq_diffmat.sum(axis=1) distance = sq_distances**0.5 elif distances_way == str('man'): diffmat = inx - dataset abs_diffmat = abs(diffmat) distance = abs_diffmat.sum(axis=1) elif distances_way == str('min'): p = int(input('输入p值:')) diffmat = inx - dataset sq_diffmat = diffmat**2 sq_distances = sq_diffmat.sum(axis=1) distance = sq_distances**(1/p) distance_sort = distance.argsort() dic = {} for i in range(k): diff_label = labels[distance_sort[i]] dic[diff_label] = dic.get(diff_label,0)+1 dic_sort = sorted(dic.items(), key=operator.itemgetter(1),reverse=True) return dic_sort[0][0]
|
手写字体识别(列子来源于《机器学习实战》自己做了部分改变)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| def data_read(path): data = open(path) data_use = np.zeros((1,1024)) for i in range(32): data_line = data.readline() for j in range(32): data_use[0,32*i+j] = int(data_line[j]) return data_use
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| def handwritingClassTest(): """ hwLabels:手写数字真实值 m:训练集文件个数 mTset:测试集的文件个数 trainingMat:存储训练集的全部一维化的数据 """ hwLabels = [] trainingFileList = listdir('D:/Github-code/python-gogogo/机器学习/分类问题/K近邻算法/手写数据文件/training_handwriting') m = len(trainingFileList) trainingMat = np.zeros((m, 1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) trainingMat[i,:] = data_read('D:/Github-code/python-gogogo/机器学习/分类问题/K近邻算法/手写数据文件/training_handwriting/%s'% fileNameStr)
testFileList = listdir('D:/Github-code/python-gogogo/机器学习/分类问题/K近邻算法/手写数据文件/test_handwriting') errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = data_read('D:/Github-code/python-gogogo/机器学习/分类问题/K近邻算法/手写数据文件/test_handwriting/%s' % fileNameStr) print(vectorUnderTest.shape) classifierResult = KNN(vectorUnderTest, trainingMat, hwLabels, 3,'o') print ("KNN分类结果: %d, 实际结果: %d" % (classifierResult, classNumStr)) if (classifierResult != classNumStr): errorCount += 1.0 print ("\n错误数字数量: %d" % errorCount) print ("\n错误比率: %f" % (errorCount/float(mTest)))
|
七、个人思考
既然过程中涉及到了pytorch,并且做的是”手写字体识别“,那么的话我们不妨自己尝试使用自己随便拍一张照片,而后去识别自己手写字体(将图片二值化,我暂时想到这样,因为深度学习自己也不是太过了了解),如果利用KNN算法可能效率没有使用CNN算法效率那么高,其中可能还会涉及到opencv库的使用。
评论区大佬有想法不妨踢我哈哈哈哈哈。😀😀文章有什么不足欢迎留言。
参考
1、https://www.ibm.com/cn-zh/topics/knn#:~:text=k%2D%E6%9C%80%E8%BF%91%E9%82%BB%E7%AE%97%E6%B3%95%EF%BC%8C%E4%B9%9F,%E6%9C%80%E5%B8%B8%E8%A1%A8%E7%A4%BA%E7%9A%84%E6%A0%87%E7%AD%BE%E3%80%82
2、李航《统计学学习方法第二版》