模型壓縮和加速——知識(shí)蒸餾(Knowledge Distillation)及pytorch實(shí)現(xiàn)

一、知識(shí)蒸餾的含義

知識(shí)蒸餾(Knowledge Distillation)是一種用于模型壓縮和遷移學(xué)習(xí)的技術(shù),通過將一個(gè)復(fù)雜模型(稱為教師模型)的知識(shí)傳遞給一個(gè)較小的模型(稱為學(xué)生模型),從而使得學(xué)生模型能夠達(dá)到接近教師模型的性能振惰。

  • 具體方法:知識(shí)蒸餾的方式就是將Teacher Network輸出的soft label作為標(biāo)簽來訓(xùn)練Student Network猾编。比如在下圖中我們訓(xùn)練Student Network來使其與Teacher Network有同樣的輸出。這樣的好處是Teacher Network的輸出提供了比獨(dú)熱編碼標(biāo)簽更多的信息朱灿,比如對(duì)于輸入的數(shù)字1,Teacher Network的輸出表明這個(gè)數(shù)字是1钠四,同時(shí)也表明了這個(gè)數(shù)字也有一點(diǎn)像7盗扒,也有一點(diǎn)像9。
  • 損失函數(shù):通常需要2個(gè)損失函數(shù)缀去,一個(gè)是KL 散度損失(soft_loss)侣灶,用于衡量老師和學(xué)生兩個(gè)概率分布之間的差異;另一個(gè)是交叉熵?fù)p失函數(shù)缕碎,用于衡量學(xué)生模型輸出與真實(shí)標(biāo)簽之間的損失褥影。總損失為二者之和咏雌。
知識(shí)蒸餾.png

知識(shí)蒸餾訓(xùn)練出的Student Network有一點(diǎn)神奇的地方就是這個(gè)Network有可能辨識(shí)從來沒有見過的輸入凡怎,這是因?yàn)門eacher Network輸出的soft label提供了額外的信息

知識(shí)蒸餾的另一個(gè)用處是用來擬合集成模型赊抖,有時(shí)候我們會(huì)集成(Ensemble)很多個(gè)模型來獲取其輸出的均值從而提高總體的效果统倒,我們可以使用知識(shí)蒸餾的方式來使得Student Network學(xué)習(xí)集成模型的輸出,從而達(dá)到將集成模型的效果復(fù)制到一個(gè)模型上的目的

知識(shí)蒸餾用于擬合集成模型.png

在進(jìn)行知識(shí)蒸餾時(shí)我們還會(huì)使用到下面的技巧就是調(diào)整最終輸出的sofmax層來避免Teacher Network輸出類似獨(dú)熱編碼的標(biāo)簽:


調(diào)整softmax函數(shù)形式.png

通過下列數(shù)據(jù)的對(duì)比我們可以看出這一操作的作用(在實(shí)際操作時(shí)T是一個(gè)可以調(diào)的參數(shù)):


調(diào)整softmax函數(shù)形式示例.png

二氛雪、代碼實(shí)現(xiàn)

案例1:壓縮 ResNet 模型

假設(shè)我們使用一個(gè)預(yù)訓(xùn)練的 ResNet 作為教師模型房匆,并定義一個(gè)簡(jiǎn)單的兩層全連接網(wǎng)絡(luò)作為學(xué)生模型。

#1. 導(dǎo)入必要的庫
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
#2. 定義教師模型和學(xué)生模型
#假設(shè)我們使用一個(gè)預(yù)訓(xùn)練的 ResNet 作為教師模型报亩,并定義一個(gè)簡(jiǎn)單的兩層全連接網(wǎng)絡(luò)作為學(xué)生模型浴鸿。
# 教師模型:使用預(yù)訓(xùn)練的 ResNet18
teacher_model = models.resnet18(pretrained=True)
teacher_model.fc = nn.Linear(512, 10)  # 修改最后一層以適應(yīng)10類分類任務(wù)

