1 |
|
2 |
|
3 | import numpy as np |
4 | from os import listdir |
5 | from PIL import Image |
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 | def read_data(f_dir, f_name): |
12 | f = open(f_dir + f_name, 'r', encoding='utf8') |
13 | f_width, f_height = 32, 32 |
14 | digit_mat = [] |
15 | for i in range(f_height): |
16 | mat = [] |
17 | str = f.readline() |
18 | for j in range(f_width): |
19 | mat.append(int(str[j])) |
20 | digit_mat.append(mat) |
21 | return digit_mat |
22 |
|
23 |
|
24 |
|
25 |
|
26 |
|
27 | def prepare_digits(f_dir): |
28 | f_list = listdir(f_dir) |
29 | digits = [] |
30 | for i in f_list: |
31 | if i[0].isdigit(): |
32 | digits.append([read_data(f_dir, i), i[0]]) |
33 | return digits |
34 |
|
35 |
|
36 |
|
37 |
|
38 | def img2code(img_file, code_file): |
39 | img = Image.open(img_file) |
40 | img = img.resize((32, 32)) |
41 | img = img.convert('1') |
42 | width, height = img.size |
43 | f1 = open(img_file, 'r') |
44 | f2 = open(code_file, 'w') |
45 | for i in range(height): |
46 | for j in range(width): |
47 | pixel = int(img.getpixel((j, i)) / 255) |
48 | if pixel == 0: |
49 | pixel = 1 |
50 | elif pixel == 1: |
51 | pixel = 0 |
52 | f2.write(str(pixel)) |
53 | if j == width - 1: |
54 | f2.write('\n') |
55 | f1.close() |
56 | f2.close() |
57 |
|
58 |
|
59 |
|
60 |
|
61 |
|
62 | def get_diatance(training_digits, unknown_digit): |
63 | distance = [] |
64 | i = 0 |
65 | for t in training_digits: |
66 | distance.append([i, np.linalg.norm(np.array(t[0]) - np.array(unknown_digit))]) |
67 | i = i + 1 |
68 | return distance |
69 |
|
70 |
|
71 |
|
72 | def digit_identify(): |
73 | trainingdigits_dir = './trainingDigits/' |
74 | digits = prepare_digits(trainingdigits_dir) |
75 | print('**1NN实现的手写数字识别**') |
76 | unknown_digit_img = './digits_img/digit.png' |
77 | unknown_digit_txt = './digits_img/digit.txt' |
78 | img2code(unknown_digit_img, unknown_digit_txt) |
79 | unknown_digit = read_data('./digits_img/', 'digit.txt') |
80 | distance = get_diatance(digits, unknown_digit) |
81 | distance = sorted(distance, key=lambda x:x[1]) |
82 | k = 5 |
83 | labels = distance[0:k] |
84 | count = list(np.zeros(10)) |
85 | |
86 | for i in labels: |
87 | index = i[0] |
88 | l = int(digits[index][1]) |
89 | count[l] = count[l] + 1 |
90 | label = count.index(max(count)) |
91 | print('识别成功!这是数字: ' + str(label)) |
92 |
|
93 |
|
94 | if __name__ == '__main__': |
95 | digit_identify() |