MegEngine Python 層模塊串講(下)

前面的文章中,我們簡(jiǎn)單介紹了在 MegEngine imperative 中的各模塊以及它們的作用。對(duì)于新用戶而言可能不太了解各個(gè)模塊的使用方法绵脯,對(duì)于模塊的結(jié)構(gòu)和原理也是一頭霧水。Python 作為現(xiàn)在深度學(xué)習(xí)領(lǐng)域的主流編程語(yǔ)言慰照,其相關(guān)的模塊自然也是深度學(xué)習(xí)框架的重中之重。

模塊串講將對(duì) MegEngine 的 Python 層相關(guān)模塊分別進(jìn)行更加深入的介紹琉朽,會(huì)涉及到一些原理的解釋和代碼解讀毒租。Python 層模塊串講共分為上、中箱叁、下三個(gè)部分墅垮,本文將介紹 Python 層的 quantization 模塊惕医。量化是為了減少模型的存儲(chǔ)空間和計(jì)算量,從而加速模型的推理過(guò)程噩斟。在量化中曹锨,我們將權(quán)重和激活值從浮點(diǎn)數(shù)轉(zhuǎn)換為整數(shù)孤个,從而減少模型的大小和運(yùn)算的復(fù)雜性剃允。通過(guò)本文讀者將會(huì)對(duì)量化的基本原理和使用 MegEngine 得到量化模型有所了解。

降低模型內(nèi)存占用利器 —— quantization 模塊

量化是一種對(duì)深度學(xué)習(xí)模型參數(shù)進(jìn)行壓縮以降低計(jì)算量的技術(shù)齐鲤。它基于這樣一種思想:神經(jīng)網(wǎng)絡(luò)是一個(gè)近似計(jì)算過(guò)程斥废,不需要其中每個(gè)計(jì)算過(guò)程的絕對(duì)的精確。因此在某些情況下可以把需要較多比特存儲(chǔ)的模型參數(shù)轉(zhuǎn)為使用較少比特存儲(chǔ)给郊,而不影響模型的精度牡肉。

量化通過(guò)舍棄數(shù)值表示上的精度來(lái)追求極致的推理速度。直覺(jué)上用低精度/比特類型的模型參數(shù)會(huì)帶來(lái)較大的模型精度下降(稱之為掉點(diǎn))淆九,但在經(jīng)過(guò)一系列精妙的量化處理之后统锤,掉點(diǎn)可以變得微乎其微。

如下圖所示炭庙,量化通常是將浮點(diǎn)模型(常見(jiàn)神經(jīng)網(wǎng)絡(luò)的 Tensor 數(shù)據(jù)類型一般是 float32)處理為一個(gè)量化模型(Tensor 數(shù)據(jù)類型為 int8 等)饲窿。

1.png

量化基本流程

