kNN及手写数字识别

kNN(k-近邻, k-Nearest Neighbor)算法是一种基本的分类与回归方法,是一个简单的无显示学习过程、非泛化学习的监督学习模型。

算法

  • 获取带有标签的样本数据集
  • 输入没有标签的新数据,将新数据的每个特征与样本集中数据对应的特征进行比较:
    • 计算新数据与样本数据集中每条数据的距离
    • 对求得的所有距离进行升序排序(从小到大,越小表示越相似)
    • 取前 k (k 一般小于等于 20 )个样本数据所对应的分类标签
  • 求 k 个数据中出现次数最多的分类标签作为新数据的分类标签
  • 完成分类

特点

  • 优点:精度高、对异常值不敏感、无数据输入假定
  • 缺点:计算复杂度高、空间复杂度高
  • 适用数据范围:数值型和标称型

k值选取

  • k值小:近似误差小,估计误差大,整体模型变得复杂,容易发生过拟合。
  • k值大:近似误差大,估计误差小,整体的模型变得简单。
  • 近似误差:对训练集的训练误差,关注训练集,近似误差小了会出现过拟合,对现有训练集有很好预测,但对未知测试样本的预测会有较大偏差,模型本身不是最接近最佳模型。
  • 测试误差:对测试集的测试误差,关注测试集,估计误差小了说明对未知数据的预测能力好,模型本身最接近最佳模型。
  • 通过交叉验证(cross validation)来选取适合的k值 。

kNN应用:手写数字的识别

示例

手写数字图片 二进制文件
digit digit

识别结果

1
**1NN实现的手写数字识别**
2
识别成功!这是数字: 9

实现

1
#!/usr/local/bin/python3
2
# -*- coding: utf-8 -*-
3
import numpy as np
4
from os import listdir
5
from PIL import Image
6
7
8
# 读取数据文件
9
# 参数:目录,文件名
10
# 返回:数据矩阵
11
def read_data(f_dir, f_name):
12
    f = open(f_dir + f_name, 'r', encoding='utf8')
13
    f_width, f_height = 32, 32  # 数字数据尺寸为32*32
14
    digit_mat = []
15
    for i in range(f_height):
16
        mat = []
17
        str = f.readline()
18
        for j in range(f_width):
19
            mat.append(int(str[j]))
20
        digit_mat.append(mat)
21
    return digit_mat
22
23
24
# 准备数字数据文件
25
# 参数:训练数字文件目录
26
# 返回:数字样本及类别矩阵
27
def prepare_digits(f_dir):
28
    f_list = listdir(f_dir)
29
    digits = []
30
    for i in f_list:
31
        if i[0].isdigit():
32
            digits.append([read_data(f_dir, i), i[0]])
33
    return digits
34
35
36
# 将图像转为只包含0和1的txt文件,内容区域用1填充
37
# 参数:输入的图片文件,输出的txt文件
38
def img2code(img_file, code_file):
39
    img = Image.open(img_file)
40
    img = img.resize((32, 32))
41
    img = img.convert('1')  # 转换为黑白图像
42
    width, height = img.size
43
    f1 = open(img_file, 'r')
44
    f2 = open(code_file, 'w')
45
    for i in range(height):
46
        for j in range(width):
47
            pixel = int(img.getpixel((j, i)) / 255)  # 获取每个像素值
48
            if pixel == 0:
49
                pixel = 1
50
            elif pixel == 1:
51
                pixel = 0
52
            f2.write(str(pixel))
53
            if j == width - 1:
54
                f2.write('\n')
55
    f1.close()
56
    f2.close()
57
58
59
# 欧式距离计算
60
# 参数:训练数字数据集,未知数字数据
61
# 返回:距离列表(序号, 距离)
62
def get_diatance(training_digits, unknown_digit):
63
    distance = []
64
    i = 0
65
    for t in training_digits:
66
        distance.append([i, np.linalg.norm(np.array(t[0]) - np.array(unknown_digit))])
67
        i = i + 1
68
    return distance
69
70
71
# 手写数字识别
72
def digit_identify():
73
    trainingdigits_dir = './trainingDigits/'  # 训练数据目录
74
    digits = prepare_digits(trainingdigits_dir)  # 训练数据列表
75
    print('**1NN实现的手写数字识别**')
76
    unknown_digit_img = './digits_img/digit.png'  # 图片文件
77
    unknown_digit_txt = './digits_img/digit.txt'  # 转化成二进制文件后的文件名
78
    img2code(unknown_digit_img, unknown_digit_txt)  # 转化成二进制文件
79
    unknown_digit = read_data('./digits_img/', 'digit.txt')  # 读取待识别数据
80
    distance = get_diatance(digits, unknown_digit)  # 计算距离矩阵(序号, 距离)
81
    distance = sorted(distance, key=lambda x:x[1])  # 根据距离排序
82
    k = 5  # k值选取
83
    labels = distance[0:k]  # 截取前k个最相似的数据样本
84
    count = list(np.zeros(10))  # 各个标记的次数列表
85
    # 统计次数
86
    for i in labels:
87
        index = i[0]
88
        l = int(digits[index][1])
89
        count[l] = count[l] + 1
90
    label = count.index(max(count))  # 待识别数据的标记
91
    print('识别成功!这是数字: ' + str(label))
92
93
94
if __name__ == '__main__':
95
    digit_identify()