LSTM

2018-12-06
來看看udacity的深度學(xué)習(xí)課的lstm實(shí)現(xiàn)代碼

RNN和LSTM

假設(shè)你有一個(gè)事件序列,這個(gè)序列是根據(jù)時(shí)間變化的缎岗,希望根據(jù)某個(gè)時(shí)間點(diǎn)的事件進(jìn)行預(yù)測,并且把以前的事件也考慮在內(nèi),因?yàn)椴豢赡軐⒅懊總€(gè)時(shí)間點(diǎn)的狀態(tài)傳遞給當(dāng)前時(shí)間點(diǎn)捉貌,所以RNN通過每個(gè)時(shí)間點(diǎn)都對(duì)前面的時(shí)間點(diǎn)進(jìn)行總結(jié)傳遞給當(dāng)前狀態(tài),就可以學(xué)習(xí)到序列的所有節(jié)點(diǎn)狀態(tài)


RNN-rolled

RNN-unrolled

上下兩幅圖是等價(jià)的
其中序列應(yīng)該是逐個(gè)讀入RNN而不是同時(shí)讀取的

存在問題

RNN的反向傳播:
因?yàn)镽NN在時(shí)間上共用權(quán)重,所以更新時(shí)非常不穩(wěn)定趁窃,會(huì)出現(xiàn)梯度爆炸或梯度下降

解決方法
  • gradient clipping(梯度裁剪)


    梯度裁剪
  • lstm(長短期模型)



    記憶單元


代碼

讀入數(shù)據(jù)

仍然是text8.zip

創(chuàng)建一個(gè)小的驗(yàn)證集

valid_size = 1000
valid_text = text[:valid_size]
train_text = text[valid_size:]
train_size = len(train_text)
print(train_size, train_text[:64])
print(valid_size, valid_text[:64])

99999000 ons anarchists advocate social relations based upon voluntary as
1000  anarchism originated as a term of abuse first used against earl

建立字母到數(shù)字的映射

vocabulary_size = len(string.ascii_lowercase) + 1 # [a-z] + ' '
first_letter = ord(string.ascii_lowercase[0])

def char2id(char):
  if char in string.ascii_lowercase:
    return ord(char) - first_letter + 1
  elif char == ' ':
    return 0
  else:
    print('Unexpected character: %s' % char)
    return 0
  
def id2char(dictid):
  if dictid > 0:
    return chr(dictid + first_letter - 1)
  else:
    return ' '

print(char2id('a'), char2id('z'), char2id(' '), char2id('?'))
print(id2char(1), id2char(26), id2char(0))

1 26 0 Unexpected character: ?
0
a z  

為模型建立訓(xùn)練數(shù)據(jù)

batch_size=64
num_unrollings=10

class BatchGenerator(object):
  def __init__(self, text, batch_size, num_unrollings):
    self._text = text
    self._text_size = len(text)
    self._batch_size = batch_size
    self._num_unrollings = num_unrollings
    segment = self._text_size // batch_size
    self._cursor = [ offset * segment for offset in range(batch_size)]
    self._last_batch = self._next_batch()
  
  def _next_batch(self):
    """Generate a single batch from the current cursor position in the data."""
    batch = np.zeros(shape=(self._batch_size, vocabulary_size), dtype=np.float)
    for b in range(self._batch_size):
      batch[b, char2id(self._text[self._cursor[b]])] = 1.0
      self._cursor[b] = (self._cursor[b] + 1) % self._text_size
#這里是為了循環(huán)拿數(shù)據(jù)
    return batch
  
  def next(self):
    """Generate the next array of batches from the data. The array consists of
    the last batch of the previous array, followed by num_unrollings new ones.
    """
    batches = [self._last_batch]
#這里的batches我認(rèn)為應(yīng)該叫序列比較好分清楚牧挣, num_unrollings的長度就是batches的長度
    for step in range(self._num_unrollings):
      batches.append(self._next_batch())
    self._last_batch = batches[-1]
#每次會(huì)取上次的最后一序列
    return batches

train_batches = BatchGenerator(train_text, batch_size, num_unrollings)
valid_batches = BatchGenerator(valid_text, 1, 1)

batch_size是批次大小,num_unrollings 是序列長度
為了保證每次傳遞的批次對(duì)應(yīng)的字符是一樣的醒陆,所以設(shè)置了cursor游標(biāo)

比如'abcdefghij'是長度為10的字符串瀑构,2是批次大小,序列長度也是2
下面的輸出统求,一個(gè)array是一個(gè)批次检碗,多少個(gè)array就是多少個(gè)序列