# 學(xué)生模型:簡(jiǎn)單的兩層全連接網(wǎng)絡(luò)
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
    def forward(self, x):
        x = x.view(-1, 784)  # 將輸入展平
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

student_model = StudentModel()
#3. 定義損失函數(shù)和優(yōu)化器
#知識(shí)蒸餾的損失函數(shù)由兩部分組成:一部分是學(xué)生模型輸出與真實(shí)標(biāo)簽之間的交叉熵?fù)p失(hard_loss),另一部分是學(xué)生模型輸出與教師模型輸出之間的 KL 散度損失(soft_loss)弦追。
# 定義損失函數(shù)
criterion = nn.CrossEntropyLoss()
# 定義優(yōu)化器
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

#4. 加載數(shù)據(jù)集
#我們使用 MNIST 數(shù)據(jù)集進(jìn)行訓(xùn)練和測(cè)試岳链。
# 數(shù)據(jù)預(yù)處理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
# 加載數(shù)據(jù)集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

#5. 訓(xùn)練學(xué)生模型
#在訓(xùn)練過程中,我們使用教師模型的輸出來指導(dǎo)學(xué)生模型的訓(xùn)練劲件。
# 設(shè)置溫度參數(shù)
T = 5.0
# 訓(xùn)練學(xué)生模型
def train(epoch):
    student_model.train()
    teacher_model.eval()  # 教師模型不參與訓(xùn)練宠页,只用來輸出softlable和判卷子

    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()

        # 教師模型的輸出
        with torch.no_grad():
            teacher_output = teacher_model(data)
        # 學(xué)生模型的輸出
        student_output = student_model(data)

        # 計(jì)算硬標(biāo)簽損失(學(xué)生模型輸出與真實(shí)標(biāo)簽之間的交叉熵?fù)p失)
        hard_loss = criterion(student_output, target)

        # 計(jì)算軟標(biāo)簽損失(學(xué)生模型輸出與教師模型輸出之間的 KL 散度損失)
        soft_loss = F.kl_div(F.log_softmax(student_output / T, dim=1),
                             F.softmax(teacher_output / T, dim=1),
                             reduction='batchmean') * T * T

        # 總損失
        loss = hard_loss + soft_loss

        # 反向傳播和優(yōu)化
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# 測(cè)試學(xué)生模型
def test():
    student_model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            output = student_model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item() 
# correct計(jì)算模型預(yù)測(cè)準(zhǔn)確率的代碼片段左胞。
# pred.eq(target.view_as(pred))用于比較預(yù)測(cè)結(jié)果 pred 和真實(shí)標(biāo)簽 target。
# target.view_as(pred) 將 target 的形狀調(diào)整為與 pred 相同的形狀
# 最終返回一個(gè)張量举户,相等的位置為1,其他位置為0遍烦。

    test_loss /= len(test_loader.dataset) #test_loss 是一個(gè)累加器俭嘁,用于記錄模型在整個(gè)測(cè)試集上的總損失,除以樣本量求得平均損失服猪。
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({100. * correct / len(test_loader.dataset):.0f}%)\n')

# 訓(xùn)練和測(cè)試
for epoch in range(1, 5):
    train(epoch)
    test()

案例2:壓縮BERT模型

BERT 模型可以作為教師模型供填,而一個(gè)較小的模型(如 LSTM 或簡(jiǎn)單的 Transformer)可以作為學(xué)生模型。

# 1. 導(dǎo)包
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertModel, BertTokenizer
# 2. 定義教師模型和學(xué)生模型
#我們將使用 Hugging Face 的 transformers 庫來加載 BERT 模型罢猪,
#并定義一個(gè)簡(jiǎn)單的 LSTM 模型作為學(xué)生模型近她。

# 定義教師模型(BERT)
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(768, 2)  # 假設(shè)我們有一個(gè)二分類任務(wù)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.fc(pooled_output)
        return logits

