Keras同時用多張顯卡訓練網(wǎng)絡

Author: Zongwei Zhou | 周縱葦
Weibo: @MrGiovanni
Email: zongweiz@asu.edu


References.

官方文檔:multi_gpu_model
以及Google

0. 誤區(qū)

目前Keras是支持了多個GPU同時訓練網(wǎng)絡权均,非常容易,但是靠以下這個代碼是不行的锅锨。

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"

當你監(jiān)視GPU的使用情況(nvidia-smi -l 1)的時候會發(fā)現(xiàn)叽赊,盡管GPU不空閑,實質(zhì)上只有一個GPU在跑必搞,其他的就是閑置的占用狀態(tài)必指,也就是說,如果你的電腦里面有多張顯卡恕洲,無論有沒有上面的代碼塔橡,Keras都會默認的去占用所有能檢測到的GPU梅割。這行代碼在你只需要一個GPU的時候時候用的,也就是可以讓Keras檢測不到電腦里其他的GPU葛家。假設你一共有三張顯卡户辞,每個顯卡都是有自己的標號的(0, 1, 2),為了不影響別人的使用癞谒,你只用其中一個底燎,比如用gpu=1的這張,那么

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

然后再監(jiān)視GPU的使用情況(nvidia-smi -l 1)扯俱,確實只有一個被占用书蚪,其他都是空閑狀態(tài)。所以這是一個Keras使用多顯卡的誤區(qū)迅栅,它并不能同時利用多個GPU殊校。

1. 目的

為什么要同時用多個GPU來訓練?

單個顯卡內(nèi)存太小 -> batch size無法設的比較大读存,有時甚至batch_size=1都內(nèi)存溢出(OUT OF MEMORY)

從我跑深度網(wǎng)絡的經(jīng)驗來看为流,batch_size設的大一點會比較好,相當于每次反向傳播更新權重让簿,網(wǎng)絡都可以看到更多的樣本敬察,從而不會每次iteration都過擬合到不同的地方去Don't Decay the Learning Rate, Increase the Batch Size。當然尔当,我也看過有論文說也不能設的過大莲祸,原因不明... 反正我也沒有機會試過。我建議的batch_size大概就是64~256的范圍內(nèi)椭迎,都沒什么大問題锐帜。

但是隨著現(xiàn)在網(wǎng)絡的深度越來越深,對于GPU的內(nèi)存要求也越來越大畜号,很多入門的新人最大的問題往往不是代碼缴阎,而是從Github里面抄下來的代碼自己的GPU太渣,實現(xiàn)不了简软,只能降低batch_size蛮拔,最后訓練不出那種效果。

解決方案兩個:一是買一個超級牛逼的GPU痹升,內(nèi)存巨大無比建炫;二是買多個一般般的GPU,一起用疼蛾。

第一個方案不行踱卵,因為目前即便最好的NVIDIA顯卡,內(nèi)存也不過十幾個G了不起了,網(wǎng)絡一深也掛惋砂,并且買一個牛逼顯卡的性價比不高。所以绳锅、學會在Keras下用多個GPU是比較靠譜的選擇西饵。

2. 實現(xiàn)

2.1 設計一個類

cite: parallel_model.py

import tensorflow as tf
import keras.backend as K
import keras.layers as KL
import keras.models as KM


class ParallelModel(KM.Model):
    """Subclasses the standard Keras Model and adds multi-GPU support.
    It works by creating a copy of the model on each GPU. Then it slices
    the inputs and sends a slice to each copy of the model, and then
    merges the outputs together and applies the loss on the combined
    outputs.
    """

    def __init__(self, keras_model, gpu_count):
        """Class constructor.
        keras_model: The Keras model to parallelize
        gpu_count: Number of GPUs. Must be > 1
        """
        self.inner_model = keras_model
        self.gpu_count = gpu_count
        merged_outputs = self.make_parallel()
        super(ParallelModel, self).__init__(inputs=self.inner_model.inputs,
                                            outputs=merged_outputs)

    def __getattribute__(self, attrname):
        """Redirect loading and saving methods to the inner model. That's where
        the weights are stored."""
        if 'load' in attrname or 'save' in attrname:
            return getattr(self.inner_model, attrname)
        return super(ParallelModel, self).__getattribute__(attrname)

    def summary(self, *args, **kwargs):
        """Override summary() to display summaries of both, the wrapper
        and inner models."""
        super(ParallelModel, self).summary(*args, **kwargs)
        self.inner_model.summary(*args, **kwargs)

    def make_parallel(self):
        """Creates a new wrapper model that consists of multiple replicas of
        the original model placed on different GPUs.
        """
        # Slice inputs. Slice inputs on the CPU to avoid sending a copy
        # of the full inputs to all GPUs. Saves on bandwidth and memory.
        input_slices = {name: tf.split(x, self.gpu_count)
                        for name, x in zip(self.inner_model.input_names,
                                           self.inner_model.inputs)}

        output_names = self.inner_model.output_names
        outputs_all = []
        for i in range(len(self.inner_model.outputs)):
            outputs_all.append([])

        # Run the model call() on each GPU to place the ops there
        for i in range(self.gpu_count):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('tower_%d' % i):
                    # Run a slice of inputs through this replica
                    zipped_inputs = zip(self.inner_model.input_names,
                                        self.inner_model.inputs)
                    inputs = [
                        KL.Lambda(lambda s: input_slices[name][i],
                                  output_shape=lambda s: (None,) + s[1:])(tensor)
                        for name, tensor in zipped_inputs]
                    # Create the model replica and get the outputs
                    outputs = self.inner_model(inputs)
                    if not isinstance(outputs, list):
                        outputs = [outputs]
                    # Save the outputs for merging back together later
                    for l, o in enumerate(outputs):
                        outputs_all[l].append(o)

        # Merge outputs on CPU
        with tf.device('/cpu:0'):
            merged = []
            for outputs, name in zip(outputs_all, output_names):
                # If outputs are numbers without dimensions, add a batch dim.
                def add_dim(tensor):
                    """Add a dimension to tensors that don't have any."""
                    if K.int_shape(tensor) == ():
                        return KL.Lambda(lambda t: K.reshape(t, [1, 1]))(tensor)
                    return tensor
                outputs = list(map(add_dim, outputs))

                # Concatenate
                merged.append(KL.Concatenate(axis=0, name=name)(outputs))
        return merged

