SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient代碼分析

github鏈接:https://github.com/LantaoYu/SeqGAN
論文及appendix里有很好的代碼說明。

sequence_gan.py是主文件

main函數(shù)里,首先定義了generator和discriminator結(jié)構(gòu)體摊灭。

generator里:

是RNN生成長度為20的句子的過程(即:一行有20個數(shù)字)。LSTM結(jié)構(gòu)是自己實現(xiàn)的煤杀,即g_recurrent_unit函數(shù)沈自。由于輸出的是hidden_state,再用g_output_unit函數(shù)轉(zhuǎn)換為output_token_prob异逐。
placeholder里腥例,self.x是real sentence,self.rewards是RL里的Rewards燎竖。
self.h0里包括了LSTM的hidden_state和Ct璃弄。
gen_xgen_o是generator生成的sentence(word id)和每個word的prob。
接下來是一個循環(huán)构回,循環(huán)生成20個word夏块。start_token和start_hidden_state是初始化好的。具體生成方法見函數(shù)_g_recurrence纤掸。依次把結(jié)果寫入gen_xgen_o脐供。
到以上部分截止,都和待輸入的placeholder沒什么關(guān)系借跪,就是簡單的RNN網(wǎng)絡(luò)政己。


接下來是有監(jiān)督學(xué)習(xí)部分,即用MLE的思想來訓(xùn)練網(wǎng)絡(luò)掏愁。和GAN沒關(guān)系歇由。用和上文循環(huán)相似的結(jié)構(gòu)來生成prediction,即每個word的prob果港,而不是具體的word id沦泌。通過和轉(zhuǎn)化成one-hot形式的placeholder里的x作對比來計算self.pretrain_loss
再接下來是無監(jiān)督學(xué)習(xí)部分京腥,即GAN的generator部分赦肃。根據(jù)論文里的公式

self.g_loss = -tf.reduce_sum(
    tf.reduce_sum(
        tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log(
            tf.clip_by_value(tf.reshape(self.g_predictions, [-1, self.num_emb]), 1e-20, 1.0)
        ), 1) * tf.reshape(self.rewards, [-1])  # rewards是RL才有的
)

把輸入的64*20個單詞全變成一維,乘以對應(yīng)的reward公浪,再加和得到expected rewards他宛。取負(fù)即為GAN的generator部分的loss。

target_lstm里:

結(jié)構(gòu)大致和generator相同欠气,但是它的作用是生成真實的句子厅各,所以不存在訓(xùn)練過程。
設(shè)置tf.set_random_seed(66)预柒,這就是real sentence的特征队塘。

discriminator里:

待輸入的placeholder是x(sentence), y([0,1] or [1,0]), keep_prob.
其內(nèi)容是標(biāo)準(zhǔn)的CNN用于text classification的代碼。

# Convolution Layer
filter_shape = [filter_size, embedding_size, 1, num_filter]
W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
b = tf.Variable(tf.constant(0.1, shape=[num_filter]), name="b")
conv = tf.nn.conv2d(
    self.embedded_chars_expanded, #(batch_size, sequence_length , embedding_size, 1)
    W,
    strides=[1, 1, 1, 1],
    padding="VALID",
    name="conv")
# Apply nonlinearity
h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")#(batch_size, sequence_length,filter_size, num_filter )
# Maxpooling over the outputs
pooled = tf.nn.max_pool(
    h,
    ksize=[1, sequence_length - filter_size + 1, 1, 1],
    strides=[1, 1, 1, 1],
    padding='VALID',
    name="pool")#(batch_size, filter_size, 1, num_filter)

最后把輸出轉(zhuǎn)化成[-1, num_filters_total]宜鸯,再用linear函數(shù)輸出最后的prob憔古。


回到主函數(shù)里。
使用target_lstm生成real sentence淋袖,寫入real_data.txt文件鸿市。

pre-train generator:
outputs = sess.run([self.pretrain_updates, self.pretrain_loss], 
            feed_dict={self.x: x}) # 不需要generator生成sentence
pre-training discriminator:
feed = {
    discriminator.input_x: x_batch,
    discriminator.input_y: y_batch,
    discriminator.dropout_keep_prob: dis_dropout_keep_prob
}
_ = sess.run(discriminator.train_op, feed)
到了最關(guān)鍵的部分,roll-out policy,即計算reward的部分焰情。
先定義reward結(jié)構(gòu)體

參數(shù)設(shè)置都和generator一樣陌凳,placeholder里的x也是real sentence,但是新增了placeholder given_num内舟,也就是上文公式中的t合敦,意指generator句子的長度為given_num
當(dāng)i<given_num的時候验游,不用生成sentence充岛,直接讀取:

h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
x_tp1 = ta_emb_x.read(i)
gen_x = gen_x.write(i, ta_x.read(i))
return i + 1, x_tp1, h_t, given_num, gen_x

反之批狱,還是要按基本法裸准,依次選出單詞:

# 這里的input_x是generator生成的fake sentence
h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
log_prob = tf.log(tf.nn.softmax(o_t))
next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32)
x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token)  # batch x emb_dim
gen_x = gen_x.write(i, next_token)  # indices, batch_size
return i + 1, x_tp1, h_t, given_num, gen_x

get_reward函數(shù)中,從1到20嘗試given_num赔硫,計算expected reward。

