当前位置: 首页 > 科技观察

TensorFlow中RNN实现的正确打开方法

时间:2023-03-16 13:13:27 科技观察

本文的主要内容是如何在TensorFlow中实现RNN的几种结构:TensorFlow中RNN实现学习的完整循序渐进的方法。这条学习路径的曲线比较平缓,应该会减少很多学习上的努力,帮助大家少走弯路。一些可能的陷阱TensorFlow源码分析一个CharRNN实现实例,可以用来写诗,生成歌词,甚至可以写网络小说!1.学习单步RNN:RNNCell如果你想学习TensorFlow中的RNN,***站应该是了解“RNNCell”,它是TensorFlow中RNN的基本单元。每个RNNCell都有一个调用方法,用作:(output,next_state)=call(input,state)。用图片可能更容易理解。假设我们有一个初始状态h0,输入x1,调用call(x1,h0)后可以得到(output1,h1):再调用call(x2,h1)可以得到(output2,h2):即也就是说,每调用一次RNNCell的call方法,就相当于在时间上“前进了一步”,这是RNNCell的基本功能。在代码实现上,RNNCell只是一个抽象类,我们在使用的时候会用到它的两个子类BasicRNNCell和BasicLSTMCell。顾名思义,前者是RNN的基础类,后者是LSTM的基础类。这里推荐大家阅读其源码实现。您不需要一开始就阅读所有这些内容。只需要看一下RNNCell、BasicRNNCell、BasicLSTMCell这三个类的注释,就可以了解它们的功能。除了call方法,RNNCell还有两个比较重要的类属性:state_sizeoutput_size前者是隐藏层的大小,后者是输出的大小。例如,我们通常会发送一个批次给模型计算。如果输入数据的shape为(batch_size,input_size),那么计算时得到的隐藏层状态为(batch_size,state_size),输出为(batch_size,output_size)。可以用下面的代码验证一下(注意下面的代码是基于1.2版本的TensorFlow***):128inputs=tf.placeholder(np.float32,shape=(32,100))#32为batch_sizeh0=cell.zero_state(32,np.float32)#通过zero_state得到一个全0的初始状态,shape为(batch_size,state_size)output,h1=cell.call(inputs,h0)#调用call函数print(h1.shape)#(32,128)对于BasicLSTMCell,情况略有不同,因为LSTM可以看成有两个隐藏状态h和c,对应的隐藏层是一个元组,每个元组的形状为(batch_size,state_size):importtensorflowastfimportnumpyasnplstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=128)inputs=tf.placeholder(np.float32,shape=(32,100))#32是batch_sizeh0=lstm_cell.zero_state(32,np.float32)#通过zero_state得到一个全0的初始状态输出,h1=lstm_cell.call(inputs,h0)print(h1.h)#shape=(32,128)print(h1.c)#shape=(32,128)2.学习如何一次执行多个步骤:基于tf.nn.dynamic_rnn的RNNCell有一个明显的问题:对于单个RNNCell,当我们使用它的调用函数来执行操作时,我们只是在序列时间上提前了一步。比如用x1和h0得到h1,通过x2和h1得到h2等等。对于这样的h,如果我们的序列长度是10,就需要调用10次call函数,比较麻烦。对此,TensorFlow提供了一个tf.nn.dynamic_rnn函数,相当于调用了n次call函数。即直接通过{h0,x1,x2,….,xn}得到{h1,h2...,hn}。具体来说,让我们输入数据的格式为(batch_size,time_steps,input_size),其中time_steps代表序列本身的长度。例如,在CharRNN中,长度为10的句子对应的time_steps等于10。***的input_size表示单个时间维度上单个输入数据序列的固有长度。另外,我们定义了一个RNNCell,并调用了RNNCell的调用函数time_steps。对应的代码是:#inputs:shape=(batch_size,time_steps,input_size)#cell:RNNCell#initial_state:shape=(batch_size,cell.state_size)。初始状态。一般可以取零矩阵输出,state=tf.nn.dynamic_rnn(cell,inputs,initial_state=initial_state)此时得到的输出都是time_steps步的输出。它具有形状(batch_size、time_steps、cell.output_size)。state是最后一步的隐藏状态,它的shape是(batch_size,cell.state_size)。建议阅读tf.nn.dynamic_rnn的文档,进一步了解。3.学习如何堆叠RNNCell:MultiRNNCell很多时候,单层RNN的能力是有限的,我们需要多层RNN。将x输入第一层RNN后,得到隐层状态h。这个隐层状态相当于第二层RNN的输入,第二层RNN的隐层状态相当于第三层RNN的输入。比喻。在TensorFlow中,可以使用tf.nn.rnn_cell.MultiRNNCell函数来堆叠RNNCell。对应的示例程序如下:importtensorflowastfimportnumpyasnp#每次调用该函数,返回一个BasicRNNCelldefget_a_cell():returntf.nn.rnn_cell.BasicRNNCell(num_units=128)#使用tf.nn.rnn_cellMultiRNNCell创建3层RNNcell=tf.nn.rnn_cell.MultiRNNCell([get_a_cell()for_inrange(3)])#3层RNN#得到的cell其实是RNNCell的子类#它的state_size是(128,128,128)#(128,128,128)不是这个意思的128x128x128#表示一共有3个隐层状态,每个隐层状态的大小为128print(cell.state_size)#(128,128,128)#使用对应的调用函数inputs=tf.placeholder(np.float32,shape=(32,100))#32是batch_sizeh0=cell.zero_state(32,np.float32)#通过zero_state得到一个全0的初始状态输出,h1=cell.call(inputs,h0)print(h1)#元组包含三个32x128向量。MultiRNNCell得到的cell不是新的。它实际上是RNNCell的子类,所以它也有callmethod、state_size和output_size属性。也可以通过tf.nn.dynamic_rnn一次运行多个步骤。推荐阅读MutiRNNCell源码中的注释,进一步了解其功能。4、可能的陷阱1:输出描述在经典的RNN结构中有这样一张图:在上面的代码中,我们似乎有意忽略了调用call或dynamic_rnn函数后得到的输出的介绍。将上图与TensorFlow的BasicRNNCell进行对比。h对应BasicRNNCell的state_size。那么,y对应的是BasicRNNCell的output_size吗?答案是不。在源码中找到BasicRNNCell的调用函数的实现:defcall(self,inputs,state):"""MostbasicRNN:output=new_state=act(W*input+U*state+B)."""output=self._activation(_linear([inputs,state],self._num_units,True))returnoutput,output这句“returnoutput,output”表明在BasicRNNCell中,output其实和hiddenstate的值是一样的。因此,我们还需要为输出定义一个新的变换,才能得到图中真正的输出y。由于输出和隐藏状态是一回事,在BasicRNNCell中,state_size总是等于output_size。TensorFlow定义BasicRNNCell是为了尽可能简单,所以省略了输出参数。我们必须搞清楚它和图中原来的RNN定义的联系和区别。我们看一下BasicLSTMCell的调用函数定义(函数前几行):new_c=(c*sigmoid(f+self._forget_bias)+sigmoid(i)*self._activation(j))new_h=self._activation(new_c)*sigmoid(o)ifself._state_is_tuple:new_state=LSTMStateTuple(new_c,new_h)else:new_state=array_ops.concat([new_c,new_h],1)returnnew_h,new_state我们只需要关注self._state_is_tuple==True的情况,因为self._state_is_tuple==False的情况将来会被弃用。返回的隐藏状态是new_c和new_h的组合,输出是单个new_h。如果我们处理的是分类问题,那么我们还需要在new_h中添加一个单独的Softmax层,以获得最好的分类概率输出。还是建议大家看一下源码实现,了解其中的细节。5、可能的陷阱2:版本原因导致的错误。我们讲到堆叠RNN时,使用的代码是:#每次调用这个函数,返回一个BasicRNNCelldefget_a_cell():returntf.nn.rnn_cell.BasicRNNCell(num_units=128)#用tf.nn创建3层RNNcell.rnn_cellMultiRNNCell=tf.nn.rnn_cell.MultiRNNCell([get_a_cell()for_inrange(3)])#3-layerRNN这段代码可以在TensorFlow1.2中正确使用。但是在之前的版本中(以及网上很多相关的教程),实现是这样的:one_cell=tf.nn.rnn_cell.BasicRNNCell(num_units=128)cell=tf.nn.rnn_cell.MultiRNNCell([one_cell]*3)#3LayerRNN如果在TensorFlow1.2还是按照原来的方式定义,会报错!6.一个实践项目:以上CharRNN的内容其实就是在TensorFlow中实现RNN的基础知识。这个时候我建议你用一个项目来练习巩固。这里特别推荐CharRNN项目。本项目对应经典的RNN结构。用来实现它的TensorFlow函数就是上面提到的那些。该项目本身很有趣,可以用于文本生成。大家平时看到的,基本都是用深度学习来写诗词的。CharRNN的实现已经有很多了,大家可以自己去Github上找找,我这里也做了一个实现供大家参考。项目地址为:hzy46/Char-RNN-TensorFlow。我主要是在代码中加入了embedding层来支持中文,并且重新整理了代码结构,将API改成了最新的TensorFlow1.2版本。可以用这个项目写诗(下面的诗都是自动生成的):无人见人,此地如是。一夜上山,一夜回山。山风春草色,秋水夜声深。相见,当知老子。何不相见,何处见江边。一叶生云中,春风出竹堂。有来访,不在王心。您还可以生成代码:staticintpage_cpus(structflags*str){intrc;structrq*do_init;};/**Core_trace_periodsthetimeinisisthatsupsed,*/#endif/**Intendifinttostateanded.*/intprint_init(structpriority*rt){/*Commentsighindifseetasksoandthesections*/console(string,&can);}此外,生成英语不是问题(使用莎士比亚的文本训练):LAUNCE:Theformitysomistaliedonhis,thouhastshewastoherhears,whatweshouldbethatsayasounmanWouldthelord和所有的犯规,也说,我们destent和我的和平在这里。PALINA:为什么,你的呼吸是必须的还是你的呼吸,我已经满足了他,我的营地也有我。***,如果你的脑洞足够大,我可以做一些更有趣的事情.比如我用著名的网络小说《斗破苍穹》训练了一个RNN模型,可以生成如下文字:闻言,萧炎吓了一跳,随即把目光转向了旁边的灰袍青年,然后他的目光扫过老人。在那里,一个巨大的石台上,有着一个巨大的坑洞,一些黑色的光柱从里面散发出来。一条巨大的黑色蟒蛇,一股极为恐怖的感觉从天而降,而后一些人的眼中,那些身影之上出现了闪电。在那灵魂之中,有一种强者的感觉,在他们面前,然而,那些身影,犹如黑影一般,在那双眼之中,在这个世界,在那巨大的空间之中,散开……尊级,但不管你,是不可能出手的。那些家伙可以为了这个而动手,而且这里能出现一些异常,他也不能将你的灵魂交给其他人。所以,我不可能给这个人强大的吞噬天芒,这一次,我们的实力,就是能够将它击杀……”“这里的人,也可以和魂殿内的强者一较高下了。”萧炎眼中也是闪过一抹惊骇,旋即笑了起来,顿时一声冷喝,身后的魂殿高手迎向萧炎,一身寒饮冲天而起,一股恐怖的能量从天而降。》还是很好玩的,我也尝试过生成日文等等。7.学习完整版LSTMCell上面只提到了基础版BasicRNNCell和BasicLSTMCell,TensorFlow中还有一个“完整”的LSTM:LSTMCell。这个完整版的LSTM可以定义peephole,增加输出投影层,可以给LSTM的遗忘单元设置bias,rc版已经出,看来正式版也快出炉了,更新的是真快)更新了Seq2SeqAPI,使用这个API,我们可以不用手动定义Seq2Seq模型中的Encoder和Decoder,另外还兼容1.2版本新的数据读入方式兼容Datasets.可以在这里阅读文档学习如何使用九.总结***简单总结一下,本文提供了学习TensorFlowRNN实现的详细路径,包括学习顺序,可能踩到的坑,源码分析是的,还有一个示例工程hzy46/Char-RNN-TensorFlow,希望能对大家有所帮助。