Chen Z, Song Y, Chang T H, et al. Generating Radiology Reports via Memory-driven Transformer[C]//Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP). 2020: 1439-1449.
代碼倉:R2Gen code
任務目標
輸入一張醫(yī)學影像狮崩,生成相應的報告飞蹂。
難點
- 醫(yī)學報告句子很多集乔,使用常用的只生成一句話的Image Captioning models可能不足以生成醫(yī)學報告。
- 要求的精度也比較高辉词。
醫(yī)學報告也有特有的特征禽绪,圖片和報告的格式都高度模式化。目前的解決方案:
- retrieval-based. 大數據集的準備
- retrieval-based + generation-based + manually extracted templates. 模板的準備
- 本文使用的是 generation-based model
模型簡介
本文使用memory-dirven Transformer生成醫(yī)學報告侣肄。主要工作:
- 提出了relational memory (RM) 模塊記錄之前生成過程的信息旧困;
- 提出了memory-driven conditional layer normalization (MCLN) 把RM和Transformer結合起來。
模型結構:Visual Extractor + Encoder + Decoder + Relational Memory
1. Visual Extractor
這一部分的主要任務就是把圖像轉化為序列數據稼锅,從而可以輸入到Encoder中吼具。使用常用的卷積神經網絡就可以,把最后的Linear去掉矩距,留有最后的patch feature以及fc_feature就可以拗盒。
例如本文使用ResNet101預訓練模型,每一組數據輸入的圖像為兩張彩色圖像锥债。
- 輸入shape為(b, 2, c, h, w)
- 視覺提取器分別對兩張圖進行特征提取陡蝇。
- ResNet101去除掉最后一層Linear與Pooling層,輸入(b, c, h, w)哮肚,輸出(b, 2048, 7, 7)
- 最后經過resize登夫,permutation,shape=(b, 49, 2048)
- 兩張圖像的特征在axis=1上拼接允趟,得到patch_feature(b, 98, 2048)恼策。這個維度可以看作是batch * seq_len * embedding。
- 第二組特征fc_feature是在patch_feature的基礎上再次Pooling生成的(b, 2048)拼窥,在axis=1上拼接后戏蔑,得到(b, 4096)。
2. Encoder
編碼器把視覺特征處理鲁纠,使用attention機制总棵,得到最終的特征,作為K改含,V輸入到decoder中情龄。
- 首先有一個src_embedding,把視覺特征維度轉換為d_model=512,方便輸入Transformer骤视。(按理說這部分是在transformer里實現的鞍爱,但作者的代碼在CaptionModel里實現)
- 數據x.shape=(b, seq_len, d_model)輸入后,作為query专酗,key睹逃,value輸入到attention中(這里d_model=head * d_k),最后又得到相同shape的輸出祷肯。
3. Relational Memory
這塊的設計是為了使模型可以學到更好的report patterns沉填,和retrieval-based 里面模板的準備差不多。RM使用矩陣存儲pattern information with each row佑笋,稱作memory slot翼闹。每步生成的過程,矩陣都會更新蒋纬。在第t步猎荠,矩陣用作Q,和前一步輸出的embedding
拼接起來作為K蜀备,V進入到MultiHeadAttention关摇。
Attention
這里的K Q V計算機制與Encoder里的稍有不同
最終attention計算得到的結果記為。因為M是循環(huán)計算的琼掠,可能梯度消失或者爆炸拒垃,因此引入了residual connections 和 gate mechanism。
Residual connection
M的中間值為
Gate Mechanism
gate mechanism的結構如圖所示:
輸入門和遺忘門用來平衡
其中的U和W都是可訓練參數横堡。最終gate的輸出為:
其中
Memory-driven Conditional Layer Normalization
常見的模型memory都在encoder部分,本文單獨設計并與decoder緊密聯系食听。與Attention中 LayerNorm對比胸蛛,提出了MCLN。把用到了Norm里
的計算上樱报。主要思路是葬项,把
拉成一個向量,再用MLP去預測
的變化量迹蛤,最后再更新民珍。
4. Decoder
共有三個結構襟士,self_attention + src_attention + FFN
輸入參數有:
- 輸出序列的embedding:tgt,經過了embedding + positionalEncoding嚷量,輸入到self_attention中陋桂,得到的結果作為src_attention的Q,這里使用了tgt_mask
- encoder的輸出內容:src蝶溶,要用到src_attention中的K和V嗜历,這里使用了src_mask
- src_mask, tgt_mask 在attention 計算過程中用到
- RM的輸出結果memory,每一個t時刻的memory都被拉成了向量身坐,最后拼接在一起秸脱,在每一個MCLN中用到落包。
代碼理解
作者的R2Gen模型里EncoderDecoder模塊是最復雜的部蛇。
- 首先實現了CaptionModel類,可以調用函數咐蝇,分別執(zhí)行_forward()和_sample()涯鲁,實現了beam_search()。
- 然后AttModel繼承于CaptionModel類有序,實現了_sample()函數抹腿,在測試過程中用到。
- 最后EncoderDecoder又繼承于AttModel旭寿,實現了_forward()警绩,在訓練過程中用到。
- 最后的搜索過程盅称,也就是_sample() 函數根據不同的策略會有不同的實現肩祥。
本文的創(chuàng)新之處在于設計了Relational Memory 模塊,并使用到MCLN中缩膝。