热线电话:13121318867

登录
首页精彩阅读十分钟告诉你——何为Keras中的序列到序列学习
十分钟告诉你——何为Keras中的序列到序列学习
2019-12-20
收藏
十分钟告诉你——何为<a href='/map/Keras/' style='color:#000;font-size:inherit;'>Keras</a>中的序列到序列学习

作者 | Francois Chollet

编译 | CDA数据分析师

A ten-minute introduction to sequence-to-sequence learning in Keras

什么是顺序学习?

序列到序列学习(Seq2Seq)是关于将模型从一个域(例如英语中的句子)转换为另一域(例如将相同句子翻译为法语的序列)的训练模型。

“猫坐在垫子上” -> [ Seq2Seq 模型] -> “在小吃中聊天

这可用于机器翻译或免费问答(在给定自然语言问题的情况下生成自然语言答案)-通常,它可在需要生成文本的任何时间使用。

有多种处理此任务的方法,可以使用RNN或使用一维卷积网络。在这里,我们将重点介绍RNN

普通情况:输入和输出序列的长度相同

当输入序列和输出序列的长度相同时,您可以简单地使用Keras LSTM或GRU层(或其堆栈)来实现此类模型。在此示例脚本 中就是这种情况, 该脚本显示了如何教RNN学习加编码为字符串的数字:

十分钟告诉你——何为<a href='/map/Keras/' style='color:#000;font-size:inherit;'>Keras</a>中的序列到序列学习

该方法的一个警告是,它假定可以生成target[...t]给定input[...t]。在某些情况下(例如添加数字字符串),该方法有效,但在大多数用例中,则无效。在一般情况下,有关整个输入序列的信息是必需的,以便开始生成目标序列。

一般情况:规范序列间

在一般情况下,输入序列和输出序列具有不同的长度(例如,机器翻译),并且需要整个输入序列才能开始预测目标。这需要更高级的设置,这是人们在没有其他上下文的情况下提到“序列模型的序列”时通常所指的东西。运作方式如下:

RNN层(或其堆栈)充当“编码器”:它处理输入序列并返回其自己的内部状态。请注意,我们放弃了编码器RNN的输出,仅恢复 了状态。在下一步中,此状态将用作解码器的“上下文”或“条件”。

另一个RNN层(或其堆栈)充当“解码器”:在给定目标序列的先前字符的情况下,对其进行训练以预测目标序列的下一个字符。具体而言,它经过训练以将目标序列变成相同序列,但在将来会偏移一个时间步,在这种情况下,该训练过程称为“教师强迫”。重要的是,编码器使用来自编码器的状态向量作为初始状态,这就是解码器如何获取有关应该生成的信息的方式。有效地,解码器学会产生targets[t+1...] 给定的targets[...t],调节所述输入序列。

十分钟告诉你——何为<a href='/map/Keras/' style='color:#000;font-size:inherit;'>Keras</a>中的序列到序列学习

在推断模式下,即当我们想解码未知的输入序列时,我们会经历一个略有不同的过程:

  • 将输入序列编码为状态向量。
  • 从大小为1的目标序列开始(仅是序列开始字符)。
  • 将状态向量和1个字符的目标序列馈送到解码器,以生成下一个字符的预测。
  • 使用这些预测来采样下一个字符(我们仅使用argmax)。
  • 将采样的字符追加到目标序列
  • 重复直到生成序列结束字符或达到字符数限制。
十分钟告诉你——何为<a href='/map/Keras/' style='color:#000;font-size:inherit;'>Keras</a>中的序列到序列学习

同样的过程也可以用于训练Seq2Seq网络,而无需 “教师强制”,即通过将解码器的预测重新注入到解码器中。

一个Keras例子

  • 将句子翻译成3个numpy的阵列,encoderinputdata,decoderinputdata,decodertargetdata:
  • encoderinputdata是一个3D形状的数组,(numpairs, maxenglishsentencelength, numenglishcharacters) 其中包含英语句子的一键向量化。
  • decoderinputdata是一个3D形状的数组,(numpairs, maxfrenchsentencelength, numfrenchcharacters) 其中包含法语句子的一键矢量化。
  • decodertargetdata与相同,decoderinputdata但相差一个时间步长。 decodertargetdata[:, t, :]将与相同decoderinputdata[:, t + 1, :]。
  • 训练基于LSTM的基本Seq2Seq模型,以预测decodertargetdata 给定encoderinputdata和decoderinputdata。我们的模型使用教师强迫。
  • 解码一些句子以检查模型是否正常运行(即,将的样本从encoderinputdata 转换为的对应样本decodertargetdata)。