for given_num in range(1, 20):  # 最后只到19
    feed = {self.x: input_x, self.given_num: given_num}# 這里需要輸入input_x才能生成gen_x盐肃,因為gen_x的前一部分是固定好的
    samples = sess.run(self.gen_x, feed)
    feed = {discriminator.input_x: samples, discriminator.dropout_keep_prob: 1.0}
    ypred_for_auc = sess.run(discriminator.ypred_for_auc, feed)
    ypred = np.array([item[1] for item in ypred_for_auc])# 得到每個句子的分值
    if i == 0:
        rewards.append(ypred)
    else:
        rewards[given_num - 1] += ypred  # 把所有得分加在一起

# the last token reward
feed = {discriminator.input_x: input_x, discriminator.dropout_keep_prob: 1.0}#最后一個word就不用自己run生成了爪膊,直接讀取input_x
ypred_for_auc = sess.run(discriminator.ypred_for_auc, feed)
ypred = np.array([item[1] for item in ypred_for_auc])
if i == 0:
    rewards.append(ypred)
else:
    rewards[19] += ypred

這樣,就得到了對于任意長度的句子的rewards砸王。

就是對于已經(jīng)生成的不同長度的sentence推盛,擴展到完整的句子,再用discriminator來打分谦铃。

再根據(jù)形如

 self.Wi = self.update_rate * self.Wi + (1 - self.update_rate) * tf.identity(self.lstm.Wi)

的公式更新roll-out里的參數(shù)耘成。
discriminator還是根據(jù)相同的公式來訓(xùn)練。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末驹闰,一起剝皮案震驚了整個濱河市瘪菌,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌嘹朗,老刑警劉巖师妙,帶你破解...
    沈念sama閱讀 206,839評論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異屹培,居然都是意外死亡默穴,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,543評論 2 382
  • 文/潘曉璐 我一進店門褪秀,熙熙樓的掌柜王于貴愁眉苦臉地迎上來蓄诽,“玉大人,你說我怎么就攤上這事媒吗÷胤眨” “怎么了?”我有些...
    開封第一講書人閱讀 153,116評論 0 344
  • 文/不壞的土叔 我叫張陵蝴猪,是天一觀的道長调衰。 經(jīng)常有香客問我膊爪,道長,這世上最難降的妖魔是什么嚎莉? 我笑而不...
    開封第一講書人閱讀 55,371評論 1 279
  • 正文 為了忘掉前任米酬,我火速辦了婚禮,結(jié)果婚禮上趋箩,老公的妹妹穿的比我還像新娘赃额。我一直安慰自己,他們只是感情好叫确,可當(dāng)我...
    茶點故事閱讀 64,384評論 5 374
  • 文/花漫 我一把揭開白布跳芳。 她就那樣靜靜地躺著,像睡著了一般竹勉。 火紅的嫁衣襯著肌膚如雪飞盆。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,111評論 1 285
  • 那天次乓,我揣著相機與錄音吓歇,去河邊找鬼。 笑死票腰,一個胖子當(dāng)著我的面吹牛城看,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播杏慰,決...
    沈念sama閱讀 38,416評論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼测柠,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了缘滥?” 一聲冷哼從身側(cè)響起轰胁,我...
    開封第一講書人閱讀 37,053評論 0 259
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎完域,沒想到半個月后软吐,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 43,558評論 1 300
  • 正文 獨居荒郊野嶺守林人離奇死亡吟税,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,007評論 2 325
  • 正文 我和宋清朗相戀三年凹耙,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片肠仪。...
    茶點故事閱讀 38,117評論 1 334
  • 序言:一個原本活蹦亂跳的男人離奇死亡肖抱,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出异旧,到底是詐尸還是另有隱情意述,我是刑警寧澤,帶...
    沈念sama閱讀 33,756評論 4 324
  • 正文 年R本政府宣布,位于F島的核電站荤崇,受9級特大地震影響拌屏,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜术荤,卻給世界環(huán)境...
    茶點故事閱讀 39,324評論 3 307
  • 文/蒙蒙 一倚喂、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧瓣戚,春花似錦端圈、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,315評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至仑嗅,卻和暖如春宴倍,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背仓技。 一陣腳步聲響...
    開封第一講書人閱讀 31,539評論 1 262
  • 我被黑心中介騙來泰國打工啊楚, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人浑彰。 一個月前我還...
    沈念sama閱讀 45,578評論 2 355
  • 正文 我出身青樓,卻偏偏與公主長得像拯辙,于是被迫代替她去往敵國和親郭变。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 42,877評論 2 345

推薦閱讀更多精彩內(nèi)容

  • “深度解讀:GAN模型及其在2016年度的進展”[1]一文對過去一年GAN的進展做了詳細(xì)介紹涯保,十分推薦學(xué)習(xí)GAN的...
    MiracleJQ閱讀 2,008評論 0 7
  • 作者:貝克·哈吉斯 美國著名企業(yè)家诉濒、國際營銷大師,擁有多家快速成長的公司夕春。 這是一本薄薄的書未荒,122頁,2小時候左...
    李樺成長記閱讀 1,373評論 0 0
  • 一.CSS的全稱是什么及志? 層疊樣式表(英文全稱:Cascading Style Sheets)是一種用來表現(xiàn)HTM...
    Sunset125閱讀 406評論 0 1
  • 一月結(jié)束了片排。二月,沓芷蝶來速侈。 一切都在有條不紊的進行著率寡,恰似春光旭日,慢慢悠悠倚搬,舒適自然地啟幕冶共,打開、進行、乍泄捅僵。...
    Miss瓦爾登湖閱讀 331評論 0 0