Tensorflow2單機多GPU數(shù)據(jù)準備與訓(xùn)練說明

前言

能看到這篇文章的驳规,都是富貴讓我們相遇。
現(xiàn)在這光景艾君,單GPU都困難采够,何況多GPU訓(xùn)練。腻贰。吁恍。

幾個需要注意的點

  1. 模型生成部分需要使用tf.distribute.MirroredStrategy
  2. 為了將batch size的數(shù)據(jù)均等分配給各個GPU的顯存,需要通過tf.data.Dataset.from_generator托管數(shù)據(jù)播演,從迭代器加載冀瓦,同時顯式關(guān)閉AutoShardPolicy。如果不做這一步写烤,顯存分配可能會出問題翼闽,不僅顯存會爆,還可能過程中的validation loss計算會出問題洲炊。
  3. 為了避免觸發(fā)tensorflow2在完成以上步驟感局,訓(xùn)練過程中metrics的計算bug,需要做到如下幾點暂衡!這個地方是痛點询微,如果不仔細跟蹤,是很難發(fā)現(xiàn)的狂巢!
    metrics一定設(shè)置為binary_accuracy撑毛,或者sparse_categorical_accuracy
    不能簡單設(shè)置為acc
    否則之后會報:as_list() is not defined on an unknown TensorShape的錯誤
  4. 之所以使用生成器動態(tài)產(chǎn)生訓(xùn)練數(shù)據(jù),不僅僅是為了避免一次性加載訓(xùn)練數(shù)據(jù)唧领,直接吃爆顯存藻雌,還因為需要實時對訓(xùn)練數(shù)據(jù)做數(shù)據(jù)增強與變換,增加模型的魯棒性斩个。

代碼部分

模型生成與編譯部分

直接看tf.distribute.MirroredStrategy的用法胯杭,損失函數(shù),優(yōu)化函數(shù)的根據(jù)自己習(xí)慣來受啥。但是metrics一定不能選擇acc做个!

gpus = tf.config.list_physical_devices('GPU')
batchsize = 8
print('apply: Adam + weighted_bce_dice_loss_v1_7_3')
if len(gpus) > 1:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(device=gpu, enable=True)
    batchsize *= len(gpus)
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = table_line.get_model(input_shape=(512, 512, 3),
                                     is_resnest_unet=is_resnest_unet,
                                     is_swin_unet=is_swin_unet,
                                     resnest_pretrain_model=resnest_pretrain_model)
        # apply custom loss
        model.compile(
            optimizer=Adam(
                lr=0.0001),
            loss=weighted_bce_dice_loss_v1_7_3,
            metrics=['binary_accuracy'])
else:
    model = table_line.get_model(input_shape=(512, 512, 3),
                                 is_resnest_unet=is_resnest_unet,
                                 is_swin_unet=is_swin_unet,
                                 resnest_pretrain_model=resnest_pretrain_model)
    model.compile(
        optimizer=Adam(
            lr=0.0001),
        loss=weighted_bce_dice_loss_v1_7_3,
        metrics=['binary_accuracy'])
print('batch size: {0}, GPUs: {1}'.format(batchsize, gpus))

數(shù)據(jù)迭代器生成部分

def makeDataset(generator_func,
                data_list,
                line_path,
                batchsize,
                draw_line,
                is_raw,
                need_rotate,
                only_flip,
                is_wide_line,
                strategy=None):
    # Get amount of files
    ds = tf.data.Dataset.from_generator(generator_func,
                                        args=[data_list, line_path, batchsize,
                                              draw_line, is_raw, need_rotate,
                                              only_flip, is_wide_line],
                                        output_types=(tf.float64, tf.float64))
    # Make a dataset from the generator. MAKE SURE TO SPECIFY THE DATA TYPE!!!
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
    ds = ds.with_options(options)

    # Optional: Make it a distributed dataset if you're using a strategy
    if strategy is not None:
        ds = strategy.experimental_distribute_dataset(ds)

    return ds

獲取training與validation數(shù)據(jù)獲取的迭代器
其中g(shù)en是生成數(shù)據(jù)的方程,其余參數(shù), 除了最后一個strategy參數(shù)滚局,都是生成數(shù)據(jù)方程所需的參數(shù)

training_ds = makeDataset(gen,
                          data_list=trainP,
                          line_path=line_path,
                          batchsize=batchsize,
                          draw_line=False,
                          is_raw=is_raw,
                          need_rotate=need_rotate,
                          only_flip=only_flip,
                          is_wide_line=is_wide_line,
                          strategy=None)
validation_ds = makeDataset(gen,
                            data_list=testP,
                            line_path=line_path,
                            batchsize=batchsize,
                            draw_line=False,
                            is_raw=is_raw,
                            need_rotate=need_rotate,
                            only_flip=only_flip,
                            is_wide_line=is_wide_line,
                            strategy=None)

生成數(shù)據(jù)方程的示例居暖,學(xué)過iterate的都明白在說啥