# 定義學(xué)生模型(LSTM)
class StudentModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
        super(StudentModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout)
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text, text_lengths):
        embedded = self.dropout(self.embedding(text))
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu')) 
#nn.utils.rnn.pack_padded_sequence 是 PyTorch 中的一個(gè)函數(shù),用于將填充后的序列(padded sequence)打包成一個(gè)壓縮的序列(packed sequence)膳帕。text_lengths 是一個(gè)包含每個(gè)序列實(shí)際長(zhǎng)度的張量粘捎,打包后的序列只包含實(shí)際的非零元素,從而減少了計(jì)算量和內(nèi)存占用危彩。
        packed_output, (hidden, cell) = self.lstm(packed_embedded)
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)
        hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
        return self.fc(hidden)
#這句代碼的作用是將 LSTM 的隱藏狀態(tài)進(jìn)行拼接攒磨,并通過 dropout 層進(jìn)行正則化處理。
#hidden 是一個(gè)形狀為 (num_layers * num_directions, batch_size, hidden_dim) 的張量汤徽,表示 LSTM 的隱藏狀態(tài)娩缰。
#num_layers 是 LSTM 的層數(shù)。
#num_directions 是 LSTM 的方向數(shù)(1 表示單向谒府,2 表示雙向)拼坎。
#batch_size 是批次大小。
#hidden_dim 是隱藏狀態(tài)的維度完疫。hidden[-2,:,:] 和 hidden[-1,:,:] 分別表示倒數(shù)第二層和最后一層的隱藏狀態(tài)泰鸡。對(duì)于雙向 LSTM,hidden[-2,:,:] 和 hidden[-1,:,:] 分別表示前向和后向的隱藏狀態(tài)趋惨。

# 初始化教師模型和學(xué)生模型
teacher_model = TeacherModel()
student_model = StudentModel(vocab_size=30522, embedding_dim=100, hidden_dim=256, output_dim=2, n_layers=2, bidirectional=True, dropout=0.5)

# 3. 定義損失函數(shù)和優(yōu)化器
criterion = nn.KLDivLoss(reduction='batchmean')
optimizer = optim.Adam(student_model.parameters(), lr=0.001)
# 4. 加載數(shù)據(jù)和預(yù)處理
#假設(shè)我們有一個(gè)簡(jiǎn)單的數(shù)據(jù)集鸟顺,包含文本和標(biāo)簽。我們需要使用 BERT 的 tokenizer 來處理輸入數(shù)據(jù)器虾。
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 示例數(shù)據(jù)
texts = ["I love programming", "I hate bugs"]
labels = [1, 0]  # 1表示正面讯嫂,0表示負(fù)面

# 預(yù)處理數(shù)據(jù)
encoded_texts = [tokenizer.encode(text, add_special_tokens=True) for text in texts]
max_len = max(len(text) for text in encoded_texts)
input_ids = [text + [0] * (max_len - len(text)) for text in encoded_texts]
#對(duì)于每個(gè)文本 text,以0將其填充到 max_len 的長(zhǎng)度兆沙。
attention_mask = [[1] * len(text) + [0] * (max_len - len(text)) for text in encoded_texts]
#1是實(shí)際的欧芽,0是填充的。

input_ids = torch.tensor(input_ids)
attention_mask = torch.tensor(attention_mask)
labels = torch.tensor(labels)

#4. 訓(xùn)練學(xué)生模型
#在訓(xùn)練過程中葛圃,我們使用教師模型的輸出作為軟標(biāo)簽(soft labels)來指導(dǎo)學(xué)生模型的學(xué)習(xí)千扔。
# 訓(xùn)練學(xué)生模型
for epoch in range(5):  # 假設(shè)我們訓(xùn)練5個(gè)epoch
    student_model.train()
    optimizer.zero_grad()

    # 獲取教師模型的輸出
    with torch.no_grad():
        teacher_logits = teacher_model(input_ids, attention_mask)
        teacher_probs = torch.softmax(teacher_logits, dim=1)

    # 獲取學(xué)生模型的輸出
    student_logits = student_model(input_ids, torch.tensor([len(text) for text in encoded_texts]))
    student_probs = torch.softmax(student_logits, dim=1)

    # 計(jì)算KL散度損失
    loss = criterion(torch.log(student_probs), teacher_probs)
    loss.backward()
    optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')
