文獻編號:5
文獻著作信息:
Vision Transformer for Fast and Efficient Scene Text Recognition
論文地址
代碼地址
18 May 2021
研究主題:
Scene text recognition
Transformer
Data augmentation
研究問題:
低參數量摇天,低計算量的STR模型
主要是精度不變情況下的提速
研究原因:
因為大多數只考慮了識別的精度宠哄,并沒有考慮到移動設備的需求
我的收獲和疑問
為了平衡準確性乓旗、速度和效率的重要性违孝,作者建議利用視覺轉換器(Vit)的簡單和高效的優(yōu)勢刹前。例如數據高效圖像轉換器(Deit)
Deit.pdf (arxiv.org)
ViT證明,僅使用transformer的encoder(好幾個連起來)就可以實現ImageNet識別中得到SOTA結果雌桑。
ViT繼承了transformer的所有特性喇喉,包括速度和計算效率
作者的框架也是這樣做的,因為圖片校坑,也需要位置編碼
用了【我的參考文獻3】的框架拣技,相同框架下,才能更好的比較不同模型的性能
MJ和ST 各用50%撒踪,如果用100%性能會下降
再自己寫論文時,可以把自己的參數設置寫成表格呈現給審稿人
研究設計:
作者試圖平衡準確性大渤、速度和效率制妄。準確性是指識別文本的正確性。速度是通過單位時間內處理多少文本圖像來衡量的泵三。效率可以用處理一張圖像所需的參數和計算(如FLOPS)的數量近似表示耕捞。參數的數量反映內存需求,而FLOPS估計完成任務所需的指令數量烫幕。理想的STR是精確和快速的俺抽,而只需要很少的計算資源。
研究發(fā)現:
使用Deit的模型權重,Deit簡單地是通過知識蒸餾訓練的VIT
對于機器來說枪孩,在人類環(huán)境中閱讀文本是一項具有挑戰(zhàn)性的任務憔晒,因為符號的可能外觀不同藻肄。圖2顯示了受曲率、字體樣式拒担、模糊嘹屯、旋轉、噪聲澎蛛、幾何圖形抚垄、照明、遮擋和分辨率影響的文本的示例谋逻。還有許多其他因素可能會影響文本圖像呆馁,例如天氣條件、相機傳感器缺陷毁兆、運動浙滤、照明等。
研究結論:
通過針對STR的數據增強气堕,ViTSTR可以顯著提高準確性纺腊,特別是對于不規(guī)則數據集。
當規(guī)模擴大時茎芭,ViTSTR保持在前沿揖膜,以平衡精度、速度和計算要求梅桩。
帶問題看論文:
相關工作
字符串以正確的順序標識圖像中文本的每個字符壹粟。與通常只有一類對象的對象識別不同,對于給定的文本圖像宿百,可以有零個或多個字符趁仙。因此,STR模型更加復雜垦页。與許多視覺問題類似雀费,早期的方法[24,38]使用手工制作的特征痊焊,導致性能較差盏袄。深度學習極大地推動了STR領域的發(fā)展。
2019年薄啥,Baek等人提出【我的參考文獻編號3】貌矿。[1]提出了一個對現代STR設計模式進行建模的框架。圖3顯示了STR的四個階段或模塊罪佳。從廣義上講逛漫,即使是最近提出的基于變壓器的模型、無遞歸序列對序列文本識別器(NRTR)[29]和自注意文本識別網絡(SATRN)[18]等方法也可以適用于校正-特征提取(Backbone)-序列建模-預測框架
校正階段去除文字圖像的失真赘艳,使文本水平或規(guī)范化酌毡。這使得特征提取(Backbone)模塊更容易確定不變特征克握。薄板樣條(TPS)[3]通過尋找和校正基準點來模擬畸變。RARE(帶有自動校正的健壯文本識別器)[31]枷踏、STAR-Net(空間注意殘留網絡)[21]和TRBA (TPS- resnet - bilstm -Attention)[1]使用TPS菩暗。ESIR(端到端可訓練場景文本識別)[41]采用迭代校正網絡,顯著提高了文本識別模型的性能旭蠕。在某些情況下停团,沒有采用整改,如CRNN卷積循環(huán)神經網絡[30]掏熬,R2AM(帶有注意力建模的遞歸循環(huán)神經網絡)[17]佑稠,GCRNN(門控循環(huán)卷積神經網絡)[36]和Rosetta[4]
特征提取(Backbone)階段的作用是自動確定每個字符符號的不變特征。STR在對象識別任務中使用相同的特征提取器旗芬,如VGG[32]舌胶、ResNet[11]和CNN的一個變體RCNN[17]。Rosetta, STAR-Net和TRBA使用ResNet疮丛。利用VGG提取RARE和CRNN特征幔嫂。R2AM和GCRNN建立在RCNN的基礎上√鼙。基于變壓器的模型NRTR和SATRN使用定制的CNN塊來提取變壓器編碼器-解碼器文本識別的特征
預測階段檢查由主干或序列建模產生的特征履恩,以達到字符預測序列。CTC(連接主義時間分類)[8]通過有效地對所有可能的輸入-輸出序列對齊進行求和呢蔫,最大限度地提高了輸出序列的可能性切心。CTC的替代方案是注意力機制[2],它學習圖像特征和符號之間的對齊咐刨。CRNN, GRCNN, Rosetta和STAR-Net使用CTC昙衅。R2AM, RARE和TRBA是基于注意力的
與自然語言處理(NLP)一樣余蟹,變壓器通過并行的自我注意和預測克服了序列建模和預測卷胯。這就產生了一個快速有效的模型。如圖3所示威酒,基于電流互感器的STR模型仍然需要一個骨干和一個變壓器編碼器-解碼器窑睁。最近挺峡,ViT[7]證明了它可以在ImageNet1k[28]分類上僅使用變壓器編碼器,但在非常大的數據集(如ImageNet21k和JFT-300M)上預先訓練它担钮,從而擊敗諸如ResNet[11]和efficiency entnet[33]等深度網絡的性能橱赠。DeiT[34]證明了ViT不需要大數據集,甚至可以獲得更好的結果箫津,但必須使用知識蒸餾[13]進行訓練狭姨。ViT是使用預先訓練的DeiT權重的基礎,是我們提出的快速有效的STR稱為ViTSTR的基礎苏遥。如圖3所示饼拍,ViTSTR是一個非常簡單的模型,只有一級暖眼,可以輕松地將基于變壓器的STR的參數數量和FLOPS減少一半惕耕。
ViT和ViTSTR之間的唯一區(qū)別是預測頭。ViTSTR必須識別具有正確序列順序和長度的多個字符诫肠,而不是單一對象類識別司澎。預測是并行進行的
在原始的ViT中,使用與可學習類嵌入相對應的輸出向量進行對象類別預測栋豫。在ViTSTR中挤安,這對應于[GO]令牌。此外丧鸯,我們不再只提取一個輸出向量蛤铜,而是從編碼器中提取多個特征向量。這個數字等于數據集中文本的最大長度加上兩個[GO]和[s]令牌丛肢。我們使用[GO]標記標記文本預測的開始围肥,并使用[s]標記注明結尾或空格。[s]在每個文本預測的末尾重復蜂怎,直到最大序列長度穆刻,以標記文本字符之后沒有任何內容。
圖5顯示了一個編碼器塊內的層杠步。每個輸入都經過層歸一化(LN)氢伟。多頭自注意層(Multi-head Self-Attention layer, MSA)確定特征向量之間的關系。Vaswani等人[35]發(fā)現幽歼,使用多個頭部而不是一個頭部可以讓模型共同關注來自不同位置的不同表示子空間的信息朵锣。頭部數為h。多層感知器(Multilayer Perceptron, MLP)進行特征提取甸私。它的輸入也是層規(guī)范化的诚些。MLP由2層組成,GELU激活[12]皇型。LN的輸出與MSA/MLP之間存在殘差連接诬烹。
or l = 1…L為編碼器塊的深度或數量
for i = 1…S是[GO]和[S]令牌的最大文本長度加2。表1總結了ViTSTR配置椅您。
作者用了【我的參考文獻3】的框架
為了對不同的模型做出公平的評價外冀,一個統(tǒng)一的框架是很重要的。統(tǒng)一的框架確保在評估中使用一致的訓練和測試條件掀泳。下面的討論描述了在性能比較中一直存在爭議的訓練數據集和測試數據集雪隧。使用不同的訓練和測試數據集可能會嚴重傾向于支持或反對某種性能報告。
數據集
由于缺乏大數據集的真實數據员舵,STR模型訓練的實踐是使用合成數據脑沿。使用兩個流行的數據集:1)MJSynth (MJ)[14]或也稱為Synth90k和2)SynthText (ST)[9]。
MJ
MJSynth (MJ)是一個合成生成的數據集马僻,由890萬逼真的文字圖像組成庄拇。MJSynth被設計成有3層:1)背景,2)前景和3)可選的陰影/邊框韭邓。它使用了1400種不同的字體措近。字體的字距、粗細女淑、下劃線和其他屬性是不同的瞭郑。MJSynth還利用了不同的背景效果,邊界/陰影渲染鸭你,基礎著色屈张,投影失真,自然圖像混合和噪聲袱巨。
ST
SynthText (ST)是另一個由550萬單詞圖像合成生成的數據集阁谆。SynthText是通過在自然圖像上混合合成文本生成的。它使用場景幾何愉老、紋理和表面法線來自然地混合和扭曲圖像中物體表面上的文本渲染场绿。與MJSynth類似,SynthText的文本使用隨機字體俺夕。文字圖像是從嵌入合成文本的自然圖像中裁剪出來的
在STR框架中裳凸,每個數據集占整個列車數據集的50%贱鄙。將兩個數據集100%地結合在一起會導致性能下降[我的參考文獻3]劝贸。圖6顯示了來自MJ和ST的示例圖像
測試數據集是由幾個小的公開的自然圖像文本STR數據集組成的。這些數據集通常分為兩組:1)常規(guī)和2)不規(guī)則
常規(guī)數據集的文本圖像是正面的逗宁,水平的映九,并且有最小的失真。IIIT5K-Words[23]瞎颗,街景文本(SVT) [37]件甥, ICDAR2003 (IC03)[22]和ICDAR2013 (IC13)[16]被認為是常規(guī)數據集捌议。同時,不規(guī)則數據集包含具有挑戰(zhàn)性外觀的文本引有,如彎曲瓣颅、垂直、透視譬正、低分辨率或扭曲宫补。ICDAR2015 (IC15)[15]、SVT Perspective (SVTP)[25]和CUTE80 (CT)[27]屬于不規(guī)則數據集曾我。圖7顯示了來自規(guī)則和不規(guī)則數據集的樣本粉怕。對于兩個數據集,只有測試分割用于評估
規(guī)則數據集
IIT5K包含3000張用于測試的圖像抒巢。圖像大多來自街景贫贝,如招牌、品牌標志蛉谜、門牌號或路牌稚晚。
SVT有647張圖片用于測試。文本圖像是從谷歌街景圖片裁剪型诚。
IC03包含來自ICDAR2003健壯閱讀比賽的1,110張測試圖像蜈彼。圖像是從自然場景中捕捉的。在刪除長度小于3個字符的單詞后俺驶,結果是860張圖像幸逆。然而,另外7張圖片被發(fā)現丟失了暮现。因此还绘,該框架還包含867個測試圖像版本∑艽—IC13是
IC03的擴展拍顷,共享類似的鏡像。IC13是為ICDAR2013健壯閱讀比賽而創(chuàng)建的塘幅。在文獻和框架中昔案,使用了兩個版本的測試數據集:1)857和2)1015。
不規(guī)則的數據集
IC15有ICDAR2015健壯閱讀比賽的文本圖片电媳。許多圖像模糊踏揣、嘈雜、旋轉匾乓,有時分辨率很低捞稿,因為這些圖像是使用谷歌眼鏡拍攝的,佩戴者處于無約束運動狀態(tài)。文獻和框架中使用了兩個版本:1)1811張和2)2077張圖像娱局。2077個版本包含旋轉彰亥、垂直、透視和彎曲的圖像衰齐。
SVTP有645張來自谷歌街景的測試圖像任斋。大多數是商業(yè)標牌的圖片。-
CT專注于從襯衫和產品標志中捕獲的彎曲文本圖像耻涛。該數據集有288張圖像仁卷。
表2列出了框架中推薦的培訓配置。我們復制了幾個強基線模型的結果:CRNN, R2AM, GCRNN, Rosetta, RARE, STAR-Net和TRBA犬第,以與ViTSTR進行公平的比較锦积。我們使用不同的隨機種子對所有模型進行至少5次訓練。保存測試數據集上表現最好的權重以獲得平均評估分數歉嗓。
對于ViTSTR丰介,我們使用相同的列車配置,除了輸入被調整為224 × 224鉴分,以匹配預訓練的DeiT[34]的尺寸哮幢。在訓練ViTSTR之前,會自動下載DeiT預訓練的權重文件志珍。ViTSTR可以端到端訓練橙垢,沒有凍結參數
表3和表4顯示了不同模型的性能得分。我們報告了準確性伦糯、速度柜某、參數數量和FLOPS,以得到折衷的總體情況敛纲,如圖1所示喂击。為了準確性,我們在大多數STR模型的大小寫敏感訓練和大小寫不敏感評估中遵循框架評估協(xié)議淤翔。對于速度翰绊,報告的數字是基于2080Ti GPU上的模型運行時間。與其他模型基準(如[19,20])不同旁壮,在評估之前监嗜,我們不旋轉垂直文本圖像(例如,表5 IC15)抡谐。
數據增強
使用專門針對STR的數據增強配方可以顯著提高ViTSTR的準確性厦坛,在圖8中五垮,我們可以看到不同之處
數據擴充會改變圖像,但不會改變其中文本的含義杜秸。表3顯示放仗,對不同的圖像變換(如反轉、彎曲撬碟、模糊诞挨、噪聲、扭曲呢蛤、旋轉惶傻、拉伸/壓縮、透視和收縮)應用RandAugment[6]后其障,ViTSTR-TINY的通用性提高了+1.8%银室,ViTSTR-Small的通用性提高了+1.6%,ViTSTR-Base的通用性提高了1.5%励翼。準確率提高最大的是不規(guī)則數據集蜈敢,例如CT(+9.2%極小,+6.6%小和基本)汽抚、SVTP(+3.8%極小抓狭,+3.3%小,+1.8%基本)造烁、IC15 1,811(+2.7%極小否过,+2.6%小,+1.7%基本)和IC15 2,077(+2.5%極小惭蟋,+2.2%小叠纹,+1.5%基本)。
注意力
圖9顯示了ViTSTR讀出文本圖像時的注意圖敞葛。當注意力適當地集中在每個字符上時誉察,ViTSTR也會關注相鄰的字符。也許惹谐,上下文是在單個符號預測期間放置的持偏。
STR模型的 性能懲罰
在STR模型中每增加一個階段,就會獲得一個精度氨肌,但代價是速度變慢和計算量增加鸿秆。例如,RARE?→TRBA提高了2.2%的準確率怎囚,但需要388m的參數卿叽,并將任務完成速度降低了4 msec/image桥胞。像STAR-Net?→TRBA那樣將CTC階段替換為Attention,將計算速度從8.8 msec/張圖像顯著降低到22.8 msec/張圖像考婴,從而獲得額外的2.5%的精度贩虾。事實上,從CTC到Attention的變化所帶來的放緩沥阱,與在管道中添加BiLSTM或TPS相比缎罢,是> 10倍。在ViTSTR中,從小版本到小版本的過渡需要增加嵌入尺寸和頭部數量。不需要額外的階段疏日。為了獲得2.3%的精度芜飘,性能損失是參數數量增加16.1M。從微小到基本,獲得3.4%的精度的性能懲罰是額外的80.4M參數。在這兩種情況下,速度幾乎沒有變化酬蹋,因為我們在MLP和MSA中使用了相同的并行張量點積、softmax和加法運算層的變壓器編碼器抽莱。只有張量維度增加范抓,導致任務完成速度降低0.2到0.3 msec/圖像。與多級STR不同食铐,額外的模塊需要額外的連續(xù)的前向傳播層匕垫,這不能并行化,從而導致顯著的性能損失
失敗案例
表5顯示了ViTSTR-Small在每個測試數據集中失敗的預測樣本虐呻。導致預測錯誤的主要原因是相似符號混淆(如8和B, J和I)象泵,腳本字體(如Inc中的I),字符眩光斟叼,垂直文本偶惠,嚴重彎曲的文本圖像和部分遮擋的符號。請注意朗涩,在某些情況下忽孽,即使是人類讀者也很容易犯錯誤。然而谢床,人類使用語義來解決歧義兄一。語義已經在最近的STR方法中使用了[26,39]
代碼閱讀
def get_args(is_train=True):
parser = argparse.ArgumentParser(description='STR')
# for test
parser.add_argument('--eval_data', required=not is_train, help='path to evaluation dataset')
parser.add_argument('--benchmark_all_eval', action='store_true', help='evaluate 10 benchmark evaluation datasets')
parser.add_argument('--calculate_infer_time', action='store_true', help='calculate inference timing')
parser.add_argument('--flops', action='store_true', help='calculates approx flops (may not work)')
# for train
parser.add_argument('--exp_name', help='Where to store logs and models')
parser.add_argument('--train_data', required=is_train, help='path to training dataset')
parser.add_argument('--valid_data', required=is_train, help='path to validation dataset')
parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting')
parser.add_argument('--workers', type=int, help='number of data loading workers. Use -1 to use all cores.', default=4)
parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for')
parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation')
parser.add_argument('--saved_model', default='', help="path to model to continue training")
parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning')
parser.add_argument('--sgd', action='store_true', help='Whether to use SGD (default is Adadelta)')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)')
parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta')
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9')
parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95')
parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5')
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
""" Data processing """
parser.add_argument('--select_data', type=str, default='MJ-ST',
help='select training data (default is MJ-ST, which means MJ and ST used as training data)')
parser.add_argument('--batch_ratio', type=str, default='0.5-0.5',
help='assign ratio for each selected data in the batch')
parser.add_argument('--total_data_usage_ratio', type=str, default='1.0',
help='total data usage ratio, this ratio is multiplied to total number of data.')
parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
parser.add_argument('--rgb', action='store_true', help='use rgb input')
parser.add_argument('--character', type=str,
default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode')
""" Model Architecture """
parser.add_argument('--Transformer', action='store_true', help='Use end-to-end transformer')
choices = ["vitstr_tiny_patch16_224", "vitstr_small_patch16_224", "vitstr_base_patch16_224", "vitstr_tiny_distilled_patch16_224", "vitstr_small_distilled_patch16_224"]
parser.add_argument('--TransformerModel', default=choices[0], help='Which vit/deit transformer model', choices=choices)
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
parser.add_argument('--FeatureExtraction', type=str, required=True,
help='FeatureExtraction stage. VGG|RCNN|ResNet')
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. None|CTC|Attn')
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
parser.add_argument('--input_channel', type=int, default=1,
help='the number of input channel of Feature extractor')
parser.add_argument('--output_channel', type=int, default=512,
help='the number of output channel of Feature extractor')
parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
# selective augmentation
# can choose specific data augmentation
parser.add_argument('--issel_aug', action='store_true', help='Select augs')
parser.add_argument('--sel_prob', type=float, default=1., help='Probability of applying augmentation')
parser.add_argument('--pattern', action='store_true', help='Pattern group')
parser.add_argument('--warp', action='store_true', help='Warp group')
parser.add_argument('--geometry', action='store_true', help='Geometry group')
parser.add_argument('--weather', action='store_true', help='Weather group')
parser.add_argument('--noise', action='store_true', help='Noise group')
parser.add_argument('--blur', action='store_true', help='Blur group')
parser.add_argument('--camera', action='store_true', help='Camera group')
parser.add_argument('--process', action='store_true', help='Image processing routines')
# use cosine learning rate decay
parser.add_argument('--scheduler', action='store_true', help='Use lr scheduler')
parser.add_argument('--intact_prob', type=float, default=0.5, help='Probability of not applying augmentation')
parser.add_argument('--isrand_aug', action='store_true', help='Use RandAug')
parser.add_argument('--augs_num', type=int, default=3, help='Number of data augment groups to apply. 1 to 8.')
parser.add_argument('--augs_mag', type=int, default=None, help='Magnitude of data augment groups to apply. None if random.')
# for comparison to other augmentations
parser.add_argument('--issemantic_aug', action='store_true', help='Use Semantic')
parser.add_argument('--isrotation_aug', action='store_true', help='Use ')
parser.add_argument('--isscatter_aug', action='store_true', help='Use ')
parser.add_argument('--islearning_aug', action='store_true', help='Use ')
# orig paper uses this for fast benchmarking
parser.add_argument('--fast_acc', action='store_true', help='Fast average accuracy computation')
parser.add_argument('--infer_model', type=str,
default=None, help='generate inference jit model')
parser.add_argument('--quantized', action='store_true', help='Model quantization')
parser.add_argument('--static', action='store_true', help='Static model quantization')
args = parser.parse_args()
return
傳參
opt = get_args()
模型
請忽略縮進骂束,需要源代碼可去github上下載
class Model(nn.Module):
def __init__(self, opt):
super(Model, self).__init__()
self.opt = opt
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction,
'ViTSTR': opt.Transformer}
""" Transformation """
if opt.Transformation == 'TPS':
self.Transformation = TPS_SpatialTransformerNetwork(
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
else:
print('No Transformation module specified')
if opt.Transformer:
self.vitstr= create_vitstr(num_tokens=opt.num_class, model=opt.TransformerModel)
return
""" FeatureExtraction """
if opt.FeatureExtraction == 'VGG':
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'RCNN':
self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'ResNet':
self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
else:
raise Exception('No FeatureExtraction module specified')
self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
""" Sequence modeling"""
if opt.SequenceModeling == 'BiLSTM':
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
self.SequenceModeling_output = opt.hidden_size
else:
print('No SequenceModeling module specified')
self.SequenceModeling_output = self.FeatureExtraction_output
""" Prediction """
if opt.Prediction == 'CTC':
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
elif opt.Prediction == 'Attn':
self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
else:
raise Exception('Prediction is neither CTC or Attn')
def forward(self, input, text, is_train=True, seqlen=25):
""" Transformation stage """
if not self.stages['Trans'] == "None":
input = self.Transformation(input)
if self.stages['ViTSTR']:
prediction = self.vitstr(input, seqlen=seqlen)
return prediction
""" Feature extraction stage """
visual_feature = self.FeatureExtraction(input)
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
visual_feature = visual_feature.squeeze(3)
""" Sequence modeling stage """
if self.stages['Seq'] == 'BiLSTM':
contextual_feature = self.SequenceModeling(visual_feature)
else:
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
""" Prediction stage """
if self.stages['Pred'] == 'CTC':
prediction = self.Prediction(contextual_feature.contiguous())
else:
prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length)
return prediction
沒有數據增強的訓練
RANDOM=$$
CUDA_VISIBLE_DEVICES=0 python3 train.py --train_data data_lmdb_release/training
--valid_data data_lmdb_release/evaluation --select_data MJ-ST
--batch_ratio 0.5-0.5 --Transformation None --FeatureExtraction None \
--SequenceModeling None --Prediction None --Transformer
--TransformerModel=vitstr_tiny_patch16_224 --imgH 224 --imgW 224
--manualSeed=$RANDOM --sensitive
無特征提取耳璧,序列模型,只有transformer