卷积神经网络手写数字识别(基于TensorFlow),数据:MNIST库(手写体数字库),包含55000张训练图片,每张图片分辨率是28×28,故训练网络输入应该是28×28=784个像素数据。

训练、保存、测试模型
1from PIL import Image2import matplotlib.pyplot as plt3import tensorflow as tf4from tensorflow.examples.tutorials.mnist import input_data56# 导入 MNIST 数据集7mnist = input_data.read_data_sets('./MNIST_data/', one_hot=True)89# 训练参数10learning_rate = 0.00111num_steps = 500012batch_size = 12813display_step = 101415# 神经网络网络参数16num_input = 784 # MNIST 输入数据规格(img shape: 28*28)17num_classes = 10 # MNIST 类数(0-9 digits)18dropout = 0.75 # Dropout, 保留神经元的概率1920# TensorFlow输入21X = tf.placeholder(tf.float32, [None, num_input])22Y = tf.placeholder(tf.float32, [None, num_classes])23keep_prob = tf.placeholder(tf.float32) # dropout (保留概率)242526# 卷积27def conv2d(x, W, b, strides=1):28x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')29x = tf.nn.bias_add(x, b)30return tf.nn.relu(x)313233# 池化34def maxpool2d(x, k=2):35return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME')363738# 创建卷积神经网络模型39def conv_net(x, weights, biases, dropout):40# MNIST data input is a 1-D vector of 784 features (28*28 pixels)41# Reshape to match picture format [Height x Width x Channel]42# Tensor input become 4-D: [Batch Size, Height, Width, Channel]43x = tf.reshape(x, shape=[-1, 28, 28, 1])4445# 第一层卷积和池化46conv1 = conv2d(x, weights['wc1'], biases['bc1'])47conv1 = maxpool2d(conv1, k=2)4849# 第二层卷积和池化50conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])51conv2 = maxpool2d(conv2, k=2)5253# 全连接层54# Reshape conv2 output to fit fully connected layer input55fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])56fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])57fc1 = tf.nn.relu(fc1)58fc1 = tf.nn.dropout(fc1, rate=1-dropout) # Apply Dropout5960# 输出层61out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])62return out636465# 每层的 weight 和 bias66weights = {67# 5x5 conv, 1 input, 32 outputs68'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32])),69# 5x5 conv, 32 inputs, 64 outputs70'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64])),71# fully connected, 7*7*64 inputs, 1024 outputs72'wd1': tf.Variable(tf.random_normal([7*7*64, 1024])),73# 1024 inputs, 10 outputs (class prediction)74'out': tf.Variable(tf.random_normal([1024, num_classes]))75}76biases = {77'bc1': tf.Variable(tf.random_normal([32])),78'bc2': tf.Variable(tf.random_normal([64])),79'bd1': tf.Variable(tf.random_normal([1024])),80'out': tf.Variable(tf.random_normal([num_classes]))81}8283# 建立模型84logits = conv_net(X, weights, biases, keep_prob)85prediction = tf.nn.softmax(logits)8687# 损失函数与优化器88loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))89optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)90train_op = optimizer.minimize(loss_op)9192# 评估模型93correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))94accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))9596# 初始化变量97init = tf.global_variables_initializer()9899# 定义saver: 保存模型100saver = tf.train.Saver()101102103# 训练模型104def train_model():105with tf.Session() as sess:106sess.run(init)107108for step in range(1, num_steps+1):109batch_x, batch_y = mnist.train.next_batch(batch_size)110# 反向传播优化111sess.run(train_op, feed_dict={X: batch_x, Y: batch_y, keep_prob: dropout})112if step % display_step == 0 or step == 1:113# Calculate batch loss and accuracy114loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x, Y: batch_y, keep_prob: 1.0})115print("Step " + str(step) + ", Minibatch Loss= " + "{:.4f}".format(loss) + ", Training Accuracy= " + "{:.3f}".format(acc))116saver.save(sess, 'CNN_handwritten_digits_model') # 储存模型117print("Optimization Finished!")118119# 计算测试精度(256 MNIST test images)120print("Testing Accuracy:", sess.run(accuracy, feed_dict={X: mnist.test.images[:256], Y: mnist.test.labels[:256], keep_prob: 1.0}))121122123# 图片预处理124def image_prepare(picture):125image = Image.open(picture) # 读取图片(28*28)126plt.imshow(image) # 显示需要识别的图片127plt.show()128image = image.convert('L')129image = image.resize((28, 28), Image.ANTIALIAS)130image_data = list(image.getdata())131image_data = [0 if x < 50 else 1 for x in image_data] # 如果黑底白字132# image_data = [1 if x < 50 else 0 for x in image_data] # 如果白底黑字133# cnt = 0134# for i in image_data:135# print(i, end=' ')136# cnt += 1137# if cnt % 28 == 0:138# print(end='\n')139return image_data140141142# 待识别数字图片143image = image_prepare('./digits/62.png')144145# 训练模型146# train_model()147148# 测试模型149with tf.Session() as sess:150saver = tf.train.import_meta_graph('CNN_handwritten_digits_model.meta')151saver.restore(sess, tf.train.latest_checkpoint('./'))152153prediction = tf.argmax(prediction, 1)154predint = prediction.eval(feed_dict={X: [image], keep_prob: 1.0}, session=sess)155156print('识别结果:', end=' ')157print(predint[0])
