用TensorFlow实现手写数字分类。
Classification 分类学习
数据:MNIST库(手写体数字库),包含55000张训练图片,每张图片分辨率是28×28,故训练网络输入应该是28×28=784个像素数据。

网络结构:输入数据784个特征,输出数据10个特征,激励采用softmax函数

损失函数(最优化目标函数):选用交叉熵函数,其用来衡量预测值和真实值的相似程度(若完全相同,它们的交叉熵等于零)。
最优化算法:梯度下降法
手写数字分类
1 | import tensorflow as tf |
2 | from tensorflow.examples.tutorials.mnist import input_data |
3 | |
4 | ''' |
5 | TensorFlow: 分类学习 |
6 | ''' |
7 | |
8 | |
9 | # 构建一个神经层函数 |
10 | def add_layer(inputs, in_size, out_size, activation_function=None,): |
11 | Weights = tf.Variable(tf.random_normal([in_size, out_size])) |
12 | biases = tf.Variable(tf.zeros([1, out_size]) + 0.1,) |
13 | Wx_plus_b = tf.matmul(inputs, Weights) + biases |
14 | if activation_function is None: |
15 | outputs = Wx_plus_b |
16 | else: |
17 | outputs = activation_function(Wx_plus_b,) |
18 | return outputs |
19 | |
20 | |
21 | # 计算精确度 |
22 | def compute_accuracy(v_xs, v_ys): |
23 | global prediction |
24 | y_pre = sess.run(prediction, feed_dict={xs: v_xs}) |
25 | correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(v_ys, 1)) |
26 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) |
27 | result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys}) |
28 | return result |
29 | |
30 | |
31 | # 导入手写体数字库MNIST |
32 | mnist = input_data.read_data_sets('./MNIST_data', one_hot=True) |
33 | |
34 | # 定义传入值 |
35 | xs = tf.placeholder(tf.float32, [None, 784]) # 28x28 |
36 | ys = tf.placeholder(tf.float32, [None, 10]) # 输出为数字0到9共10类 |
37 | |
38 | # 添加层 |
39 | prediction = add_layer(xs, 784, 10, activation_function=tf.nn.softmax) |
40 | |
41 | # 损失函数(最优化目标函数):交叉熵函数 |
42 | cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction), reduction_indices=[1])) # loss |
43 | |
44 | # 梯度下降法 |
45 | train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) |
46 | sess = tf.Session() |
47 | init = tf.global_variables_initializer() |
48 | sess.run(init) |
49 | |
50 | # 开始训练 |
51 | for i in range(1000): |
52 | # 为提高训练速度,每次只取100张图片训练 |
53 | batch_xs, batch_ys = mnist.train.next_batch(100) |
54 | sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys}) |
55 | if i % 50 == 0: |
56 | # 输出预测精确度 |
57 | print(compute_accuracy(mnist.test.images, mnist.test.labels)) |
分类结果
1 | 0.0901 |
2 | 0.6524 |
3 | 0.7598 |
4 | 0.7997 |
5 | 0.8143 |
6 | 0.833 |
7 | 0.8446 |
8 | 0.8528 |
9 | 0.8544 |
10 | 0.8559 |
11 | 0.8624 |
12 | 0.8636 |
13 | 0.8656 |
14 | 0.8691 |
15 | 0.8693 |
16 | 0.8664 |
17 | 0.8727 |
18 | 0.8768 |
19 | 0.8797 |
20 | 0.8787 |
