你好,游客 登录
背景:
阅读新闻

教程 | 使用MNIST数据集,在TensorFlow上实现基础LSTM网络

[日期:2017-09-30] 来源:机器之心  作者: [字体: ]

本文介绍了如何在 TensorFlow 上实现基础 LSTM 网络的详细过程。作者选用了 MNIST 数据集,本文详细介绍了实现过程。

长短期记忆(LSTM)是目前循环神经网络最普遍使用的类型,在处理时间序列数据时使用最为频繁。关于 LSTM 的更加深刻的洞察可以看看这篇优秀的博客:http://colah.github.io/posts/2015-08-Understanding-LSTMs/。

我们的目的

这篇博客的主要目的就是使读者熟悉在 TensorFlow 上实现基础 LSTM 网络的详细过程。

我们将选用 MNIST 作为数据集。

  1. fromtensorflow.examples.tutorials.mnist importinput_data

  2. mnist =input_data.read_data_sets("/tmp/data/",one_hot=True)

MNIST 数据集

MNIST 数据集包括手写数字的图像和对应的标签。我们可以根据以下内置功能从 TensorFlow 上下载并读取数据。

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

数据被分成 3 个部分:

1. 训练数据(mnist.train):55000 张图像

2. 测试数据(mnist.test):10000 张图像

3. 验证数据(mnist.validation):5000 张图像

数据的形态

讨论一下 MNIST 数据集中的训练数据的形态。数据集的这三个部分的形态都是一样的。

训练数据集包括 55000 张 28x28 像素的图像,这些 784(28x28)像素值被展开成一个维度为 784 的单一向量,所有 55000 个像素向量(每个图像一个)被储存为形态为 (55000,784) 的 numpy 数组,并命名为 mnist.train.images。

所有这 55000 张图像都关联了一个类别标签(表示其所属类别),一共有 10 个类别(0,1,2...9),类别标签使用独热编码的形式表示。因此标签将作为形态为 (55000,10) 的数组保存,并命名为 mnist.train.labels。

为什么要选择 MNIST?

LSTM 通常用来解决复杂的序列处理问题,比如包含了 NLP 概念(词嵌入、编码器等)的语言建模问题。这些问题本身需要大量理解,那么将问题简化并集中于在 TensorFlow 上实现 LSTM 的细节(比如输入格式化、LSTM 单元格以及网络结构设计),会是个不错的选择。

MNIST 就正好提供了这样的机会。其中的输入数据是一个像素值的集合。我们可以轻易地将其格式化,将注意力集中在 LSTM 实现细节上。

实现

在动手写代码之前,先规划一下实现的蓝图,可以使写代码的过程更加直观。

VANILLA RNN

循环神经网络按时间轴展开的时候,如下图所示:

图中:

1.x_t 代表时间步 t 的输入;

2.s_t 代表时间步 t 的隐藏状态,可看作该网络的「记忆」;

3.o_t 作为时间步 t 时刻的输出;

4.U、V、W 是所有时间步共享的参数,共享的重要性在于我们的模型在每一时间步以不同的输入执行相同的任务。

当把 RNN 展开的时候,网络可被看作每一个时间步都受上一时间步输出影响(时间步之间存在连接)的前馈网络。

两个注意事项

为了更顺利的进行实现,需要清楚两个概念的含义:

1.TensorFlow 中 LSTM 单元格的解释;

2. 数据输入 TensorFlow RNN 之前先格式化。

TensorFlow 中 LSTM 单元格的解释

在 TensorFlow 中,基础的 LSTM 单元格声明为:

  1. tf.contrib.rnn.BasicLSTMCell(num_units)

这里,num_units 指一个 LSTM 单元格中的单元数。num_units 可以比作前馈神经网络中的隐藏层,前馈神经网络的隐藏层的节点数量等于每一个时间步中一个 LSTM 单元格内 LSTM 单元的 num_units 数量。下图可以帮助直观理解:

每一个 num_units LSTM 单元都可以看作一个标准的 LSTM 单元:

