一度迂、參數(shù)共享含義
參數(shù)共享(Parameter Sharing)是模型壓縮與加速中的一種重要技術(shù)侨嘀。通過參數(shù)共享妨猩,多個神經(jīng)元或?qū)涌梢怨蚕硐嗤臋?quán)重參數(shù),而不是每個神經(jīng)元或?qū)佣加歇毩⒌膮?shù)遭笋。
二坝冕、一個超級簡單的示例
定義一個簡單的卷積神經(jīng)網(wǎng)絡(luò),一共二層瓦呼,第二層共享第一層的參數(shù)
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# 第一個卷積層喂窟,使用32個3x3的卷積核
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
# 第二個卷積層,使用32個3x3的卷積核吵血,但我們將共享第一個卷積層的參數(shù)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
# 共享參數(shù):將conv2的參數(shù)設(shè)置為conv1的參數(shù)
self.conv2.weight = self.conv1.weight
self.conv2.bias = self.conv1.bias
# 全連接層
self.fc = nn.Linear(32 * 28 * 28, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = x.view(x.size(0), -1) # 展平
x = self.fc(x)
return x
# 創(chuàng)建模型實例
model = SimpleCNN()
# 打印模型結(jié)構(gòu)
print(model)
# 打印模型的參數(shù)
for name, param in model.named_parameters():
print(name, param.size())
# 示例輸入
input_data = torch.randn(1, 1, 28, 28) # 1個樣本,1個通道偷溺,28x28的圖像
output = model(input_data)
print(output)
二蹋辅、指定共享某一模塊
假設(shè)我們有以下兩個模型:
class ANN1(nn.Module):
def __init__(self,features):
super(ANN1, self).__init__()
self.features = features
self.nn_same = torch.nn.Sequential(
nn.Linear(features, 128),
torch.nn.ReLU(),
)
self.nn_diff = torch.nn.Sequential(
nn.Linear(128, 1)
)
def forward(self, x):
# x(batch_size, features)
x = self.nn_same(x)
x = self.nn_diff(x)
return x
class ANN2(nn.Module):
def __init__(self,features):
super(ANN2, self).__init__()
self.features = features
self.nn_same = torch.nn.Sequential(
nn.Linear(features, 128),
torch.nn.ReLU(),
)
self.nn_diff = torch.nn.Sequential(
nn.Linear(128, 1)
)
def forward(self, x):
# x(batch_size, features)
x = self.nn_same(x)
x = self.nn_diff(x)
return x
model1 = ANN1(10)
model2 = ANN2(10)
print(model1)
print(model2)
ANN1(
(nn_same): Sequential(
(0): Linear(in_features=10, out_features=128, bias=True)
(1): ReLU()
)
(nn_diff): Sequential(
(0): Linear(in_features=128, out_features=1, bias=True)
)
)
ANN2(
(nn_same): Sequential(
(0): Linear(in_features=10, out_features=128, bias=True)
(1): ReLU()
)
(nn_diff): Sequential(
(0): Linear(in_features=128, out_features=1, bias=True)
)
)
其中 nn_same 代表要共享參數(shù)的模塊,模塊名稱可以不相同挫掏,但是模塊結(jié)構(gòu)必須完全相同侦另。因為模型初始化時參數(shù)是隨機初始化的,所以兩個模型的參數(shù)肯定不相同尉共。
下面我們開始進(jìn)行參數(shù)共享:
print("****************遷移前*****************")
for param_tensor in model2.nn_same.state_dict():#輸出遷移前的參數(shù)
print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
model_nn_same = model1.nn_same.state_dict() ##獲取model的nn_same部分的參數(shù)
model2.nn_same.load_state_dict(model_nn_same,strict=True) #更新model2 nn_same部分的參數(shù),#更新model2所有的參數(shù),False表示跳過名稱不同的層褒傅,True表示必須全部匹配(默認(rèn))
print("****************遷移后*****************")
for param_tensor in model2.nn_same.state_dict():#輸出遷移后的參數(shù)
print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
#此時nn_same參數(shù)更新,nn_diff2參數(shù)不變
三袄友、 共享所有相同名稱的模塊
只需要修改這兩句即可
model_all = model1.state_dict() #獲取model1的所有的參數(shù)
model2.load_state_dict(model_all,strict=False) #更新model2所有的參數(shù),False表示跳過名稱不同的層殿托,True表示必須全部匹配(默認(rèn))
strict=False,表示兩個模型的模塊名不需要完全匹配剧蚣,只會更新名稱相同的模塊支竹。如果兩個模型的模塊名不完全相同但是strict=True那么就會報錯。
本文部分參考了《Pytorch中模型之間的參數(shù)共享》原文鏈接:https://blog.csdn.net/cyj972628089/article/details/127325735
如有侵權(quán)鸠按,請原作者聯(lián)系刪除礼搁。