#5. 測(cè)試學(xué)生模型
#在訓(xùn)練完成后憎妙,可以使用學(xué)生模型進(jìn)行預(yù)測(cè)。
# 測(cè)試學(xué)生模型
student_model.eval()
with torch.no_grad():
    test_text = ["I love coding"]
    encoded_test_text = [tokenizer.encode(text, add_special_tokens=True) for text in test_text]
    max_len = max(len(text) for text in encoded_test_text)
    input_ids = [text + [0] * (max_len - len(text)) for text in encoded_test_text]
    attention_mask = [[1] * len(text) + [0] * (max_len - len(text)) for text in encoded_test_text]

    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)

    student_logits = student_model(input_ids, torch.tensor([len(text) for text in encoded_test_text]))
    student_probs = torch.softmax(student_logits, dim=1)
    predicted_label = torch.argmax(student_probs, dim=1).item()

    print(f'Predicted Label: {predicted_label}')
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末曲楚,一起剝皮案震驚了整個(gè)濱河市厘唾,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌龙誊,老刑警劉巖抚垃,帶你破解...
    沈念sama閱讀 219,110評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異趟大,居然都是意外死亡鹤树,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,443評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門逊朽,熙熙樓的掌柜王于貴愁眉苦臉地迎上來罕伯,“玉大人,你說我怎么就攤上這事叽讳∽匪” “怎么了?”我有些...
    開封第一講書人閱讀 165,474評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵绽榛,是天一觀的道長(zhǎng)湿酸。 經(jīng)常有香客問我,道長(zhǎng)灭美,這世上最難降的妖魔是什么推溃? 我笑而不...
    開封第一講書人閱讀 58,881評(píng)論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮届腐,結(jié)果婚禮上铁坎,老公的妹妹穿的比我還像新娘。我一直安慰自己犁苏,他們只是感情好硬萍,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,902評(píng)論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著围详,像睡著了一般朴乖。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上助赞,一...
    開封第一講書人閱讀 51,698評(píng)論 1 305
  • 那天买羞,我揣著相機(jī)與錄音,去河邊找鬼雹食。 笑死畜普,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的群叶。 我是一名探鬼主播吃挑,決...
    沈念sama閱讀 40,418評(píng)論 3 419
  • 文/蒼蘭香墨 我猛地睜開眼钝荡,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了舶衬?” 一聲冷哼從身側(cè)響起埠通,我...
    開封第一講書人閱讀 39,332評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎逛犹,沒想到半個(gè)月后植阴,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,796評(píng)論 1 316
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡圾浅,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,968評(píng)論 3 337
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了憾朴。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片狸捕。...
    茶點(diǎn)故事閱讀 40,110評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖众雷,靈堂內(nèi)的尸體忽然破棺而出灸拍,到底是詐尸還是另有隱情,我是刑警寧澤砾省,帶...
    沈念sama閱讀 35,792評(píng)論 5 346
  • 正文 年R本政府宣布鸡岗,位于F島的核電站,受9級(jí)特大地震影響编兄,放射性物質(zhì)發(fā)生泄漏轩性。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,455評(píng)論 3 331
  • 文/蒙蒙 一狠鸳、第九天 我趴在偏房一處隱蔽的房頂上張望揣苏。 院中可真熱鬧,春花似錦件舵、人聲如沸卸察。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,003評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽坑质。三九已至,卻和暖如春临梗,著一層夾襖步出監(jiān)牢的瞬間涡扼,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,130評(píng)論 1 272
  • 我被黑心中介騙來泰國打工夜焦, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留壳澳,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,348評(píng)論 3 373
  • 正文 我出身青樓茫经,卻偏偏與公主長(zhǎng)得像巷波,于是被迫代替她去往敵國和親萎津。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,047評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容