以上图表来自博客(地址:http://colah.github.io/posts/2015-08-Understanding-LSTMs/),该博客有效介绍了 LSTM 的概念。

数据输入 TensorFlow RNN 之前先格式化

在 TensorFlow 中最简单的 RNN 形式是 static_rnn,在 TensorFlow 中定义如下:

  1. tf.static_rnn(cell,inputs)

虽然还有其它的注意事项,但在这里我们仅关注这两个。

inputs 引数接受形态为 [batch_size,input_size] 的张量列表。列表的长度为将网络展开后的时间步数,即列表中每一个元素都分别对应网络展开的时间步。比如在 MNIST 数据集中,我们有 28x28 像素的图像,每一张都可以看成拥有 28 行 28 个像素的图像。我们将网络按 28 个时间步展开,以使在每一个时间步中,可以输入一行 28 个像素(input_size),从而经过 28 个时间步输入整张图像。给定图像的 batch_size 值,则每一个时间步将分别收到 batch_size 个图像。详见下图说明:

由 static_rnn 生成的输出是一个形态为 [batch_size,n_hidden] 的张量列表。列表的长度为将网络展开后的时间步数,即每一个时间步输出一个张量。在这个实现中我们只需关心最后一个时间步的输出,因为一张图像的所有行都输入到 RNN,预测即将在最后一个时间步生成。

现在,所有的困难部分都已经完成,可以开始写代码了。只要理清了概念,写代码过程是很直观的。

代码

在开始的时候,先导入一些必要的依赖关系、数据集,并声明一些常量。设定 batch_size=128 、 num_units=128。

  1. importtensorflow astf

  2. fromtensorflow.contrib importrnn

  3.  

  4. #import mnist dataset

  5. fromtensorflow.examples.tutorials.mnist importinput_data

  6. mnist=input_data.read_data_sets("/tmp/data/",one_hot=True)

  7.  

  8. #define constants

  9. #unrolled through 28 time steps

  10. time_steps=28

  11. #hidden LSTM units

  12. num_units=128

  13. #rows of 28 pixels

  14. n_input=28

  15. #learning rate for adam

  16. learning_rate=0.001

  17. #mnist is meant to be classified in 10 classes(0-9).

  18. n_classes=10

  19. #size of batch

  20. batch_size=128

现在设置占位、权重以及偏置变量(用于将输出的形态从 [batch_size,num_units] 转换为 [batch_size,n_classes]),从而可以预测正确的类别。

  1. #weights and biases of appropriate shape to accomplish above task

  2. out_weights=tf.Variable(tf.random_normal([num_units,n_classes]))

  3. out_bias=tf.Variable(tf.random_normal([n_classes]))

  4.  

  5. #defining placeholders

  6. #input image placeholder

  7. x=tf.placeholder("float",[None,time_steps,n_input])

  8. #input label placeholder

  9. y=tf.placeholder("float",[None,n_classes])

现在我们得到了形态为 [batch_size,time_steps,n_input] 的输入,我们需要将其转换成形态为 [batch_size,n_inputs] 、长度为 time_steps 的张量列表,从而可以将其输入 static_rnn。

  1. #processing the input tensor from [batch_size,n_steps,n_input] to "time_steps" number of [batch_size,n_input] tensors

  2. input=tf.unstack(x ,time_steps,1)

现在我们可以定义网络了。我们将利用 BasicLSTMCell 的一个层,将我们的 static_rnn 从中提取出来。

  1. #defining the network

  2. lstm_layer=rnn.BasicLSTMCell(num_units,forget_bias=1)

  3. outputs,_=rnn.static_rnn(lstm_layer,input,dtype="float32")

我们只考虑最后一个时间步的输入,从中生成预测。

  1. #converting last output of dimension [batch_size,num_units] to [batch_size,n_classes] by out_weight multiplication

  2. prediction=tf.matmul(outputs[-1],out_weights)+out_bias

定义损失函数、优化器和准确率。

  1. #loss_function

  2. loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))

  3. #optimization

  4. opt=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

  5.  

  6. #model evaluation

  7. correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))

  8. accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

现在我们已经完成定义,可以开始运行了。

  1. #initialize variables

  2. init=tf.global_variables_initializer()

  3. withtf.Session()assess:

  4. sess.run(init)

  5. iter=1

  6. whileiter<800:

  7. batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)

  8.  

  9. batch_x=batch_x.reshape((batch_size,time_steps,n_input))

  10.  

  11. sess.run(opt,feed_dict={x:batch_x,y:batch_y})

  12.  

  13. ifiter %10==0:

  14. acc=sess.run(accuracy,feed_dict={x:batch_x,y:batch_y})

  15. los=sess.run(loss,feed_dict={x:batch_x,y:batch_y})

  16. print("For iter ",iter)

  17. print("Accuracy ",acc)

  18. print("Loss ",los)

  19. print("__________________")

  20.  

  21. iter=iter+1

需要注意的是我们的每一张图像在开始时被平坦化为 784 维的单一向量,函数 next_batch(batch_size) 必须返回这些 784 维向量的 batch_size 批次数。因此它们的形态要被改造成 [batch_size,time_steps,n_input],从而可以被我们的占位符接受。

我们还可以计算模型的准确率:

  1. #calculating test accuracy

  2. test_data =mnist.test.images[:128].reshape((-1,time_steps,n_input))

  3. test_label =mnist.test.labels[:128]

  4. print("Testing Accuracy:",sess.run(accuracy,feed_dict={x:test_data,y:test_label}))

在运行的时候,模型的测试准确率为 99.21%。

这篇博客旨在让读者熟悉 TensorFlow 中 RNN 的实现细节。我们将会在 TensorFlow 中建立更加复杂的模型以更有效的利用 RNN。敬请期待!

收藏 推荐 打印 | 录入:Cstor | 阅读:
相关新闻      
本文评论   查看全部评论 (0)
表情: 表情 姓名: 字数
点评:
       
评论声明
  • 尊重网上道德,遵守中华人民共和国的各项有关法律法规
  • 承担一切因您的行为而直接或间接导致的民事或刑事法律责任
  • 本站管理人员有权保留或删除其管辖留言中的任意内容
  • 本站有权在网站内转载或引用您的评论
  • 参与本评论即表明您已经阅读并接受上述条款