模型壓縮和加速——參數(shù)共享(Parameter Sharing)及其pytorch實現(xiàn)

一度迂、參數(shù)共享含義

參數(shù)共享(Parameter Sharing)是模型壓縮與加速中的一種重要技術(shù)侨嘀。通過參數(shù)共享妨猩,多個神經(jīng)元或?qū)涌梢怨蚕硐嗤臋?quán)重參數(shù),而不是每個神經(jīng)元或?qū)佣加歇毩⒌膮?shù)遭笋。

二坝冕、一個超級簡單的示例

定義一個簡單的卷積神經(jīng)網(wǎng)絡(luò),一共二層瓦呼,第二層共享第一層的參數(shù)

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 第一個卷積層喂窟,使用32個3x3的卷積核
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        # 第二個卷積層,使用32個3x3的卷積核吵血,但我們將共享第一個卷積層的參數(shù)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        # 共享參數(shù):將conv2的參數(shù)設(shè)置為conv1的參數(shù)
        self.conv2.weight = self.conv1.weight
        self.conv2.bias = self.conv1.bias
        
        # 全連接層
        self.fc = nn.Linear(32 * 28 * 28, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc(x)
        return x

# 創(chuàng)建模型實例
model = SimpleCNN()

# 打印模型結(jié)構(gòu)
print(model)

# 打印模型的參數(shù)
for name, param in model.named_parameters():
    print(name, param.size())

# 示例輸入
input_data = torch.randn(1, 1, 28, 28)  # 1個樣本,1個通道偷溺,28x28的圖像
output = model(input_data)
print(output)

二蹋辅、指定共享某一模塊

假設(shè)我們有以下兩個模型:

class ANN1(nn.Module):
    def __init__(self,features):
        super(ANN1, self).__init__()
        self.features = features
        self.nn_same = torch.nn.Sequential(
            nn.Linear(features, 128),
            torch.nn.ReLU(),
        )
        self.nn_diff = torch.nn.Sequential(
            nn.Linear(128, 1)
        )

    def forward(self, x):
        # x(batch_size, features)
        x = self.nn_same(x)
        x = self.nn_diff(x)
        return x
class ANN2(nn.Module):
    def __init__(self,features):
        super(ANN2, self).__init__()
        self.features = features
        self.nn_same = torch.nn.Sequential(
            nn.Linear(features, 128),
            torch.nn.ReLU(),
        )
        self.nn_diff = torch.nn.Sequential(
            nn.Linear(128, 1)
        )

    def forward(self, x):
        # x(batch_size, features)
        x = self.nn_same(x)
        x = self.nn_diff(x)
        return x
    
model1 = ANN1(10)
model2 = ANN2(10)
print(model1)
print(model2)


ANN1(
  (nn_same): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
  )
  (nn_diff): Sequential(
    (0): Linear(in_features=128, out_features=1, bias=True)
  )
)
ANN2(
  (nn_same): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
  )
  (nn_diff): Sequential(
    (0): Linear(in_features=128, out_features=1, bias=True)
  )
)

其中 nn_same 代表要共享參數(shù)的模塊,模塊名稱可以不相同挫掏,但是模塊結(jié)構(gòu)必須完全相同侦另。因為模型初始化時參數(shù)是隨機初始化的,所以兩個模型的參數(shù)肯定不相同尉共。
下面我們開始進(jìn)行參數(shù)共享:

print("****************遷移前*****************")
for param_tensor in model2.nn_same.state_dict():#輸出遷移前的參數(shù)
    print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
    
model_nn_same = model1.nn_same.state_dict() ##獲取model的nn_same部分的參數(shù)
model2.nn_same.load_state_dict(model_nn_same,strict=True) #更新model2 nn_same部分的參數(shù),#更新model2所有的參數(shù),False表示跳過名稱不同的層褒傅,True表示必須全部匹配(默認(rèn))

print("****************遷移后*****************")
for param_tensor in model2.nn_same.state_dict():#輸出遷移后的參數(shù)
    print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