這里要講清楚,批次大小為多少就認(rèn)定有多少個(gè)字符是一個(gè)組码邻,比如批次為2折剃,那么認(rèn)定有倆詞,分別是‘a(chǎn)bcde’和‘fhij',那么對(duì)應(yīng)的批次當(dāng)然是’a,f','b,h'等等像屋,可以這樣理解多少個(gè)批次就是多少個(gè)首字母怕犁,那么當(dāng)然就有多少個(gè)詞

因?yàn)槊看我惨祷厣洗蔚淖詈笠粋€(gè)序列,所以每次有三個(gè)序列

test = BatchGenerator('abcdefghij',2, 2 )
test.next()

[array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., #a
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],                             
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,    #f
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
 array([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,   #b
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,      #g
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
 array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,    #c
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,       #h
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])]

工具函數(shù)

  • 展示概率最大的字符
def characters(probabilities):
  """Turn a 1-hot encoding or a probability distribution over the possible
  characters back into its (most likely) character representation."""
  return [id2char(c) for c in np.argmax(probabilities, 1)]
  • 將序列表示為字符
def batches2string(batches):
  """Convert a sequence of batches back into their (most likely) string
  representation."""
  s = [''] * batches[0].shape[0]
  for b in batches:
    s = [''.join(x) for x in zip(s, characters(b))]
  return s

簡單的LSTM模型

num_nodes = 64

graph = tf.Graph()
with graph.as_default():

num_nodes 是lstm cell的個(gè)數(shù)

定義變量
  # Parameters:
  # Input gate: input, previous output, and bias.
  ix = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))
  im = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))
  ib = tf.Variable(tf.zeros([1, num_nodes]))
  # Forget gate: input, previous output, and bias.
  fx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))
  fm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))
  fb = tf.Variable(tf.zeros([1, num_nodes]))
  # Memory cell: input, state and bias.                             
  cx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))
  cm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))
  cb = tf.Variable(tf.zeros([1, num_nodes]))
  # Output gate: input, previous output, and bias.
  ox = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))
  om = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))
  ob = tf.Variable(tf.zeros([1, num_nodes]))
  # Variables saving state across unrollings.
  saved_output = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)
  saved_state = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)
  # Classifier weights and biases.
  w = tf.Variable(tf.truncated_normal([num_nodes, vocabulary_size], -0.1, 0.1))
  b = tf.Variable(tf.zeros([vocabulary_size]))

再把lstm的圖拿出來回憶一下:


lstm cell

上述代碼提到了一下幾個(gè)

  • input gate: ix, im, ib
  • forget gate: fx, fm, fb
  • memory cell : cx, cm, cb
  • output cell : ox, om, ob
  • saved_output, saved_state:初始的ht和ct
  • classifier: w,b最后用來分類的權(quán)重和偏置
定義lstm cell
  # Definition of the cell computation.
  def lstm_cell(i, o, state):
    """Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf
    Note that in this formulation, we omit (省略)the various connections between the
    previous state and the gates."""
    input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)
    forget_gate = tf.sigmoid(tf.matmul(i, fx) + tf.matmul(o, fm) + fb)
    output_gate = tf.sigmoid(tf.matmul(i, ox) + tf.matmul(o, om) + ob)
    update = tf.matmul(i, cx) + tf.matmul(o, cm) + cb
    state = forget_gate * state + input_gate * tf.tanh(update)
    return output_gate * tf.tanh(state), state
LSTM

根據(jù)圖來看己莺,代碼中的對(duì)應(yīng)
input_gate: i
forget_gate: f
output_gate : o
update : g
三個(gè)輸入
state: ct-1
o: ht-1
i :xt
輸出分別為: ht奏甫, ct

定義輸入接口
  # Input data.
  train_data = list()
  for _ in range(num_unrollings + 1):
    train_data.append(
      tf.placeholder(tf.float32, shape=[batch_size,vocabulary_size]))
  train_inputs = train_data[:num_unrollings]
  train_labels = train_data[1:]  # labels are inputs shifted by one time step.

訓(xùn)練數(shù)據(jù)的標(biāo)簽是序列向右位移一位

LSTM 循環(huán)訓(xùn)練
  # Unrolled LSTM loop.
  outputs = list()
  output = saved_output
  state = saved_state
  for i in train_inputs:
    output, state = lstm_cell(i, output, state)
    outputs.append(output)
定義loss

取自博客
因?yàn)椴皇琼樞驁?zhí)行語言,一般模型如果不是相關(guān)的語句凌受,其執(zhí)行是沒有先后順序的阵子,control_dependencies 的作用就是建立先后順序,保證前面兩句被執(zhí)行后胜蛉,才執(zhí)行后面的內(nèi)容挠进。

這里也就是先把 saved_output 和 saved_state 保存之后,再計(jì)算 logits 和 loss誊册。否則因?yàn)橄旅嬗?jì)算時(shí)沒有關(guān)聯(lián)到 saved_output 和 saved_state领突,如果不用 control_dependencies 那上面兩句保存就不會(huì)被優(yōu)化語句觸發(fā)。