MegEngine 中支持工業(yè)界的兩類主流量化技術(shù),分別是訓(xùn)練后量化(PTQ)和量化感知訓(xùn)練(QAT)焕蹄。

  1. 訓(xùn)練后量化(Post-Training Quantization, PTQ

    訓(xùn)練后量化逾雄,顧名思義就是將訓(xùn)練后的 Float 模型轉(zhuǎn)換成低精度/比特模型。

    比較常見(jiàn)的做法是對(duì)模型的權(quán)重(weight)和激活值(activation)進(jìn)行處理腻脏,把它們轉(zhuǎn)換成精度更低的類型鸦泳。雖然是在訓(xùn)練后再進(jìn)行精度轉(zhuǎn)換,但為了獲取到模型轉(zhuǎn)換需要的一些統(tǒng)計(jì)信息(比如縮放因子 scale)永品,仍然需要在模型進(jìn)行前向計(jì)算時(shí)插入觀察者(Observer)做鹰。

    使用訓(xùn)練后量化技術(shù)通常會(huì)導(dǎo)致模型掉點(diǎn),某些情況下甚至?xí)?dǎo)致模型不可用鼎姐〖佤铮可以使用小批量數(shù)據(jù)在量化之前對(duì) Observer 進(jìn)行校準(zhǔn)(Calibration),這種方案叫做 Calibration 后量化症见。也可以使用 QAT 方案喂走。

  2. 量化感知訓(xùn)練(Quantization-Aware Training, QAT

    QAT 會(huì)向 Float 模型中插入一些偽量化(FakeQuantize)算子,在前向計(jì)算過(guò)程中偽量化算子根據(jù) Observer 觀察到的信息進(jìn)行量化模擬谋作,模擬數(shù)值截?cái)嗟那闆r下的數(shù)值轉(zhuǎn)換芋肠,再將轉(zhuǎn)換后的值還原為原類型。讓被量化對(duì)象在訓(xùn)練時(shí)“提前適應(yīng)”量化操作遵蚜,減少訓(xùn)練后量化的掉點(diǎn)影響帖池。

    而增加這些偽量化算子模擬量化過(guò)程又會(huì)增加訓(xùn)練開(kāi)銷奈惑,因此模型量化通常的思路是:

    • 按照平時(shí)訓(xùn)練模型的流程,設(shè)計(jì)好 Float 模型并進(jìn)行訓(xùn)練睡汹,得到一個(gè)預(yù)訓(xùn)練模型肴甸;
    • 插入 ObserverFakeQuantize 算子,得到 Quantized-Float 模型(QFloat 模型)進(jìn)行量化感知訓(xùn)練囚巴;
    • 訓(xùn)練后量化原在,得到真正的 Quantized 模型(Q 模型),也就是最終用來(lái)進(jìn)行推理的低比特模型彤叉。

    過(guò)程如下圖所示(實(shí)際使用時(shí)庶柿,量化流程也可能會(huì)有變化):

2.png
  1. 注意這里的量化感知訓(xùn)練 QAT 是在預(yù)訓(xùn)練好的 QFloat 模型上微調(diào)(Fine-tune)的(而不是在原來(lái)的 Float 模型上),這樣減小了訓(xùn)練的開(kāi)銷秽浇,得到的微調(diào)后的模型再做訓(xùn)練后量化 PTQ(“真量化”)浮庐,QModel 就是最終部署的模型。

模型(Model)與模塊(Module

量化是一個(gè)對(duì)模型(Model)的轉(zhuǎn)換操作柬焕,但其本質(zhì)其實(shí)是對(duì)模型中的模塊( Module) 進(jìn)行替換审残。

MegEngine 中,對(duì)應(yīng)與 Float Model 斑举、QFloat ModelQ ModelModule 分別為:

  1. 進(jìn)行正常 float 運(yùn)算的默認(rèn) Module
  2. 帶有 ObserverFakeQuantize 算子的 qat.QATModule
  3. 無(wú)法訓(xùn)練搅轿、專門用于部署的 quantized.QuantizedModule

Conv 算子為例,這些 Module 對(duì)應(yīng)的實(shí)現(xiàn)分別在:

量化配置 QConfig

量化配置包括 ObserverFakeQuantize 兩部分懂昂,要設(shè)置它們介时,用戶可以使用 MegEngine 預(yù)設(shè)配置也可以自定義配置。

  1. 使用 MegEngine 預(yù)設(shè)配置

    MegEngine 提供了多種量化預(yù)設(shè)配置凌彬。

    ema_fakequant_qconfig 為例沸柔,用戶可以通過(guò)如下代碼使用該預(yù)設(shè)配置:

import megengine.quantization as Q
Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
  1. 用戶自定義量化配置

    用戶還可以自己選擇 ObserverFakeQuantize,靈活配置 QConfig 靈活選擇 weight_observer铲敛、act_observer褐澎、weight_fake_quantact_fake_quant)。

    可選的 ObserverFakeQuantize 可參考量化 API 參考頁(yè)面伐蒋。

QConfig 提供了一系列用于對(duì)模型做量化的接口工三,要使用這些接口,需要網(wǎng)絡(luò)的 Module 能夠在 forward 時(shí)給權(quán)重先鱼、激活值加上 Observer 以及進(jìn)行 FakeQuantize俭正。