因为训练过程和推理过程(解码句子)有很大的不同,所以我们对两者使用不同的模型,尽管它们都利用相同的内部层。

这是我们的训练模型。它利用Keras RNN的三个关键功能:

return_state构造器参数,配置RNN层返回一个列表,其中,第一项是输出与下一个条目是内部RNN状态。这用于恢复编码器的状态。

inital_state呼叫参数,指定一个RNN的初始状态(S)。这用于将编码器状态作为初始状态传递给解码器。

return_sequences构造函数的参数,配置RNN返回其输出全序列(而不只是最后的输出,其默认行为)。在解码器中使用。

十分钟告诉你——何为<a href='/map/Keras/' style='color:#000;font-size:inherit;'>Keras</a>中的序列到序列学习

from keras.models import Model from keras.layers import Input, LSTM, Dense # Define an input sequence and process it. encoder_inputs = Input(shape=(None, num_encoder_tokens)) encoder = LSTM(latent_dim, return_state=True) encoder_outputs, state_h, state_c = encoder(encoder_inputs) # We discard `encoder_outputs` and only keep the states. encoder_states = [state_h, state_c] # Set up the decoder, using `encoder_states` as initial state. decoder_inputs = Input(shape=(None, num_decoder_tokens)) # We set up our decoder to return full output sequences, # and to return internal states as well. We don't use the # return states in the training model, but we will use them in inference. decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states) decoder_dense = Dense(num_decoder_tokens, activation='softmax') decoder_outputs = decoder_dense(decoder_outputs) # Define the model that will turn # `encoder_input_data` & `decoder_input_data` into `decoder_target_data` model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

我们分两行训练我们的模型,同时监视20%的保留样本中的损失。

# Run training model.compile(optimizer='rmsprop', loss='categorical_crossentropy') model.fit([encoder_input_data, decoder_input_data], decoder_target_data, batch_size=batch_size, epochs=epochs, validation_split=0.2)

在MacBook CPU上运行大约一个小时后,我们就可以进行推断了。为了解码测试语句,我们将反复:

  • 对输入语句进行编码并检索初始解码器状态
  • 以该初始状态和“序列开始”令牌为目标,运行解码器的一步。输出将是下一个目标字符。
  • 添加预测的目标字符并重复。

这是我们的推理设置:

encoder_model = Model(encoder_inputs, encoder_states) decoder_state_input_h = Input(shape=(latent_dim,)) decoder_state_input_c = Input(shape=(latent_dim,)) decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] decoder_outputs, state_h, state_c = decoder_lstm( decoder_inputs, initial_state=decoder_states_inputs) decoder_states = [state_h, state_c] decoder_outputs = decoder_dense(decoder_outputs) decoder_model = Model( [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states)

十分钟告诉你——何为<a href='/map/Keras/' style='color:#000;font-size:inherit;'>Keras</a>中的序列到序列学习

我们使用它来实现上述推理循环:

def decode_sequence(input_seq): # Encode the input as state vectors. states_value = encoder_model.predict(input_seq) # Generate empty target sequence of length 1. target_seq = np.zeros((1, 1, num_decoder_tokens)) # Populate the first character of target sequence with the start character. target_seq[0, 0, target_token_index['\t']] = 1. # Sampling loop for a batch of sequences # (to simplify, here we assume a batch of size 1). stop_condition = False decoded_sentence = '' while not stop_condition: output_tokens, h, c = decoder_model.predict( [target_seq] + states_value) # Sample a token sampled_token_index = np.argmax(output_tokens[0, -1, :]) sampled_char = reverse_target_char_index[sampled_token_index] decoded_sentence += sampled_char # Exit condition: either hit max length # or find stop character. if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length): stop_condition = True # Update the target sequence (of length 1). target_seq = np.zeros((1, 1, num_decoder_tokens)) target_seq[0, 0, sampled_token_index] = 1. # Update states states_value = [h, c] return decoded_sentence

我们得到了一些不错的结果-毫不奇怪,因为我们正在解码从训练测试中提取的样本

Input sentence: Be nice. Decoded sentence: Soyez gentil ! - Input sentence: Drop it! Decoded sentence: Laissez tomber ! - Input sentence: Get out! Decoded sentence: Sortez !

到此,我们结束了对Keras中序列到序列模型的十分钟介绍。提醒:此脚本的完整代码可以在GitHub上找到。

数据分析咨询请扫描二维码

最新资讯
更多
客服在线
立即咨询