tf.concat(0, values) 是指在 0 維上把 values 連接起來案怯。本來 outputs 是一個(gè) list君旦,每一個(gè)元素都是一個(gè)27維向量表示一個(gè)字母。

  # State saving across unrollings.
  with tf.control_dependencies([saved_output.assign(output),
                                saved_state.assign(state)]):
   # Classifier.
    logits = tf.nn.xw_plus_b(tf.concat(outputs, 0), w, b)
    loss = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
        labels=tf.concat(train_labels, 0), logits=logits))
定義訓(xùn)練優(yōu)化

clip_by_global_norm 的具體計(jì)算是嘲碱,先計(jì)算 global_norm 金砍,也就是整個(gè) W 的模(二范數(shù))÷缶猓看這個(gè)模是否大于文中的1.25捞魁,如果大于,則結(jié)果等于 gradients * 1.25 / global_norm离咐,如果不大于谱俭,就不變奉件。

最后,apply_gradients昆著。這里傳入的 global_step 是會(huì)被修改的县貌,每次加一,這樣下次計(jì)算 learning_rate 的時(shí)候就會(huì)使用新的 global_step 值凑懂。

  # Optimizer.
  global_step = tf.Variable(0)
  learning_rate = tf.train.exponential_decay(
    10.0, global_step, 5000, 0.1, staircase=True)
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  gradients, v = zip(*optimizer.compute_gradients(loss))
  gradients, _ = tf.clip_by_global_norm(gradients, 1.25)
#防止梯度爆炸
  optimizer = optimizer.apply_gradients(
    zip(gradients, v), global_step=global_step)
定義預(yù)測

  # Predictions.
  train_prediction = tf.nn.softmax(logits)
取樣并且驗(yàn)證評(píng)估
  # Sampling and validation eval: batch 1, no unrolling.
  sample_input = tf.placeholder(tf.float32, shape=[1, vocabulary_size])
  saved_sample_output = tf.Variable(tf.zeros([1, num_nodes]))
  saved_sample_state = tf.Variable(tf.zeros([1, num_nodes]))
  reset_sample_state = tf.group(
    saved_sample_output.assign(tf.zeros([1, num_nodes])),
    saved_sample_state.assign(tf.zeros([1, num_nodes])))
  sample_output, sample_state = lstm_cell(
    sample_input, saved_sample_output, saved_sample_state)
  with tf.control_dependencies([saved_sample_output.assign(sample_output),
                                saved_sample_state.assign(sample_state)]):
    sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b))

訓(xùn)練過程

這里評(píng)判訓(xùn)練的標(biāo)注是交叉熵困惑度
根據(jù)信息論煤痕,perplexity wikipedia定義 和 cross_entropy 的關(guān)系如下:
perplexity = e^{cross\_entropy}

num_steps = 7001
summary_frequency = 100