模型轉(zhuǎn)換的作用是:將普通的 Float Module 替換為支持這些操作的 QATModule(可以訓(xùn)練),再替換為 QuantizeModule(無(wú)法訓(xùn)練焙畔、專用于部署)掸读。

Conv2d 為例,模型轉(zhuǎn)換的過(guò)程如圖:

3.png

在量化時(shí)常常會(huì)用到算子融合(Fusion)。比如一個(gè) Conv2d 算子加上一個(gè) BatchNorm2d 算子儿惫,可以用一個(gè) ConvBn2d 算子來(lái)等價(jià)替代澡罚,這里 ConvBn2d 算子就是 Conv2dBatchNorm2d 的融合算子。

MegEngine 中提供了一些預(yù)先融合好的 Module肾请,比如 ConvRelu2d留搔、ConvBn2dConvBnRelu2d 等。使用融合算子會(huì)使用底層實(shí)現(xiàn)好的融合算子(kernel)铛铁,而不會(huì)分別調(diào)用子模塊在底層的 kernel隔显,因此能夠加快模型的速度,而且框架還無(wú)需根據(jù)網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行自動(dòng)匹配和融合優(yōu)化避归,同時(shí)存在融合和不需融合的算子也可以讓用戶能更好的控制網(wǎng)絡(luò)轉(zhuǎn)換的過(guò)程荣月。

實(shí)現(xiàn)預(yù)先融合的 Module 也有缺點(diǎn),那就是用戶需要在代碼中修改原先的網(wǎng)絡(luò)結(jié)構(gòu)(把可以融合的多個(gè) Module 改為融合后的 Module)梳毙。

模型轉(zhuǎn)換的原理是,將父 Module 中的 Quantable (可被量化的)子 Module 替換為新 Module捐下。而這些 Quantable submodule 中可能又包含 Quantable submodule账锹,這些 submodule 不會(huì)再進(jìn)一步轉(zhuǎn)換,因?yàn)槠涓?Module 被替換后的 forward 計(jì)算過(guò)程已經(jīng)改變了坷襟,不再依賴于這些子 Module奸柬。

有時(shí)候用戶不希望對(duì)模型的部分 Module 進(jìn)行轉(zhuǎn)換,而是保留其 Float 狀態(tài)(比如轉(zhuǎn)換會(huì)導(dǎo)致模型掉點(diǎn))婴程,則可以使用 disable_quantize 方法關(guān)閉量化廓奕。

比如下面這行代碼關(guān)閉了 fc 層的量化處理:

model.fc.disable_quantize()

由于模型轉(zhuǎn)換過(guò)程修改了原網(wǎng)絡(luò)結(jié)構(gòu),因此模型保存與加載無(wú)法直接適用于轉(zhuǎn)換后的網(wǎng)絡(luò)档叔,讀取新網(wǎng)絡(luò)保存的參數(shù)時(shí)桌粉,需要先調(diào)用轉(zhuǎn)換接口得到轉(zhuǎn)換后的網(wǎng)絡(luò),才能用 load_state_dict 將參數(shù)進(jìn)行加載衙四。

量化代碼

要從一個(gè) Float 模型得到一個(gè)可用于部署的量化模型铃肯,大致需要經(jīng)歷三個(gè)步驟:

  1. 修改網(wǎng)絡(luò)結(jié)構(gòu)。將 Float 模型中的普通 Module 替換為已經(jīng)融合好的 Module传蹈,比如 ConvBn2d押逼、ConvBnRelu2d 等(可以參考 imperative/python/megengine/module/quantized 目錄下提供的已融合模塊)。然后在正常模式下預(yù)訓(xùn)練模型惦界,并且在每輪迭代保存網(wǎng)絡(luò)檢查點(diǎn)挑格。

    ResNet18BasicBlock 為例,模塊修改前的代碼為:

