在前面的文章中,我們簡(jiǎn)單介紹了在 MegEngine imperative
中的各模塊以及它們的作用。對(duì)于新用戶而言可能不太了解各個(gè)模塊的使用方法绵脯,對(duì)于模塊的結(jié)構(gòu)和原理也是一頭霧水。Python
作為現(xiàn)在深度學(xué)習(xí)領(lǐng)域的主流編程語(yǔ)言慰照,其相關(guān)的模塊自然也是深度學(xué)習(xí)框架的重中之重。
模塊串講將對(duì) MegEngine
的 Python
層相關(guān)模塊分別進(jìn)行更加深入的介紹琉朽,會(huì)涉及到一些原理的解釋和代碼解讀毒租。Python
層模塊串講共分為上、中箱叁、下三個(gè)部分墅垮,本文將介紹 Python
層的 quantization
模塊惕医。量化是為了減少模型的存儲(chǔ)空間和計(jì)算量,從而加速模型的推理過(guò)程噩斟。在量化中曹锨,我們將權(quán)重和激活值從浮點(diǎn)數(shù)轉(zhuǎn)換為整數(shù)孤个,從而減少模型的大小和運(yùn)算的復(fù)雜性剃允。通過(guò)本文讀者將會(huì)對(duì)量化的基本原理和使用 MegEngine
得到量化模型有所了解。
降低模型內(nèi)存占用利器 —— quantization 模塊
量化是一種對(duì)深度學(xué)習(xí)模型參數(shù)進(jìn)行壓縮以降低計(jì)算量的技術(shù)齐鲤。它基于這樣一種思想:神經(jīng)網(wǎng)絡(luò)是一個(gè)近似計(jì)算過(guò)程斥废,不需要其中每個(gè)計(jì)算過(guò)程的絕對(duì)的精確。因此在某些情況下可以把需要較多比特存儲(chǔ)的模型參數(shù)轉(zhuǎn)為使用較少比特存儲(chǔ)给郊,而不影響模型的精度牡肉。
量化通過(guò)舍棄數(shù)值表示上的精度來(lái)追求極致的推理速度。直覺(jué)上用低精度/比特類型的模型參數(shù)會(huì)帶來(lái)較大的模型精度下降(稱之為掉點(diǎn))淆九,但在經(jīng)過(guò)一系列精妙的量化處理之后统锤,掉點(diǎn)可以變得微乎其微。
如下圖所示炭庙,量化通常是將浮點(diǎn)模型(常見(jiàn)神經(jīng)網(wǎng)絡(luò)的 Tensor
數(shù)據(jù)類型一般是 float32
)處理為一個(gè)量化模型(Tensor
數(shù)據(jù)類型為 int8
等)饲窿。
量化基本流程
MegEngine
中支持工業(yè)界的兩類主流量化技術(shù),分別是訓(xùn)練后量化(PTQ
)和量化感知訓(xùn)練(QAT
)焕蹄。
-
訓(xùn)練后量化(
Post-Training Quantization
,PTQ
)訓(xùn)練后量化逾雄,顧名思義就是將訓(xùn)練后的
Float
模型轉(zhuǎn)換成低精度/比特模型。比較常見(jiàn)的做法是對(duì)模型的權(quán)重(
weight
)和激活值(activation
)進(jìn)行處理腻脏,把它們轉(zhuǎn)換成精度更低的類型鸦泳。雖然是在訓(xùn)練后再進(jìn)行精度轉(zhuǎn)換,但為了獲取到模型轉(zhuǎn)換需要的一些統(tǒng)計(jì)信息(比如縮放因子scale
)永品,仍然需要在模型進(jìn)行前向計(jì)算時(shí)插入觀察者(Observer
)做鹰。使用訓(xùn)練后量化技術(shù)通常會(huì)導(dǎo)致模型掉點(diǎn),某些情況下甚至?xí)?dǎo)致模型不可用鼎姐〖佤铮可以使用小批量數(shù)據(jù)在量化之前對(duì)
Observer
進(jìn)行校準(zhǔn)(Calibration
),這種方案叫做Calibration
后量化症见。也可以使用QAT
方案喂走。 -
量化感知訓(xùn)練(
Quantization-Aware Training
,QAT
)QAT
會(huì)向Float
模型中插入一些偽量化(FakeQuantize
)算子,在前向計(jì)算過(guò)程中偽量化算子根據(jù)Observer
觀察到的信息進(jìn)行量化模擬谋作,模擬數(shù)值截?cái)嗟那闆r下的數(shù)值轉(zhuǎn)換芋肠,再將轉(zhuǎn)換后的值還原為原類型。讓被量化對(duì)象在訓(xùn)練時(shí)“提前適應(yīng)”量化操作遵蚜,減少訓(xùn)練后量化的掉點(diǎn)影響帖池。而增加這些偽量化算子模擬量化過(guò)程又會(huì)增加訓(xùn)練開(kāi)銷奈惑,因此模型量化通常的思路是:
- 按照平時(shí)訓(xùn)練模型的流程,設(shè)計(jì)好
Float
模型并進(jìn)行訓(xùn)練睡汹,得到一個(gè)預(yù)訓(xùn)練模型肴甸; - 插入
Observer
和FakeQuantize
算子,得到Quantized-Float
模型(QFloat
模型)進(jìn)行量化感知訓(xùn)練囚巴; - 訓(xùn)練后量化原在,得到真正的
Quantized
模型(Q
模型),也就是最終用來(lái)進(jìn)行推理的低比特模型彤叉。
過(guò)程如下圖所示(實(shí)際使用時(shí)庶柿,量化流程也可能會(huì)有變化):
- 按照平時(shí)訓(xùn)練模型的流程,設(shè)計(jì)好
- 注意這里的量化感知訓(xùn)練
QAT
是在預(yù)訓(xùn)練好的QFloat
模型上微調(diào)(Fine-tune
)的(而不是在原來(lái)的Float
模型上),這樣減小了訓(xùn)練的開(kāi)銷秽浇,得到的微調(diào)后的模型再做訓(xùn)練后量化PTQ
(“真量化”)浮庐,QModel
就是最終部署的模型。
模型(Model
)與模塊(Module
)
量化是一個(gè)對(duì)模型(Model
)的轉(zhuǎn)換操作柬焕,但其本質(zhì)其實(shí)是對(duì)模型中的模塊( Module
) 進(jìn)行替換审残。
在 MegEngine
中,對(duì)應(yīng)與 Float Model
斑举、QFloat Model
和 Q Model
的 Module
分別為:
- 進(jìn)行正常
float
運(yùn)算的默認(rèn)Module
- 帶有
Observer
和FakeQuantize
算子的qat.QATModule
- 無(wú)法訓(xùn)練搅轿、專門用于部署的
quantized.QuantizedModule
以 Conv
算子為例,這些 Module
對(duì)應(yīng)的實(shí)現(xiàn)分別在:
-
Float Module
:imperative/python/megengine/module/conv.py -
qat.QATModule
:imperative/python/megengine/module/qat/conv.py -
quantized.QuantizedModule
:imperative/python/megengine/module/quantized/conv.py
量化配置 QConfig
量化配置包括 Observer
和 FakeQuantize
兩部分懂昂,要設(shè)置它們介时,用戶可以使用 MegEngine
預(yù)設(shè)配置也可以自定義配置。
-
使用
MegEngine
預(yù)設(shè)配置MegEngine
提供了多種量化預(yù)設(shè)配置凌彬。以
ema_fakequant_qconfig
為例沸柔,用戶可以通過(guò)如下代碼使用該預(yù)設(shè)配置:
import megengine.quantization as Q
Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
-
用戶自定義量化配置
用戶還可以自己選擇
Observer
和FakeQuantize
,靈活配置 QConfig 靈活選擇weight_observer
铲敛、act_observer
褐澎、weight_fake_quant
和act_fake_quant
)。可選的
Observer
和FakeQuantize
可參考量化 API 參考頁(yè)面伐蒋。
QConfig
提供了一系列用于對(duì)模型做量化的接口工三,要使用這些接口,需要網(wǎng)絡(luò)的 Module
能夠在 forward
時(shí)給權(quán)重先鱼、激活值加上 Observer
以及進(jìn)行 FakeQuantize
俭正。
模型轉(zhuǎn)換的作用是:將普通的 Float Module
替換為支持這些操作的 QATModule
(可以訓(xùn)練),再替換為 QuantizeModule
(無(wú)法訓(xùn)練焙畔、專用于部署)掸读。
以 Conv2d
為例,模型轉(zhuǎn)換的過(guò)程如圖:
在量化時(shí)常常會(huì)用到算子融合(Fusion
)。比如一個(gè) Conv2d
算子加上一個(gè) BatchNorm2d
算子儿惫,可以用一個(gè) ConvBn2d
算子來(lái)等價(jià)替代澡罚,這里 ConvBn2d
算子就是 Conv2d
和 BatchNorm2d
的融合算子。
MegEngine
中提供了一些預(yù)先融合好的 Module
肾请,比如 ConvRelu2d
留搔、ConvBn2d
和 ConvBnRelu2d
等。使用融合算子會(huì)使用底層實(shí)現(xiàn)好的融合算子(kernel
)铛铁,而不會(huì)分別調(diào)用子模塊在底層的 kernel
隔显,因此能夠加快模型的速度,而且框架還無(wú)需根據(jù)網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行自動(dòng)匹配和融合優(yōu)化避归,同時(shí)存在融合和不需融合的算子也可以讓用戶能更好的控制網(wǎng)絡(luò)轉(zhuǎn)換的過(guò)程荣月。
實(shí)現(xiàn)預(yù)先融合的 Module
也有缺點(diǎn),那就是用戶需要在代碼中修改原先的網(wǎng)絡(luò)結(jié)構(gòu)(把可以融合的多個(gè) Module
改為融合后的 Module
)梳毙。
模型轉(zhuǎn)換的原理是,將父 Module
中的 Quantable
(可被量化的)子 Module
替換為新 Module
捐下。而這些 Quantable submodule
中可能又包含 Quantable submodule
账锹,這些 submodule
不會(huì)再進(jìn)一步轉(zhuǎn)換,因?yàn)槠涓?Module
被替換后的 forward
計(jì)算過(guò)程已經(jīng)改變了坷襟,不再依賴于這些子 Module
奸柬。
有時(shí)候用戶不希望對(duì)模型的部分 Module
進(jìn)行轉(zhuǎn)換,而是保留其 Float
狀態(tài)(比如轉(zhuǎn)換會(huì)導(dǎo)致模型掉點(diǎn))婴程,則可以使用 disable_quantize
方法關(guān)閉量化廓奕。
比如下面這行代碼關(guān)閉了 fc
層的量化處理:
model.fc.disable_quantize()
由于模型轉(zhuǎn)換過(guò)程修改了原網(wǎng)絡(luò)結(jié)構(gòu),因此模型保存與加載無(wú)法直接適用于轉(zhuǎn)換后的網(wǎng)絡(luò)档叔,讀取新網(wǎng)絡(luò)保存的參數(shù)時(shí)桌粉,需要先調(diào)用轉(zhuǎn)換接口得到轉(zhuǎn)換后的網(wǎng)絡(luò),才能用 load_state_dict
將參數(shù)進(jìn)行加載衙四。
量化代碼
要從一個(gè) Float
模型得到一個(gè)可用于部署的量化模型铃肯,大致需要經(jīng)歷三個(gè)步驟:
-
修改網(wǎng)絡(luò)結(jié)構(gòu)。將
Float
模型中的普通Module
替換為已經(jīng)融合好的Module
传蹈,比如ConvBn2d
押逼、ConvBnRelu2d
等(可以參考 imperative/python/megengine/module/quantized 目錄下提供的已融合模塊)。然后在正常模式下預(yù)訓(xùn)練模型惦界,并且在每輪迭代保存網(wǎng)絡(luò)檢查點(diǎn)挑格。以
ResNet18
的BasicBlock
為例,模塊修改前的代碼為:
class BasicBlock(M.Module):
def __init__(self, in_channels, channels):
super().__init__()
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
self.bn1 = M.BatchNorm2d
self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
self.bn2 = M.BatchNorm2d
self.downsample = (
M.Identity()
if in_channels == channels and stride == 1
else M.Sequential(
M.Conv2d(in_channels, channels, 1, stride, bias=False)
M.BatchNorm2d
)
def forward(self, x):
identity = x
x = F.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
identity = self.downsample(identity)
x = F.relu(x + identity)
return x
注意到現(xiàn)在的前向中使用的都是普通 Module
拼接在一起沾歪,而實(shí)際上許多模塊是可以融合的漂彤。
用可以融合的模塊替換掉原先的 Module
:
class BasicBlock(M.Module):
def __init__(self, in_channels, channels):
super().__init__()
self.conv_bn_relu1 = M.ConvBnRelu2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
self.conv_bn2 = M.ConvBn2d(channels, channels, 3, 1, padding=1, bias=False)
self.downsample = (
M.Identity()
if in_channels == channels and stride == 1
else M.ConvBn2d(in_channels, channels, 1, 1, bias=False)
)
self.add_relu = M.Elemwise("FUSE_ADD_RELU")
def forward(self, x):
identity = x
x = self.conv_bn_relu1(x)
x = self.conv_bn2(x)
identity = self.downsample(identity)
x = self.add_relu(x, identity)
return x
注意到此時(shí)前向中已經(jīng)有許多模塊使用的是融合后的 Module
。
再對(duì)該模型進(jìn)行若干論迭代訓(xùn)練,并保存檢查點(diǎn):
for step in range(0, total_steps):
# Linear learning rate decay
epoch = step // steps_per_epoch
learning_rate = adjust_learning_rate(step, epoch)
image, label = next(train_queue)
image = tensor(image.astype("float32"))
label = tensor(label.astype("int32"))
n = image.shape[0]
loss, acc1, acc5 = train_func(image, label, net, gm) # traced
optimizer.step().clear_grad()
# Save checkpoints
完整代碼見(jiàn):
- [修改前的模型結(jié)構(gòu)](https://github.com/MegEngine/Models/blob/master/official/vision/classification/resnet/model.py)
- [修改后的模型結(jié)構(gòu)](https://github.com/MegEngine/Models/blob/master/official/quantization/models/resnet.py)
-
調(diào)用 quantize_qat 方法 將
Float
模型轉(zhuǎn)換為QFloat
模型显歧,并進(jìn)行微調(diào)(量化感知訓(xùn)練或校準(zhǔn)仪或,取決于QConfig
)。使用
quantize_qat
方法將Float
模型轉(zhuǎn)換為QFloat
模型的代碼大致為:
from megengine.quantization import ema_fakequant_qconfig, quantize_qat
model = ResNet18()
# QAT
quantize_qat(model, ema_fakequant_qconfig)
# Or Calibration:
# quantize_qat(model, calibration_qconfig)
將 Float
模型轉(zhuǎn)換為 QFloat
模型后士骤,加載預(yù)訓(xùn)練 Float
模型保存的檢查點(diǎn)進(jìn)行微調(diào) / 校準(zhǔn):
if args.checkpoint:
logger.info("Load pretrained weights from %s", args.checkpoint)
ckpt = mge.load(args.checkpoint)
ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
model.load_state_dict(ckpt, strict=False)
# Fine-tune / Calibrate with new traced train_func
# Save checkpoints
完整代碼見(jiàn):
- [Finetune](https://github.com/MegEngine/Models/blob/master/official/quantization/finetune.py)
- [Calibration](https://github.com/MegEngine/Models/blob/master/official/quantization/calibration.py)
- 調(diào)用 quantize 方法將
QFloat
模型轉(zhuǎn)換為Q
模型范删,也就是可用于模型部署的量化模型。
需要在推理的方法中設(shè)置 trace
的 capture_as_const=True
拷肌,以進(jìn)行模型導(dǎo)出:
from megengine.quantization import quantize
@jit.trace(capture_as_const=True)
def infer_func(processed_img):
model.eval()
logits = model(processed_img)
probs = F.softmax(logits)
return probs
quantize(model)
processed_img = transform.apply(image)[np.newaxis, :]
processed_img = processed_img.astype("int8")
probs = infer_func(processed_img)
infer_func.dump(output_file, arg_names=["data"])
調(diào)用了 quantize
后到旦,model
就從 QFloat
模型轉(zhuǎn)換為了 Q
模型,之后便使用這個(gè) Quantized
模型進(jìn)行推理巨缘。
調(diào)用 dump
方法將模型導(dǎo)出添忘,便得到了一個(gè)可用于部署的量化模型。
完整代碼見(jiàn):
小結(jié)
MegEngine Python
層模塊串講系列到這里就結(jié)束了若锁,我們介紹了用戶在使用 MegEngine
時(shí)主要會(huì)接觸到的 python
層的各個(gè)模塊的主要功能搁骑、結(jié)構(gòu)以及使用方法,此外還有一些原理性的介紹又固。對(duì)于各模塊具體實(shí)現(xiàn)感興趣的讀者可以參考 MegEngine 官方文檔 和 github仲器。之后的文章我們會(huì)對(duì) MegEngine
開(kāi)發(fā)相關(guān)工具以及 MegEngine
底層的實(shí)現(xiàn)做更深入的介紹。