2.2 調(diào)用非常簡潔

GPU_COUNT = 3 # 同時使用3個GPU
model = keras.applications.densenet.DenseNet201() # 比如使用DenseNet-201
model = ParallelModel(model, GPU_COUNT)
model.compile(optimizer=Adam(lr=1e-5), loss='binary_crossentropy', metrics = ['accuracy'])
model.fit(X_train, y_train,
              batch_size=batch_size*GPU_COUNT, 
              epochs=nb_epoch, verbose=0, shuffle=True,
              validation_data=(X_valid, y_valid))

model.save_weights('/path/to/save/model.h5')
最后編輯于
?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市鳞芙,隨后出現(xiàn)的幾起案子眷柔,更是在濱河造成了極大的恐慌,老刑警劉巖原朝,帶你破解...
    沈念sama閱讀 218,941評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件驯嘱,死亡現(xiàn)場離奇詭異,居然都是意外死亡喳坠,警方通過查閱死者的電腦和手機鞠评,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,397評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來壕鹉,“玉大人剃幌,你說我怎么就攤上這事×涝。” “怎么了负乡?”我有些...
    開封第一講書人閱讀 165,345評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長脊凰。 經(jīng)常有香客問我抖棘,道長,這世上最難降的妖魔是什么狸涌? 我笑而不...
    開封第一講書人閱讀 58,851評論 1 295
  • 正文 為了忘掉前任切省,我火速辦了婚禮,結果婚禮上杈抢,老公的妹妹穿的比我還像新娘数尿。我一直安慰自己,他們只是感情好惶楼,可當我...
    茶點故事閱讀 67,868評論 6 392
  • 文/花漫 我一把揭開白布右蹦。 她就那樣靜靜地躺著,像睡著了一般歼捐。 火紅的嫁衣襯著肌膚如雪何陆。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,688評論 1 305
  • 那天豹储,我揣著相機與錄音贷盲,去河邊找鬼。 笑死,一個胖子當著我的面吹牛巩剖,可吹牛的內(nèi)容都是我干的铝穷。 我是一名探鬼主播,決...
    沈念sama閱讀 40,414評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼佳魔,長吁一口氣:“原來是場噩夢啊……” “哼曙聂!你這毒婦竟也來了?” 一聲冷哼從身側響起鞠鲜,我...
    開封第一講書人閱讀 39,319評論 0 276
  • 序言:老撾萬榮一對情侶失蹤宁脊,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后贤姆,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體榆苞,經(jīng)...
    沈念sama閱讀 45,775評論 1 315
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,945評論 3 336
  • 正文 我和宋清朗相戀三年霞捡,在試婚紗的時候發(fā)現(xiàn)自己被綠了坐漏。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,096評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡弄砍,死狀恐怖仙畦,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情音婶,我是刑警寧澤慨畸,帶...
    沈念sama閱讀 35,789評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站衣式,受9級特大地震影響寸士,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜碴卧,卻給世界環(huán)境...
    茶點故事閱讀 41,437評論 3 331
  • 文/蒙蒙 一弱卡、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧住册,春花似錦婶博、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,993評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間饶辙,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,107評論 1 271
  • 我被黑心中介騙來泰國打工岸晦, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 48,308評論 3 372
  • 正文 我出身青樓启上,卻偏偏與公主長得像邢隧,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子冈在,可洞房花燭夜當晚...
    茶點故事閱讀 45,037評論 2 355

推薦閱讀更多精彩內(nèi)容