with tf.Session(graph=graph) as session:
  tf.global_variables_initializer().run()
  print('Initialized')
  mean_loss = 0
  for step in range(num_steps):
    batches = train_batches.next() #循環(huán)導(dǎo)入batches訓(xùn)練序列
    feed_dict = dict()
    for i in range(num_unrollings + 1):
      feed_dict[train_data[i]] = batches[i]  #訓(xùn)練數(shù)據(jù)列表,每個(gè)列表是個(gè)batch
    _, l, predictions, lr = session.run(
      [optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)

    mean_loss += l
    if step % summary_frequency == 0:
      if step > 0:
        mean_loss = mean_loss / summary_frequency
      # The mean loss is an estimate of the loss over the last few batches.也就是前幾次的平均
      print(
        'Average loss at step %d: %f learning rate: %f' % (step, mean_loss, lr))
      mean_loss = 0
'''這里注意幾個(gè)輔助函數(shù)'''
      labels = np.concatenate(list(batches)[1:])
      print('Minibatch perplexity: %.2f' % float(
        np.exp(logprob(predictions, labels))))
      if step % (summary_frequency * 10) == 0:
        # Generate some samples.
      '''這里用來生成一些可視化的樣本'''
        print('=' * 80)
        for _ in range(5):
          feed = sample(random_distribution())
          sentence = characters(feed)[0]
          reset_sample_state.run()
          for _ in range(79):
            prediction = sample_prediction.eval({sample_input: feed})
            feed = sample(prediction)
            sentence += characters(feed)[0]
          print(sentence)
        print('=' * 80)
      # Measure validation set perplexity.
      reset_sample_state.run()
      valid_logprob = 0
      for _ in range(valid_size):
        b = valid_batches.next()
        predictions = sample_prediction.eval({sample_input: b[0]})
        valid_logprob = valid_logprob + logprob(predictions, b[1])
      print('Validation set perplexity: %.2f' % float(np.exp(
        valid_logprob / valid_size)))

幾個(gè)輔助函數(shù)介紹:
logprob: 計(jì)算label和預(yù)測值的交叉熵接谨。

先回憶一下 cross_entropy:

Cross Entropy = - \sum_{i}^N({predictions \cdot \log(labels)})
那么摆碉,

logprob = { Cross Entropy \over N }

def logprob(predictions, labels):
  """Log-probability of the true labels in a predicted batch."""
  predictions[predictions < 1e-10] = 1e-10
  return np.sum(np.multiply(labels, -np.log(predictions))) / labels.shape[0]

random_distribution():[0,1]區(qū)間內(nèi)生成一個(gè)正態(tài)分布,值加和為1

def random_distribution():
  """Generate a random column of probabilities."""
  b = np.random.uniform(0.0, 1.0, size=[1, vocabulary_size])
  return b/np.sum(b, 1)[:,None]

sample_distribution(distribution):隨機(jī)選擇[0,len(distribution)]中任意一個(gè)整數(shù)值


def sample_distribution(distribution):
  """Sample one element from a distribution assumed to be an array of normalized
  probabilities.
  """
  r = random.uniform(0, 1)
  s = 0
  for i in range(len(distribution)):
    s += distribution[i]
    if s >= r:
      return i
  return len(distribution) - 1

sample(prediction):隨機(jī)one-hot

def sample(prediction):
  """Turn a (column) prediction into 1-hot encoded samples."""
  p = np.zeros(shape=[1, vocabulary_size], dtype=np.float)
  p[0, sample_distribution(prediction[0])] = 1.0
  return p
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末脓豪,一起剝皮案震驚了整個(gè)濱河市巷帝,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌扫夜,老刑警劉巖楞泼,帶你破解...
    沈念sama閱讀 211,817評(píng)論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異笤闯,居然都是意外死亡堕阔,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,329評(píng)論 3 385
  • 文/潘曉璐 我一進(jìn)店門颗味,熙熙樓的掌柜王于貴愁眉苦臉地迎上來超陆,“玉大人,你說我怎么就攤上這事浦马∈毖剑” “怎么了?”我有些...
    開封第一講書人閱讀 157,354評(píng)論 0 348
  • 文/不壞的土叔 我叫張陵捐韩,是天一觀的道長退唠。 經(jīng)常有香客問我鹃锈,道長荤胁,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 56,498評(píng)論 1 284
  • 正文 為了忘掉前任屎债,我火速辦了婚禮仅政,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘盆驹。我一直安慰自己圆丹,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,600評(píng)論 6 386
  • 文/花漫 我一把揭開白布躯喇。 她就那樣靜靜地躺著辫封,像睡著了一般硝枉。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上倦微,一...
    開封第一講書人閱讀 49,829評(píng)論 1 290
  • 那天妻味,我揣著相機(jī)與錄音,去河邊找鬼欣福。 笑死责球,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的拓劝。 我是一名探鬼主播雏逾,決...
    沈念sama閱讀 38,979評(píng)論 3 408
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼郑临!你這毒婦竟也來了栖博?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,722評(píng)論 0 266
  • 序言:老撾萬榮一對(duì)情侶失蹤牧抵,失蹤者是張志新(化名)和其女友劉穎笛匙,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體犀变,經(jīng)...
    沈念sama閱讀 44,189評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡妹孙,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,519評(píng)論 2 327
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了获枝。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片蠢正。...
    茶點(diǎn)故事閱讀 38,654評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖省店,靈堂內(nèi)的尸體忽然破棺而出嚣崭,到底是詐尸還是另有隱情,我是刑警寧澤懦傍,帶...
    沈念sama閱讀 34,329評(píng)論 4 330
  • 正文 年R本政府宣布雹舀,位于F島的核電站,受9級(jí)特大地震影響粗俱,放射性物質(zhì)發(fā)生泄漏说榆。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,940評(píng)論 3 313
  • 文/蒙蒙 一寸认、第九天 我趴在偏房一處隱蔽的房頂上張望签财。 院中可真熱鬧,春花似錦偏塞、人聲如沸唱蒸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,762評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽神汹。三九已至庆捺,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間屁魏,已是汗流浹背疼燥。 一陣腳步聲響...
    開封第一講書人閱讀 31,993評(píng)論 1 266
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留蚁堤,地道東北人醉者。 一個(gè)月前我還...
    沈念sama閱讀 46,382評(píng)論 2 360
  • 正文 我出身青樓,卻偏偏與公主長得像披诗,于是被迫代替她去往敵國和親撬即。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,543評(píng)論 2 349