TensorFlow-4 分类学习

用TensorFlow实现手写数字分类。

Classification 分类学习

  • 数据:MNIST库(手写体数字库),包含55000张训练图片,每张图片分辨率是28×28,故训练网络输入应该是28×28=784个像素数据。

  • 网络结构:输入数据784个特征,输出数据10个特征,激励采用softmax函数

  • 损失函数(最优化目标函数):选用交叉熵函数,其用来衡量预测值和真实值的相似程度(若完全相同,它们的交叉熵等于零)。

  • 最优化算法:梯度下降法

手写数字分类

1
import tensorflow as tf
2
from tensorflow.examples.tutorials.mnist import input_data
3
4
'''
5
TensorFlow: 分类学习
6
'''
7
8
9
# 构建一个神经层函数
10
def add_layer(inputs, in_size, out_size, activation_function=None,):
11
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
12
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1,)
13
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
14
    if activation_function is None:
15
        outputs = Wx_plus_b
16
    else:
17
        outputs = activation_function(Wx_plus_b,)
18
    return outputs
19
20
21
# 计算精确度
22
def compute_accuracy(v_xs, v_ys):
23
    global prediction
24
    y_pre = sess.run(prediction, feed_dict={xs: v_xs})
25
    correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(v_ys, 1))
26
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
27
    result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})
28
    return result
29
30
31
# 导入手写体数字库MNIST
32
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)
33
34
# 定义传入值
35
xs = tf.placeholder(tf.float32, [None, 784])  # 28x28
36
ys = tf.placeholder(tf.float32, [None, 10])  # 输出为数字0到9共10类
37
38
# 添加层
39
prediction = add_layer(xs, 784, 10,  activation_function=tf.nn.softmax)
40
41
# 损失函数(最优化目标函数):交叉熵函数
42
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction), reduction_indices=[1]))  # loss
43
44
# 梯度下降法
45
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
46
sess = tf.Session()
47
init = tf.global_variables_initializer()
48
sess.run(init)
49
50
# 开始训练
51
for i in range(1000):
52
    # 为提高训练速度,每次只取100张图片训练
53
    batch_xs, batch_ys = mnist.train.next_batch(100)
54
    sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys})
55
    if i % 50 == 0:
56
        # 输出预测精确度
57
        print(compute_accuracy(mnist.test.images, mnist.test.labels))

分类结果

1
0.0901
2
0.6524
3
0.7598
4
0.7997
5
0.8143
6
0.833
7
0.8446
8
0.8528
9
0.8544
10
0.8559
11
0.8624
12
0.8636
13
0.8656
14
0.8691
15
0.8693
16
0.8664
17
0.8727
18
0.8768
19
0.8797
20
0.8787