def gen(paths,
        line_path,
        batchsize=2,
        draw_line=True,
        is_raw=False,
        need_rotate=False,
        only_flip: bool = True,
        is_wide_line=False):
    num = len(paths)
    i = 0
    while True:
        # sizes = [512,512,512,512,640,1024] ##多尺度訓(xùn)練
        # size = np.random.choice(sizes,1)[0]
        size = 512
        X = np.zeros((batchsize, size, size, 3))
        Y = np.zeros((batchsize, size, size, 2))
        print(i)
        for j in range(batchsize):
            while True:
                if i >= num:
                    i = 0
                    np.random.shuffle(paths)
                p = paths[i]
                i += 1
                try:
                    if is_raw:
                        img, lines, labelImg = get_img_label_raw(p,
                                                                 line_path,
                                                                 size=(size, size),
                                                                 draw_line=draw_line,
                                                                 is_wide_line=is_wide_line)
                    else:
                        img, lines, labelImg = get_img_label_transform(p,
                                                                       line_path,
                                                                       size=(size, size),
                                                                       draw_line=draw_line,
                                                                       need_rotate=need_rotate,
                                                                       only_flip=only_flip,
                                                                       is_wide_line=is_wide_line)
                    break
                except Exception as e:
                    print(e)
            X[j] = img
            Y[j] = labelImg
        yield X, Y

模型訓(xùn)練部分的代碼

訓(xùn)練方法:fit

之前調(diào)用數(shù)據(jù)生成器的訓(xùn)練方法是fit_generator,TF2之后統(tǒng)一用fit方程了

steps參數(shù)的寫法核畴,重點膝但!

注意steps_per_epoch與validation_steps的寫法,batchsize必須與調(diào)用makeDataset時谤草,傳入的batchsize的值相同跟束,否則無法計算出正確的steps

model.fit(training_ds,
          callbacks=[checkpointer, earlyStopping],
          steps_per_epoch=max(1, len(trainP) // batchsize),
          validation_data=validation_ds,
          validation_steps=max(1, len(testP) // batchsize),
          epochs=300)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末莺奸,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子冀宴,更是在濱河造成了極大的恐慌灭贷,老刑警劉巖,帶你破解...
    沈念sama閱讀 216,372評論 6 498
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件略贮,死亡現(xiàn)場離奇詭異甚疟,居然都是意外死亡,警方通過查閱死者的電腦和手機逃延,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,368評論 3 392
  • 文/潘曉璐 我一進店門览妖,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人揽祥,你說我怎么就攤上這事讽膏。” “怎么了拄丰?”我有些...
    開封第一講書人閱讀 162,415評論 0 353
  • 文/不壞的土叔 我叫張陵府树,是天一觀的道長。 經(jīng)常有香客問我料按,道長奄侠,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,157評論 1 292
  • 正文 為了忘掉前任载矿,我火速辦了婚禮垄潮,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘恢准。我一直安慰自己魂挂,他們只是感情好甫题,可當我...
    茶點故事閱讀 67,171評論 6 388
  • 文/花漫 我一把揭開白布馁筐。 她就那樣靜靜地躺著,像睡著了一般坠非。 火紅的嫁衣襯著肌膚如雪敏沉。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,125評論 1 297
  • 那天炎码,我揣著相機與錄音盟迟,去河邊找鬼。 笑死潦闲,一個胖子當著我的面吹牛攒菠,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播歉闰,決...
    沈念sama閱讀 40,028評論 3 417
  • 文/蒼蘭香墨 我猛地睜開眼辖众,長吁一口氣:“原來是場噩夢啊……” “哼卓起!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起凹炸,我...
    開封第一講書人閱讀 38,887評論 0 274
  • 序言:老撾萬榮一對情侶失蹤戏阅,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后啤它,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體奕筐,經(jīng)...
    沈念sama閱讀 45,310評論 1 310
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,533評論 2 332
  • 正文 我和宋清朗相戀三年变骡,在試婚紗的時候發(fā)現(xiàn)自己被綠了离赫。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 39,690評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡塌碌,死狀恐怖笆怠,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情誊爹,我是刑警寧澤蹬刷,帶...
    沈念sama閱讀 35,411評論 5 343
  • 正文 年R本政府宣布,位于F島的核電站频丘,受9級特大地震影響办成,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜搂漠,卻給世界環(huán)境...
    茶點故事閱讀 41,004評論 3 325
  • 文/蒙蒙 一迂卢、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧桐汤,春花似錦而克、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,659評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至拣度,卻和暖如春碎绎,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背抗果。 一陣腳步聲響...
    開封第一講書人閱讀 32,812評論 1 268
  • 我被黑心中介騙來泰國打工筋帖, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人冤馏。 一個月前我還...
    沈念sama閱讀 47,693評論 2 368
  • 正文 我出身青樓日麸,卻偏偏與公主長得像,于是被迫代替她去往敵國和親逮光。 傳聞我的和親對象是個殘疾皇子代箭,可洞房花燭夜當晚...
    茶點故事閱讀 44,577評論 2 353

推薦閱讀更多精彩內(nèi)容