当前位置: 首页 > 后端技术 > Python

机器学习笔记(二)——手写数字识别的KNN算法

时间:2023-03-26 18:33:00 Python

算法介绍手写数字识别是KNN算法中一个特别经典的例子。获取数据源有两种方式,一种是从MNIST数据集获取,另一种是从UCIIrvineUniversityMachineLearningRepository下载,本文以后者为例进行讲解。基本思想是使用KNN算法推导出0到9之间的哪个数字由一个32x32的二进制矩阵表示,如下图所示。数据集由两部分组成,一是训练数据集,共有1934条数据;另一个是测试数据集,共有946条数据。所有数据命名格式统一,例如编号为5的第56个样本-5_56.txt,以方便提取样本的真实标签。数据格式也有两种,一种是像上图这样由0和1组成的文本文件;另一种是手写数字图片,需要处理转换成和上图一样的格式,下面分别介绍。算法步骤收集数据:打开数据源分析数据,构思如何处理数据,导入训练数据,转换成结构化数据格式,计算距离(欧式距离),导入测试数据,计算模型准确率,手写数字,实际套用模型,因为所有的数据都是由0和1组成的,所以不需要对数据进行标准化和归一化。该算法实现了处理数据。在计算两个样本之间的距离时,各个属性是一一对应的,所以这里将32x32的数字矩阵转换成1x1024的数字矩阵,方便计算样本之间的距离。#处理文本文件defimg_deal(file):#创建一个1*1024的一维零矩阵the_matrix=np.zeros((1,1024))fb=open(file)foriinrange(32):#linebylineReadlineStr=fb.readline()forjinrange(32):#给一维零矩阵赋值32*32=1024个元素the_matrix[0,32*i+j]=int(lineStr[j])returnthe_matrix计算欧几里得距离numpy有一个tile方法,可以将一个一维矩阵水平复制几次,垂直复制多次,所以一个测试数据经过tile方法处理后,减去训练数据,新的之后得到矩阵,将矩阵中的每条数据进行(水平)平方和和根号后,就可以得到测试数据和每条训练数据的距离。接下来就是对所有的距离进行升序排序,得到第一个K,在这个范围内,每个数字类别的个数,返回出现次数较多的数字类别的标签。defclassify(test_data,train_data,label,k):Size=train_data.shape[0]#每行测试数据减去训练数据复制Size次,水平复制Size次,垂直复制1次the_matrix=np.tile(test_data,(Size,1))-train_data#对减法得到的结果进行平方sq_the_matrix=the_matrix**2#平方和,axis=1表示水平all_the_matrix=sq_the_matrix.sum(axis=1)#结果是平方根到得到最终的Distancedistance=all_the_matrix**0.5#将距离从小到大排序,结果为索引sort_distance=distance.argsort()dis_Dict={}#获取第一个kforiinrange(k):#获取topK个labelthe_label=label[sort_distance[i]]#将label的key和value引入字典dis_Dict[the_label]=dis_Dict.get(the_label,0)+1#根据大小对字典进行排序值,从大到小,即在K的范围内,过滤出现次数最多的标签sort_Count=sorted(dis_Dict.items(),key=operator.itemgetter(1),reverse=True)#返回出现次数最多的标签returnsort_Count[0][0]测试数据集应用必须先处理训练数据集。listdir方法是返回一个文件夹下的所有文件,然后生成行数为文件数,列数为1024的训练数据矩阵,收集训练数据。每条数据的实际labelcut抽取保存在labels列表中,即传入的label计算与classify函数的距离。labels=[]#listdir方法是返回一个文件夹中包含的文件train_data=listdir('trainingDigits')#获取文件夹中的文件个数m_train=len(train_data)#生成一个列号作为train_matrix,第1024行零矩阵train_matrix=np.zeros((m_train,1024))foriinrange(m_train):file_name_str=train_data[i]file_str=file_name_str.split('.')[0]#切出训练中的每个数据setThereallabelfile_num=int(file_str.split('_')[0])labels.append(file_num)#将训练数据集中的所有数据传入train_matrixtrain_matrix[i,:]=img_deal('trainingDigits/%s'%file_name_str)然后对测试和训练数据集做和上面一样的处理,传入测试数据矩阵TestClassify,训练数据矩阵train_matrix,训练数据真实标签labels,K,共4个参数进入计算距离分类函数,最后计算模型准确率并输出错误的预测数据。error=[]test_matrix=listdir('testDigits')正确=0.0m_test=len(test_matrix)foriinrange(m_test):file_name_str=test_matrix[i]file_str=file_name_str.split('.')[0]#test数据集中每条数据的真实结果file_num=int(file_str.split('_')[0])TestClassify=img_deal('testDigits/%s'%file_name_str)classify_result=classify(TestClassify,train_matrix,labels,3)print('预测结果:%s\真实结果:%s'%(classify_result,file_num))ifclassify_result==file_num:correct+=1.0else:error.append((file_name_str,classify_result))print("正确率:{:.2f}%".format(correct/float(m_test)*100))print(error)print(len(error))代码运行部分截图如下。当K=3时,准确率达到98.94%。对于这个模型来说,准确率非常可观,但是运行效率比较低,接近30秒的运行时间。由于每个测试数据需要从近2000个训练数据中计算,每次计算包含1024维浮点运算,高频多维计算是导致模型效率低下的主要原因。K值下图展示了K值与模型精度的关系。当K=3时,模型精度达到峰值。随着K的增加,精度越来越小,所以这个数据的噪声还是比较小的。.手写数字测试建模完成,模型的准确率也不错。为什么要自己测试手写号?所以我手动写了几个数字。正常拍出来的图片都是RGB彩色图片,像素点也不一样,所以需要对图片做两个处理:转成黑白图,把像素点转成32x32,这样就满足我们的要求了。上述算法的要求;对于像素来说,这个值一般在0-255之间,255代表白色,0代表黑色,但是因为手写数字像素的颜色不规范,所以我们设置一个阈值来判断黑色和白色。图片转文字代码如下:defpic_txt():foriinrange(0,10):img=Image.open('.\handwritten\%s.png'%i)#将图片像素改为32X32img=img.resize((32,32))#将彩色图片转换成黑白图片img=img.convert('L')#保存路径='.\handwritten\%s_new.jpg'%iimg.为我在范围(0,10)中保存(路径):fb=open('.\hand_written\%s_handwritten.txt'%i,'w')new_img=Image.open('.\handwritten\%s_new.jpg'%i)#读取图片的宽高width,height=new_img.sizeforiinrange(height):forjinrange(width):#获取像素颜色=new_img.getpixel((j,i))#像素比较图中高的是白色ifcolor>170:fb.write('0')else:fb.write('1')fb.write('\n')fb.close()的整体代码运行截图如下:正确率为70%。毕竟测试数据很少。10个数字中,4、7、8这三个数字预测错了,还不错;对测试结果有一定的影响。如果避免类似情况,增加更多的测试数据,准确率肯定会提高。公众号【奶糖猫】后台回复“手写数字”即可获取源码和数据供参考,感谢阅读。