#此時nn_same參數(shù)更新,nn_diff2參數(shù)不變

三袄友、 共享所有相同名稱的模塊

只需要修改這兩句即可

model_all = model1.state_dict() #獲取model1的所有的參數(shù)
model2.load_state_dict(model_all,strict=False) #更新model2所有的參數(shù),False表示跳過名稱不同的層殿托,True表示必須全部匹配(默認(rèn))

strict=False,表示兩個模型的模塊名不需要完全匹配剧蚣,只會更新名稱相同的模塊支竹。如果兩個模型的模塊名不完全相同但是strict=True那么就會報錯。

本文部分參考了《Pytorch中模型之間的參數(shù)共享》原文鏈接:https://blog.csdn.net/cyj972628089/article/details/127325735
如有侵權(quán)鸠按,請原作者聯(lián)系刪除礼搁。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市目尖,隨后出現(xiàn)的幾起案子馒吴,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 219,110評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件饮戳,死亡現(xiàn)場離奇詭異豪治,居然都是意外死亡,警方通過查閱死者的電腦和手機莹捡,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,443評論 3 395
  • 文/潘曉璐 我一進(jìn)店門鬼吵,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人篮赢,你說我怎么就攤上這事齿椅。” “怎么了启泣?”我有些...
    開封第一講書人閱讀 165,474評論 0 356
  • 文/不壞的土叔 我叫張陵涣脚,是天一觀的道長。 經(jīng)常有香客問我寥茫,道長遣蚀,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,881評論 1 295
  • 正文 為了忘掉前任纱耻,我火速辦了婚禮芭梯,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘弄喘。我一直安慰自己玖喘,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 67,902評論 6 392
  • 文/花漫 我一把揭開白布蘑志。 她就那樣靜靜地躺著累奈,像睡著了一般。 火紅的嫁衣襯著肌膚如雪急但。 梳的紋絲不亂的頭發(fā)上澎媒,一...
    開封第一講書人閱讀 51,698評論 1 305
  • 那天,我揣著相機與錄音波桩,去河邊找鬼戒努。 笑死,一個胖子當(dāng)著我的面吹牛镐躲,可吹牛的內(nèi)容都是我干的柏卤。 我是一名探鬼主播,決...
    沈念sama閱讀 40,418評論 3 419
  • 文/蒼蘭香墨 我猛地睜開眼匀油,長吁一口氣:“原來是場噩夢啊……” “哼缘缚!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起敌蚜,我...
    開封第一講書人閱讀 39,332評論 0 276
  • 序言:老撾萬榮一對情侶失蹤桥滨,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體齐媒,經(jīng)...
    沈念sama閱讀 45,796評論 1 316
  • 正文 獨居荒郊野嶺守林人離奇死亡蒲每,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,968評論 3 337
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了喻括。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片邀杏。...
    茶點故事閱讀 40,110評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖唬血,靈堂內(nèi)的尸體忽然破棺而出望蜡,到底是詐尸還是另有隱情,我是刑警寧澤拷恨,帶...
    沈念sama閱讀 35,792評論 5 346
  • 正文 年R本政府宣布脖律,位于F島的核電站,受9級特大地震影響腕侄,放射性物質(zhì)發(fā)生泄漏小泉。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,455評論 3 331
  • 文/蒙蒙 一冕杠、第九天 我趴在偏房一處隱蔽的房頂上張望微姊。 院中可真熱鬧,春花似錦分预、人聲如沸兢交。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,003評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽魁淳。三九已至飘诗,卻和暖如春与倡,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背昆稿。 一陣腳步聲響...
    開封第一講書人閱讀 33,130評論 1 272
  • 我被黑心中介騙來泰國打工纺座, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人溉潭。 一個月前我還...
    沈念sama閱讀 48,348評論 3 373
  • 正文 我出身青樓净响,卻偏偏與公主長得像,于是被迫代替她去往敵國和親喳瓣。 傳聞我的和親對象是個殘疾皇子馋贤,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,047評論 2 355

推薦閱讀更多精彩內(nèi)容