卷积神经网络手写数字识别(基于TensorFlow),数据:MNIST库(手写体数字库),包含55000张训练图片,每张图片分辨率是28×28,故训练网络输入应该是28×28=784个像素数据。
训练、保存、测试模型
1
from PIL import Image
2
import matplotlib.pyplot as plt
3
import tensorflow as tf
4
from tensorflow.examples.tutorials.mnist import input_data
5
6
# 导入 MNIST 数据集
7
mnist = input_data.read_data_sets('./MNIST_data/', one_hot=True)
8
9
# 训练参数
10
learning_rate = 0.001
11
num_steps = 5000
12
batch_size = 128
13
display_step = 10
14
15
# 神经网络网络参数
16
num_input = 784 # MNIST 输入数据规格(img shape: 28*28)
17
num_classes = 10 # MNIST 类数(0-9 digits)
18
dropout = 0.75 # Dropout, 保留神经元的概率
19
20
# TensorFlow输入
21
X = tf.placeholder(tf.float32, [None, num_input])
22
Y = tf.placeholder(tf.float32, [None, num_classes])
23
keep_prob = tf.placeholder(tf.float32) # dropout (保留概率)
24
25
26
# 卷积
27
def conv2d(x, W, b, strides=1):
28
x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
29
x = tf.nn.bias_add(x, b)
30
return tf.nn.relu(x)
31
32
33
# 池化
34
def maxpool2d(x, k=2):
35
return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME')
36
37
38
# 创建卷积神经网络模型
39
def conv_net(x, weights, biases, dropout):
40
# MNIST data input is a 1-D vector of 784 features (28*28 pixels)
41
# Reshape to match picture format [Height x Width x Channel]
42
# Tensor input become 4-D: [Batch Size, Height, Width, Channel]
43
x = tf.reshape(x, shape=[-1, 28, 28, 1])
44
45
# 第一层卷积和池化
46
conv1 = conv2d(x, weights['wc1'], biases['bc1'])
47
conv1 = maxpool2d(conv1, k=2)
48
49
# 第二层卷积和池化
50
conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
51
conv2 = maxpool2d(conv2, k=2)
52
53
# 全连接层
54
# Reshape conv2 output to fit fully connected layer input
55
fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])
56
fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])
57
fc1 = tf.nn.relu(fc1)
58
fc1 = tf.nn.dropout(fc1, rate=1-dropout) # Apply Dropout
59
60
# 输出层
61
out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])
62
return out
63
64
65
# 每层的 weight 和 bias
66
weights = {
67
# 5x5 conv, 1 input, 32 outputs
68
'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32])),
69
# 5x5 conv, 32 inputs, 64 outputs
70
'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64])),
71
# fully connected, 7*7*64 inputs, 1024 outputs
72
'wd1': tf.Variable(tf.random_normal([7*7*64, 1024])),
73
# 1024 inputs, 10 outputs (class prediction)
74
'out': tf.Variable(tf.random_normal([1024, num_classes]))
75
}
76
biases = {
77
'bc1': tf.Variable(tf.random_normal([32])),
78
'bc2': tf.Variable(tf.random_normal([64])),
79
'bd1': tf.Variable(tf.random_normal([1024])),
80
'out': tf.Variable(tf.random_normal([num_classes]))
81
}
82
83
# 建立模型
84
logits = conv_net(X, weights, biases, keep_prob)
85
prediction = tf.nn.softmax(logits)
86
87
# 损失函数与优化器
88
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
89
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
90
train_op = optimizer.minimize(loss_op)
91
92
# 评估模型
93
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
94
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
95
96
# 初始化变量
97
init = tf.global_variables_initializer()
98
99
# 定义saver: 保存模型
100
saver = tf.train.Saver()
101
102
103
# 训练模型
104
def train_model():
105
with tf.Session() as sess:
106
sess.run(init)
107
108
for step in range(1, num_steps+1):
109
batch_x, batch_y = mnist.train.next_batch(batch_size)
110
# 反向传播优化
111
sess.run(train_op, feed_dict={X: batch_x, Y: batch_y, keep_prob: dropout})
112
if step % display_step == 0 or step == 1:
113
# Calculate batch loss and accuracy
114
loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x, Y: batch_y, keep_prob: 1.0})
115
print("Step " + str(step) + ", Minibatch Loss= " + "{:.4f}".format(loss) + ", Training Accuracy= " + "{:.3f}".format(acc))
116
saver.save(sess, 'CNN_handwritten_digits_model') # 储存模型
117
print("Optimization Finished!")
118
119
# 计算测试精度(256 MNIST test images)
120
print("Testing Accuracy:", sess.run(accuracy, feed_dict={X: mnist.test.images[:256], Y: mnist.test.labels[:256], keep_prob: 1.0}))
121
122
123
# 图片预处理
124
def image_prepare(picture):
125
image = Image.open(picture) # 读取图片(28*28)
126
plt.imshow(image) # 显示需要识别的图片
127
plt.show()
128
image = image.convert('L')
129
image = image.resize((28, 28), Image.ANTIALIAS)
130
image_data = list(image.getdata())
131
image_data = [0 if x < 50 else 1 for x in image_data] # 如果黑底白字
132
# image_data = [1 if x < 50 else 0 for x in image_data] # 如果白底黑字
133
# cnt = 0
134
# for i in image_data:
135
# print(i, end=' ')
136
# cnt += 1
137
# if cnt % 28 == 0:
138
# print(end='\n')
139
return image_data
140
141
142
# 待识别数字图片
143
image = image_prepare('./digits/62.png')
144
145
# 训练模型
146
# train_model()
147
148
# 测试模型
149
with tf.Session() as sess:
150
saver = tf.train.import_meta_graph('CNN_handwritten_digits_model.meta')
151
saver.restore(sess, tf.train.latest_checkpoint('./'))
152
153
prediction = tf.argmax(prediction, 1)
154
predint = prediction.eval(feed_dict={X: [image], keep_prob: 1.0}, session=sess)
155
156
print('识别结果:', end=' ')
157
print(predint[0])