一、知識(shí)蒸餾的含義
知識(shí)蒸餾(Knowledge Distillation)是一種用于模型壓縮和遷移學(xué)習(xí)的技術(shù),通過將一個(gè)復(fù)雜模型(稱為教師模型)的知識(shí)傳遞給一個(gè)較小的模型(稱為學(xué)生模型),從而使得學(xué)生模型能夠達(dá)到接近教師模型的性能振惰。
- 具體方法:知識(shí)蒸餾的方式就是將Teacher Network輸出的soft label作為標(biāo)簽來訓(xùn)練Student Network猾编。比如在下圖中我們訓(xùn)練Student Network來使其與Teacher Network有同樣的輸出。這樣的好處是Teacher Network的輸出提供了比獨(dú)熱編碼標(biāo)簽更多的信息朱灿,比如對(duì)于輸入的數(shù)字1,Teacher Network的輸出表明這個(gè)數(shù)字是1钠四,同時(shí)也表明了這個(gè)數(shù)字也有一點(diǎn)像7盗扒,也有一點(diǎn)像9。
- 損失函數(shù):通常需要2個(gè)損失函數(shù)缀去,一個(gè)是KL 散度損失(soft_loss)侣灶,用于衡量老師和學(xué)生兩個(gè)概率分布之間的差異;另一個(gè)是交叉熵?fù)p失函數(shù)缕碎,用于衡量學(xué)生模型輸出與真實(shí)標(biāo)簽之間的損失褥影。總損失為二者之和咏雌。
知識(shí)蒸餾訓(xùn)練出的Student Network有一點(diǎn)神奇的地方就是這個(gè)Network有可能辨識(shí)從來沒有見過的輸入凡怎,這是因?yàn)門eacher Network輸出的soft label提供了額外的信息。
知識(shí)蒸餾的另一個(gè)用處是用來擬合集成模型赊抖,有時(shí)候我們會(huì)集成(Ensemble)很多個(gè)模型來獲取其輸出的均值從而提高總體的效果统倒,我們可以使用知識(shí)蒸餾的方式來使得Student Network學(xué)習(xí)集成模型的輸出,從而達(dá)到將集成模型的效果復(fù)制到一個(gè)模型上的目的:
在進(jìn)行知識(shí)蒸餾時(shí)我們還會(huì)使用到下面的技巧就是調(diào)整最終輸出的sofmax層來避免Teacher Network輸出類似獨(dú)熱編碼的標(biāo)簽:
通過下列數(shù)據(jù)的對(duì)比我們可以看出這一操作的作用(在實(shí)際操作時(shí)T是一個(gè)可以調(diào)的參數(shù)):
二氛雪、代碼實(shí)現(xiàn)
案例1:壓縮 ResNet 模型
假設(shè)我們使用一個(gè)預(yù)訓(xùn)練的 ResNet 作為教師模型房匆,并定義一個(gè)簡(jiǎn)單的兩層全連接網(wǎng)絡(luò)作為學(xué)生模型。
#1. 導(dǎo)入必要的庫
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
#2. 定義教師模型和學(xué)生模型
#假設(shè)我們使用一個(gè)預(yù)訓(xùn)練的 ResNet 作為教師模型报亩,并定義一個(gè)簡(jiǎn)單的兩層全連接網(wǎng)絡(luò)作為學(xué)生模型浴鸿。
# 教師模型:使用預(yù)訓(xùn)練的 ResNet18
teacher_model = models.resnet18(pretrained=True)
teacher_model.fc = nn.Linear(512, 10) # 修改最后一層以適應(yīng)10類分類任務(wù)
# 學(xué)生模型:簡(jiǎn)單的兩層全連接網(wǎng)絡(luò)
class StudentModel(nn.Module):
def __init__(self):
super(StudentModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(-1, 784) # 將輸入展平
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
student_model = StudentModel()
#3. 定義損失函數(shù)和優(yōu)化器
#知識(shí)蒸餾的損失函數(shù)由兩部分組成:一部分是學(xué)生模型輸出與真實(shí)標(biāo)簽之間的交叉熵?fù)p失(hard_loss),另一部分是學(xué)生模型輸出與教師模型輸出之間的 KL 散度損失(soft_loss)弦追。
# 定義損失函數(shù)
criterion = nn.CrossEntropyLoss()
# 定義優(yōu)化器
optimizer = optim.Adam(student_model.parameters(), lr=0.001)
#4. 加載數(shù)據(jù)集
#我們使用 MNIST 數(shù)據(jù)集進(jìn)行訓(xùn)練和測(cè)試岳链。
# 數(shù)據(jù)預(yù)處理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加載數(shù)據(jù)集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
#5. 訓(xùn)練學(xué)生模型
#在訓(xùn)練過程中,我們使用教師模型的輸出來指導(dǎo)學(xué)生模型的訓(xùn)練劲件。
# 設(shè)置溫度參數(shù)
T = 5.0
# 訓(xùn)練學(xué)生模型
def train(epoch):
student_model.train()
teacher_model.eval() # 教師模型不參與訓(xùn)練宠页,只用來輸出softlable和判卷子
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
# 教師模型的輸出
with torch.no_grad():
teacher_output = teacher_model(data)
# 學(xué)生模型的輸出
student_output = student_model(data)
# 計(jì)算硬標(biāo)簽損失(學(xué)生模型輸出與真實(shí)標(biāo)簽之間的交叉熵?fù)p失)
hard_loss = criterion(student_output, target)
# 計(jì)算軟標(biāo)簽損失(學(xué)生模型輸出與教師模型輸出之間的 KL 散度損失)
soft_loss = F.kl_div(F.log_softmax(student_output / T, dim=1),
F.softmax(teacher_output / T, dim=1),
reduction='batchmean') * T * T
# 總損失
loss = hard_loss + soft_loss
# 反向傳播和優(yōu)化
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
# 測(cè)試學(xué)生模型
def test():
student_model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = student_model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
# correct計(jì)算模型預(yù)測(cè)準(zhǔn)確率的代碼片段左胞。
# pred.eq(target.view_as(pred))用于比較預(yù)測(cè)結(jié)果 pred 和真實(shí)標(biāo)簽 target。
# target.view_as(pred) 將 target 的形狀調(diào)整為與 pred 相同的形狀
# 最終返回一個(gè)張量举户,相等的位置為1,其他位置為0遍烦。
test_loss /= len(test_loader.dataset) #test_loss 是一個(gè)累加器俭嘁,用于記錄模型在整個(gè)測(cè)試集上的總損失,除以樣本量求得平均損失服猪。
print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
f'({100. * correct / len(test_loader.dataset):.0f}%)\n')
# 訓(xùn)練和測(cè)試
for epoch in range(1, 5):
train(epoch)
test()
案例2:壓縮BERT模型
BERT 模型可以作為教師模型供填,而一個(gè)較小的模型(如 LSTM 或簡(jiǎn)單的 Transformer)可以作為學(xué)生模型。
# 1. 導(dǎo)包
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertModel, BertTokenizer
# 2. 定義教師模型和學(xué)生模型
#我們將使用 Hugging Face 的 transformers 庫來加載 BERT 模型罢猪,
#并定義一個(gè)簡(jiǎn)單的 LSTM 模型作為學(xué)生模型近她。
# 定義教師模型(BERT)
class TeacherModel(nn.Module):
def __init__(self):
super(TeacherModel, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.fc = nn.Linear(768, 2) # 假設(shè)我們有一個(gè)二分類任務(wù)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.fc(pooled_output)
return logits
# 定義學(xué)生模型(LSTM)
class StudentModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
super(StudentModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout)
self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text, text_lengths):
embedded = self.dropout(self.embedding(text))
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'))
#nn.utils.rnn.pack_padded_sequence 是 PyTorch 中的一個(gè)函數(shù),用于將填充后的序列(padded sequence)打包成一個(gè)壓縮的序列(packed sequence)膳帕。text_lengths 是一個(gè)包含每個(gè)序列實(shí)際長(zhǎng)度的張量粘捎,打包后的序列只包含實(shí)際的非零元素,從而減少了計(jì)算量和內(nèi)存占用危彩。
packed_output, (hidden, cell) = self.lstm(packed_embedded)
output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)
hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
return self.fc(hidden)
#這句代碼的作用是將 LSTM 的隱藏狀態(tài)進(jìn)行拼接攒磨,并通過 dropout 層進(jìn)行正則化處理。
#hidden 是一個(gè)形狀為 (num_layers * num_directions, batch_size, hidden_dim) 的張量汤徽,表示 LSTM 的隱藏狀態(tài)娩缰。
#num_layers 是 LSTM 的層數(shù)。
#num_directions 是 LSTM 的方向數(shù)(1 表示單向谒府,2 表示雙向)拼坎。
#batch_size 是批次大小。
#hidden_dim 是隱藏狀態(tài)的維度完疫。hidden[-2,:,:] 和 hidden[-1,:,:] 分別表示倒數(shù)第二層和最后一層的隱藏狀態(tài)泰鸡。對(duì)于雙向 LSTM,hidden[-2,:,:] 和 hidden[-1,:,:] 分別表示前向和后向的隱藏狀態(tài)趋惨。
# 初始化教師模型和學(xué)生模型
teacher_model = TeacherModel()
student_model = StudentModel(vocab_size=30522, embedding_dim=100, hidden_dim=256, output_dim=2, n_layers=2, bidirectional=True, dropout=0.5)
# 3. 定義損失函數(shù)和優(yōu)化器
criterion = nn.KLDivLoss(reduction='batchmean')
optimizer = optim.Adam(student_model.parameters(), lr=0.001)
# 4. 加載數(shù)據(jù)和預(yù)處理
#假設(shè)我們有一個(gè)簡(jiǎn)單的數(shù)據(jù)集鸟顺,包含文本和標(biāo)簽。我們需要使用 BERT 的 tokenizer 來處理輸入數(shù)據(jù)器虾。
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 示例數(shù)據(jù)
texts = ["I love programming", "I hate bugs"]
labels = [1, 0] # 1表示正面讯嫂,0表示負(fù)面
# 預(yù)處理數(shù)據(jù)
encoded_texts = [tokenizer.encode(text, add_special_tokens=True) for text in texts]
max_len = max(len(text) for text in encoded_texts)
input_ids = [text + [0] * (max_len - len(text)) for text in encoded_texts]
#對(duì)于每個(gè)文本 text,以0將其填充到 max_len 的長(zhǎng)度兆沙。
attention_mask = [[1] * len(text) + [0] * (max_len - len(text)) for text in encoded_texts]
#1是實(shí)際的欧芽,0是填充的。
input_ids = torch.tensor(input_ids)
attention_mask = torch.tensor(attention_mask)
labels = torch.tensor(labels)
#4. 訓(xùn)練學(xué)生模型
#在訓(xùn)練過程中葛圃,我們使用教師模型的輸出作為軟標(biāo)簽(soft labels)來指導(dǎo)學(xué)生模型的學(xué)習(xí)千扔。
# 訓(xùn)練學(xué)生模型
for epoch in range(5): # 假設(shè)我們訓(xùn)練5個(gè)epoch
student_model.train()
optimizer.zero_grad()
# 獲取教師模型的輸出
with torch.no_grad():
teacher_logits = teacher_model(input_ids, attention_mask)
teacher_probs = torch.softmax(teacher_logits, dim=1)
# 獲取學(xué)生模型的輸出
student_logits = student_model(input_ids, torch.tensor([len(text) for text in encoded_texts]))
student_probs = torch.softmax(student_logits, dim=1)
# 計(jì)算KL散度損失
loss = criterion(torch.log(student_probs), teacher_probs)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
#5. 測(cè)試學(xué)生模型
#在訓(xùn)練完成后憎妙,可以使用學(xué)生模型進(jìn)行預(yù)測(cè)。
# 測(cè)試學(xué)生模型
student_model.eval()
with torch.no_grad():
test_text = ["I love coding"]
encoded_test_text = [tokenizer.encode(text, add_special_tokens=True) for text in test_text]
max_len = max(len(text) for text in encoded_test_text)
input_ids = [text + [0] * (max_len - len(text)) for text in encoded_test_text]
attention_mask = [[1] * len(text) + [0] * (max_len - len(text)) for text in encoded_test_text]
input_ids = torch.tensor(input_ids)
attention_mask = torch.tensor(attention_mask)
student_logits = student_model(input_ids, torch.tensor([len(text) for text in encoded_test_text]))
student_probs = torch.softmax(student_logits, dim=1)
predicted_label = torch.argmax(student_probs, dim=1).item()
print(f'Predicted Label: {predicted_label}')