人工智能入门:K-近邻(KNN)算法手写体数字识别

发布于 2021-07-27  1628 次阅读


一、前言

如果下面的代码有你看不懂的地方,请善用百度、翻看评论(如有),同时也可以参阅Python官方文档(在友人帐中可以找到链接)。

如果查过、看过了还是不会,请在下方留言,我看到了会给予解答。

同时,请务必自己动手敲一遍代码,自己思考捋一遍逻辑,只有这样才能获得进步。

二、实验目的

1、输出测试集前20项的预测结果
2、输出测试集前20项的图片

三、所需文件

测试集、训练集:https://nyadoo-my.sharepoint.com/:f:/p/blogshare/EhSokcPr2hpJl47JcFVr-t0BK5DPgx6q5l3nlgpAGGCKaw?e=6XC9n7

四、实验方法

1、数据集格式化(函数dataSet

1) loadtxt导入,跳过首行
2) 利用切片索引格式化

2、K近邻数据处理部分(函数classify0

1) 直接利用Numpy中的函数linalg.norm计算欧氏距离
2) 函数Argsort进行排序并返回索引
3) 循环计数,利用函数argmax找出出现次数最多的数字并返回索引(注意计数数组初始化应在循坏外)

3、绘图部分(函数imgshow

1) 循环,每次循环将测试数据中的一行赋值给临时变量。
2) 利用函数resize升维为二维数组
3) 使用函数subplot创建子图。
4) 使用函数title将变量resultclassify0结果返回值)设置为标签,利用函数axis去除坐标轴。
5) 使用函数imshow绘图,函数subplots_adjust设置子图间隔,函数show生成图片。

五、实验代码

#引入所需模块
import numpy as np
import matplotlib.pyplot as plt

#数据集格式化
def dataSet(train, test):
    #首先进行文件的读取(使用Numpy下的loadtxt函数)
    #各参数的具体作用请善用百度
    traindata = np.loadtxt(train, delimiter=',', skiprows=1)
    testdata = np.loadtxt(test, delimiter=',', skiprows=1)
    #下面利用Python的切片索引对读取到的数据进行切分和处理
    #(此处由于我们只测试前20条数据,故testdata只取前20行)
    ytrain = traindata[:, 0]
    xtrain = traindata[:, 1:]
    testdata = testdata[:20, :]
    return ytrain, xtrain, testdata
#K近邻算法进行数据处理
def classify0(x_train, y_train, tmpdata, k):
    #计算欧氏距离
    dist = np.linalg.norm(x_train - tmpdata, axis=1)
    #从小到大排序并返回索引
    minidx = dist.argsort()
    #计数
    count = np.zeros(10)
    for j in range(k):
        count[int(y_train[minidx[j]])] += 1
    return count.argmax()
#绘图
def imgshow(data, width, high, result, k):
    for i in range(k):
        tmp = data[i]
        #利用resize将一维数组升维为二维
        tmp.resize((width, high))
        #创建画布(子图)
        plt.subplot(4, 5, i+1)
        plt.title(result[i])
        plt.axis(off)
        plt.imshow(tmp, cmap='gray')
    #此处设置设置间距是为了让各子图之间的间距增大
    plt.subplots_adjust(hspace=0.5)
    plt.show()

#入口
y_train, x_train, testData = dataSet(train.csv, x_test.csv)
result = []
for i in range(20):
    #按行提取测试数据
    tmpData = testData[i, :]
    result.append(classify0(x_train, y_train, tmpData, 5))
imgshow(testData, 28, 28, result, 20)