class BasicBlock(M.Module):
      def __init__(self, in_channels, channels):
         super().__init__()
         self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
         self.bn1 = M.BatchNorm2d
         self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
         self.bn2 = M.BatchNorm2d
         self.downsample = (
            M.Identity()
            if in_channels == channels and stride == 1
            else M.Sequential(
            M.Conv2d(in_channels, channels, 1, stride, bias=False)
            M.BatchNorm2d
         )

      def forward(self, x):
         identity = x
         x = F.relu(self.bn1(self.conv1(x)))
         x = self.bn2(self.conv2(x))
         identity = self.downsample(identity)
         x = F.relu(x + identity)
         return x

注意到現(xiàn)在的前向中使用的都是普通 Module 拼接在一起沾歪,而實(shí)際上許多模塊是可以融合的漂彤。

用可以融合的模塊替換掉原先的 Module

class BasicBlock(M.Module):
      def __init__(self, in_channels, channels):
         super().__init__()
         self.conv_bn_relu1 = M.ConvBnRelu2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
         self.conv_bn2 = M.ConvBn2d(channels, channels, 3, 1, padding=1, bias=False)
         self.downsample = (
            M.Identity()
            if in_channels == channels and stride == 1
            else M.ConvBn2d(in_channels, channels, 1, 1, bias=False)
         )
         self.add_relu = M.Elemwise("FUSE_ADD_RELU")

      def forward(self, x):
         identity = x
         x = self.conv_bn_relu1(x)
         x = self.conv_bn2(x)
         identity = self.downsample(identity)
         x = self.add_relu(x, identity)
         return x

注意到此時(shí)前向中已經(jīng)有許多模塊使用的是融合后的 Module

再對(duì)該模型進(jìn)行若干論迭代訓(xùn)練,并保存檢查點(diǎn):

for step in range(0, total_steps):
    # Linear learning rate decay
    epoch = step // steps_per_epoch
    learning_rate = adjust_learning_rate(step, epoch)

    image, label = next(train_queue)
    image = tensor(image.astype("float32"))
    label = tensor(label.astype("int32"))

    n = image.shape[0]

    loss, acc1, acc5 = train_func(image, label, net, gm)  # traced
    optimizer.step().clear_grad()

    # Save checkpoints

完整代碼見(jiàn):

-   [修改前的模型結(jié)構(gòu)](https://github.com/MegEngine/Models/blob/master/official/vision/classification/resnet/model.py)
-   [修改后的模型結(jié)構(gòu)](https://github.com/MegEngine/Models/blob/master/official/quantization/models/resnet.py)
  1. 調(diào)用 quantize_qat 方法 將 Float 模型轉(zhuǎn)換為 QFloat 模型显歧,并進(jìn)行微調(diào)(量化感知訓(xùn)練或校準(zhǔn)仪或,取決于 QConfig)。

    使用 quantize_qat 方法將 Float 模型轉(zhuǎn)換為 QFloat 模型的代碼大致為:

from megengine.quantization import ema_fakequant_qconfig, quantize_qat

model = ResNet18()

# QAT
quantize_qat(model, ema_fakequant_qconfig)

# Or Calibration:
# quantize_qat(model, calibration_qconfig)

Float 模型轉(zhuǎn)換為 QFloat 模型后士骤,加載預(yù)訓(xùn)練 Float 模型保存的檢查點(diǎn)進(jìn)行微調(diào) / 校準(zhǔn):

if args.checkpoint:
    logger.info("Load pretrained weights from %s", args.checkpoint)
    ckpt = mge.load(args.checkpoint)
    ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    model.load_state_dict(ckpt, strict=False)

# Fine-tune / Calibrate with new traced train_func
# Save checkpoints

完整代碼見(jiàn):

-   [Finetune](https://github.com/MegEngine/Models/blob/master/official/quantization/finetune.py)
-   [Calibration](https://github.com/MegEngine/Models/blob/master/official/quantization/calibration.py)
  1. 調(diào)用 quantize 方法將 QFloat 模型轉(zhuǎn)換為 Q 模型范删,也就是可用于模型部署的量化模型。

需要在推理的方法中設(shè)置 tracecapture_as_const=True拷肌,以進(jìn)行模型導(dǎo)出:

from megengine.quantization import quantize

@jit.trace(capture_as_const=True)
def infer_func(processed_img):
    model.eval()
    logits = model(processed_img)
    probs = F.softmax(logits)
    return probs

quantize(model)

processed_img = transform.apply(image)[np.newaxis, :]
processed_img = processed_img.astype("int8")
probs = infer_func(processed_img)

infer_func.dump(output_file, arg_names=["data"])

調(diào)用了 quantize 后到旦,model 就從 QFloat 模型轉(zhuǎn)換為了 Q 模型,之后便使用這個(gè) Quantized 模型進(jìn)行推理巨缘。

調(diào)用 dump 方法將模型導(dǎo)出添忘,便得到了一個(gè)可用于部署的量化模型。

完整代碼見(jiàn):

小結(jié)

MegEngine Python 層模塊串講系列到這里就結(jié)束了若锁,我們介紹了用戶在使用 MegEngine 時(shí)主要會(huì)接觸到的 python 層的各個(gè)模塊的主要功能搁骑、結(jié)構(gòu)以及使用方法,此外還有一些原理性的介紹又固。對(duì)于各模塊具體實(shí)現(xiàn)感興趣的讀者可以參考 MegEngine 官方文檔github仲器。之后的文章我們會(huì)對(duì) MegEngine 開(kāi)發(fā)相關(guān)工具以及 MegEngine 底層的實(shí)現(xiàn)做更深入的介紹。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末仰冠,一起剝皮案震驚了整個(gè)濱河市乏冀,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌洋只,老刑警劉巖辆沦,帶你破解...
    沈念sama閱讀 217,542評(píng)論 6 504
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異识虚,居然都是意外死亡肢扯,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,822評(píng)論 3 394
  • 文/潘曉璐 我一進(jìn)店門舷礼,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)鹃彻,“玉大人,你說(shuō)我怎么就攤上這事妻献≈胫辏” “怎么了?”我有些...
    開(kāi)封第一講書人閱讀 163,912評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵育拨,是天一觀的道長(zhǎng)谨履。 經(jīng)常有香客問(wèn)我,道長(zhǎng)熬丧,這世上最難降的妖魔是什么笋粟? 我笑而不...
    開(kāi)封第一講書人閱讀 58,449評(píng)論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮,結(jié)果婚禮上害捕,老公的妹妹穿的比我還像新娘绿淋。我一直安慰自己,他們只是感情好尝盼,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,500評(píng)論 6 392
  • 文/花漫 我一把揭開(kāi)白布吞滞。 她就那樣靜靜地躺著,像睡著了一般盾沫。 火紅的嫁衣襯著肌膚如雪裁赠。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書人閱讀 51,370評(píng)論 1 302
  • 那天赴精,我揣著相機(jī)與錄音佩捞,去河邊找鬼。 笑死蕾哟,一個(gè)胖子當(dāng)著我的面吹牛一忱,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播渐苏,決...
    沈念sama閱讀 40,193評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼掀潮,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了琼富?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書人閱讀 39,074評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤庄新,失蹤者是張志新(化名)和其女友劉穎鞠眉,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體择诈,經(jīng)...
    沈念sama閱讀 45,505評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡械蹋,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,722評(píng)論 3 335
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了羞芍。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片哗戈。...
    茶點(diǎn)故事閱讀 39,841評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖荷科,靈堂內(nèi)的尸體忽然破棺而出唯咬,到底是詐尸還是另有隱情,我是刑警寧澤畏浆,帶...
    沈念sama閱讀 35,569評(píng)論 5 345
  • 正文 年R本政府宣布胆胰,位于F島的核電站,受9級(jí)特大地震影響刻获,放射性物質(zhì)發(fā)生泄漏蜀涨。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,168評(píng)論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望厚柳。 院中可真熱鬧氧枣,春花似錦、人聲如沸别垮。這莊子的主人今日做“春日...
    開(kāi)封第一講書人閱讀 31,783評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)宰闰。三九已至茬贵,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間移袍,已是汗流浹背解藻。 一陣腳步聲響...
    開(kāi)封第一講書人閱讀 32,918評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留葡盗,地道東北人螟左。 一個(gè)月前我還...
    沈念sama閱讀 47,962評(píng)論 2 370
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像觅够,于是被迫代替她去往敵國(guó)和親胶背。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,781評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容