實踐torch.fx第一篇——基于Pytorch的模型優(yōu)化量化神器

第一篇——什么是torch.fx

今天聊一下比較重要torch.fx淫僻,也趁著這次機會把之前的torch.fx筆記整理下砸烦,筆記大概拆成三份甚纲,分別對應三篇:

  • 什么是torch.fx
  • 基于torch.fx做量化
  • 基于torch.fx量化部署到TensorRT

本文對應第一篇,主要介紹torch.fx和基本使用方法曲楚。廢話不多說尚氛,直接開始吧!

什么是Torch.FX

torch.fxPytorch 1.8出來的一套工具或者說一個庫洞渤,是做python-to-python code transformation阅嘶,大意就是可以把pytorch中的python前向代碼轉換為你想要的樣子,官方介紹如下:

We apply this principle in torch.fx, a program capture and
transformation library for PyTorch written entirely in Python and optimized for high developer productivity by ML practitioners
上述來源于FX的論文载迄,感興趣的可以看TORCH.FX: PRACTICAL PROGRAM CAPTURE AND TRANSFORMATION FOR DEEP LEARNING IN PYTHON這篇讯柔,知乎上也有一篇不錯的解讀,這里就不復述了护昧。不過本文也會介紹論文中的內容魂迄,更多的是以實踐的角度。

核心的關鍵詞是program capturetransformation library惋耙,這兩個概念很重要捣炬。

那么FX怎么用呢?直觀了解一下绽榛,我們定義了一個pytorch.nn.module

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

很簡單地繼承于torch.nn.Module的Module(熟悉pytorch的應該都懂)湿酸。其中前向forward函數也記錄了這個module的具體操作邏輯。

如果我們想把這個Module中forward中的一部分操作邏輯self.linear(x + self.param).clamp(min=0.0, max=1.0)clamp部分替換為sigmoid灭美,應該怎么搞呢推溃?

當然可以直接改代碼么,但是如果這些操作很多届腐,或者說你寫了很多模塊铁坎,或者說你想要做很多實驗(某些模塊中改某些模塊中不改),再這樣就比較煩瑣了犁苏。

這時候就需要FX硬萍,不需要我們手動修改代碼(就是自己改這個forward實現),只需要設定好規(guī)則围详,使用torch.fx朴乖,帶入這個模型實例進去,跑一下代碼短曾。然后你的這個MyModule中forward部分就會變?yōu)?code>self.linear(x + self.param).sigmoid():

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
# 打印查看FX的IR
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
# 通過FX生成的代碼寒砖,可以視為module中的forward代碼
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

這樣,FX會幫助你修改這個Module嫉拐,并且修改好的這個model就和平常一樣使用就可以哩都,注意這里,FX capture 了你寫的forward代碼婉徘,然后進行了transform漠嵌,修改了其中的操作咐汞。

當然這只是很簡單很簡單的fx的一個功能,我們還可以通過fx:

  • 融合兩個op儒鹿,比如conv和bn
  • 去掉某些op
  • 替換某些op
  • 在某些op后插入一些op或者其他操作

等等等等化撕。

可能大家會疑惑,這些操作是不是很像AI編譯器中的PASS约炎,而操作對象也是神經網絡這種DAG(有向無環(huán)圖)植阴。其實吧,FX你也可以理解為是一種編譯器圾浅,不過這個編譯器最終產生的可執(zhí)行文件掠手,而是python->python,最終的產物還是基于Pytorch規(guī)則的python代碼狸捕,也就是為什么FX一直說自己是Python-to-Python (or Module-to-Module) transformation toolkit而不是compiler了喷鸽。

FX目前大部分API已經穩(wěn)定(在torch-1.10中正式發(fā)布),使用起來歷史包袱不大灸拍。

fx的官方介紹:

torch.fx與量化的關系

FX的出現第一利好是基于Pytorch的量化工具做祝,這也是我介紹FX的一個原因。借助FX可以很方便地對pytorch模型做量化操作鸡岗,之前商湯就出了一個基于fx的量化工具MQBench混槐。

對于量化來說,不論是PTQ(需要插入觀察op來收集每一層的激活分布以及權重分布)還是QTA(需要插入fake量化節(jié)點來模擬量化)纤房,都會涉及到fx的功能纵隔。所以如果想基于Pytorch框架來做量化翻诉,建議直接上手torch.fx炮姨。

fx在pytorch-1.10中已經處于stable狀態(tài),大部分API已經穩(wěn)定了碰煌,我也拿torch.fx量化了幾個模型舒岸,最終搞到TensorRT上,涉及到卷積芦圾、BN蛾派、反卷積、add逆趋、concat等基本操作共屈,使用的版本是Pytorch-1.10TensorRT-8.2狂塘。

其中fx部分自己修改了下源碼,補充了一些op壳澳。這里我是直接把最新release的pytorch中的fx部分摘出來,然后pip安裝torch-1.10.0+cu113-cp38-cp38-linux_x86_64.whl茫经,兩者搭配食用巷波。

與TorchScript的區(qū)別

其實一開始torch.fx出現的時候也想過這兩個有啥區(qū)別萎津,都是先解析模型、然后生成IR抹镊、然后基于IR做一些優(yōu)化锉屈,最后生成一個最終版的優(yōu)化后的模型,難道一個是python版本的一個是C++版垮耳?肯定沒有這么簡單颈渊。當你FX用多了,會發(fā)現FX和torchscript的定位是不一樣的终佛,FX更側重于對模型進行一些功能性的改變(比如批量增加儡炼、修改某個操作,比如增加統(tǒng)計操作查蓉,比如量化)乌询;而torchscript更側重于優(yōu)化當前模型的性能,并且可以脫離python豌研,僅在C++環(huán)境運行妹田。

借一句官方大佬的回答:

torch.fx is different from TorchScript in that it is a platform for Python-to-Python transformations of PyTorch code. TorchScript, on the other hand, is more targeted at moving PyTorch programs outside of Python for deployment purposes. In this sense, FX and TorchScript are orthogonal to each other, and can even be composed with each other (e.g. transform PyTorch programs with FX, then subsequently export to TorchScript for deployment).

大意就是,FX僅僅是做Python2Python的轉換鹃共,不像Torchscript一樣是為了做部署(脫離Python這個環(huán)境鬼佣,在C++中運行)而做轉換。兩者沒什么關系霜浴,不沖突晶衷,用FX轉換后的模型也可以用torchscript繼續(xù)轉換,兩者是正交的阴孟。

Python to Python?

不過需要注意的是晌纫,FX的代碼生成式由Python到Python。也就是說永丝,FX生成的代碼锹漱,和我們平常使用nn.Module搭建的網絡沒區(qū)別,可以直接使用Pytorch的eager mode跑慕嚷,不像torchscript一樣哥牍,是另一套runtime(我們跑torchscript的時候其實調用的是一個VM,也就是虛擬機喝检,通過VM在C++中跑通過torchscript導出的模型)嗅辣。

因此fx轉換后的模型類型和nn.Module一毛一樣,所以對nn.Module能做的挠说,對轉換后的模型也能做澡谭,咱們可以連續(xù)套娃:

  • 自己寫的Module -> fx后還是Module -> 連續(xù)fx變化 -> 得到最終的fx模型

FX的IR和Jit的IR

這倆IR不一樣,FX的IR相較Jit的來說纺涤,有兩個優(yōu)點:

  • FX緊密地整合到Python的runtime中译暂,因為FX可以更加精準地捕獲prograim representations抠忘,不像jit.trace有時候會出錯。
  • FX的Graph和torch.nn.module沒啥區(qū)別外永,其IR沒有那么底層崎脉,所以說用起來更簡單,效率也會提升伯顶。

這里簡單列一下FX的IR囚灼,很簡單,只有六種祭衩,大概功能就是調函數灶体、提取attr、獲取輸入輸出等:

  • placeholder represents a function input. The name attribute specifies the name this value will take on. target is similarly the name of the argument. args holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. kwargs is don't-care. Placeholders correspond to the function parameters (e.g. x) in the graph printout.
  • get_attr retrieves a parameter from the module hierarchy. name is similarly the name the result of the fetch is assigned to. target is the fully-qualified name of the parameter's position in the module hierarchy. args and kwargs are don't-care
  • call_function applies a free function to some values. name is similarly the name of the value to assign to. target is the function to be applied. args and kwargs represent the arguments to the function, following the Python calling convention
  • call_module applies a module in the module hierarchy's forward() method to given arguments. name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. args and kwargs represent the arguments to invoke the module on, including the self argument.
  • call_method calls a method on a value. name is as similar. target is the string name of the method to apply to the self argument. args and kwargs represent the arguments to invoke the module on, including the self argument
  • output contains the output of the traced function in its args[0] attribute. This corresponds to the "return" statement in the Graph printout.

相比torchscript的IR掐暮,FX的可就簡單多了蝎抽,我們理解使用起來也很簡單。

symbolic tracer

回到一開頭示例的那段代碼路克,其中有一行是symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)樟结,這里核心就是symbolic_trace函數,也就是FX解析精算、轉換模型的起點瓢宦。這個函數其實內部是這樣的:

@compatibility(is_backward_compatible=True)
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None,
                   enable_cpatching: bool = False) -> GraphModule:
    """
    Symbolic tracing API

    Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
    constructed by recording operations seen while tracing through ``root``.

    ...
    """
    tracer = Tracer(enable_cpatching=enable_cpatching)
    graph = tracer.trace(root, concrete_args)
    name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
    return GraphModule(tracer.root, graph, name)

首先會創(chuàng)建一個Tracer類然后使用成員函數trace我們的torch.nn.Module。我們在trace這個模型之后灰羽,就可以對這個模型進行修改了:

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`
    # 使用 Tracer 類對象去trace模型 m
    # 這邊是拆開了驮履,這個transform函數就是實現torch.fx.symbolic_trace的功能
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: 這里就可以任意修改模型了,這也是重點
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

修改之后的模型可以直接拿來用廉嚼,也可以通過graph_module.to_folder玫镐,把這個模型摘出來當做單獨的模塊去使用(這個之后說)。整體的流程大概就是這樣:

symbolic tracing -> intermediate representation -> transforms -> Python code generation前鹅。

各自的功能為:

  • symbolic

The symbolic tracer performs “symbolic execution” of the Python code. It feeds fake values, called Proxies, through the code. Operations on theses Proxies are recorded. More information about symbolic tracing can be found in the symbolic_trace() and Tracer documentation.

  • intermediate representation

The intermediate representation is the container for the operations that were recorded during symbolic tracing. It consists of a list of Nodes that represent function inputs, callsites (to functions, methods, or torch.nn.Module instances), and return values. More information about the IR can be found in the documentation for Graph. The IR is the format on which transformations are applied.

  • Python code generation

Python code generation is what makes FX a Python-to-Python (or Module-to-Module) transformation toolkit. For each Graph IR, we can create valid Python code matching the Graph’s semantics. This functionality is wrapped up in GraphModule, which is a torch.nn.Module instance that holds a Graph as well as a forward method generated from the Graph.

上述就是FX的三個核心功能摘悴。

Proxy/Retracingsymbolic trace的核心。因為我對Proxy/Retracing的理解還不是很深舰绘,這里就不擅自描述了,摘一下官方的介紹:

Proxy objects are Node wrappers that flow through the program during symbolic tracing and record all the operations (torch function calls, method calls, operators) that they touch into the growing FX Graph.

If you’re doing graph transforms, you can wrap your own Proxy method around a raw Node so that you can use the overloaded operators to add additional things to a Graph.

相關結構

FX主要的結構就是GraphGraphModule了葱椭,其中A Graph is a data structure that represents a method on a GraphModule捂寿。可以理解為Graph中存放著網絡中最關鍵的Node孵运,這些node就是網絡中的一個個節(jié)點(比如卷積秦陋、relu、add治笨、concat等等)驳概,這些node記錄了對應的method和輸入輸出信息赤嚼,從而可以串起來組成網絡的邏輯。

通過print_tabular()可以將graph中的node信息打印出來:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)
# 這里打印module中的node
gm.graph.print_tabular()

打印信息如下:

graph中的node

可以看到顺又,對于輸入x更卒,對應的IR類型是placeholder;對于權重信息稚照,對應的IR類型是get_attr蹂空;對于具體的實際操作(add、linear果录、sum上枕、relu、topk等)弱恒,對應著call_function辨萍、call_module這倆IR,最后的輸出對應著output這個IR返弹。

同時還打印了每個node的輸入信息和額外的參數信息分瘦,通過這些信息就可以把node連起來。

不過光有graph是不夠的琉苇,還需要GraphModule嘲玫。GraphModule繼承于torch.nn.Module,包含了前向forward函數和網絡中模塊需要的參數并扇,這些參數會被graph中的node調用去团。

總結一下,那就是graph中的node包含了網絡的邏輯信息穷蛹,然后這些node前后調用關系會被FX重新組合為GraphModule中FX生成的前向forward代碼(可以通過traced.code打印出來)土陪,而這些生成的代碼會需要GraphModule中的參數信息來保證順利執(zhí)行。

修改Graph

既然知道graph中包含了網絡的順序執(zhí)行信息肴熏,那么想要修改網絡鬼雀,直接修改node就可以:

import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # 對于graph中的node,FX會以順序的形式來表示這個網絡
    # 所以我們可以直接for循環(huán)來遍歷:
    for node in graph.nodes:
        # 檢測該node的IR類型是否是call_function
        if node.op == 'call_function':
            # 修改node.target為torch.mul蛙吏,網絡也因此變了
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)

簡單提一句源哩,node.target代表call_function中call的是哪個target,而torch.add也就是pytorch自帶的操作op鸦做,調用這個node的時候會實際調用到torch.add励烦。

優(yōu)雅地修改graph網絡

上述直接修改簡單粗暴,FX也貼心地為我們提供了Graph rewrites工具泼诱,我們可以借助這些工具方便地增加或者刪除某一個node:

# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))
    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

借助replace_pattern來修改網絡

Graph rewrites工具都有了(相關概念是來源于編譯器)坛掠,那么match pattern肯定也有了,我們可以通過replace_pattern()來對整個graph進行修改。pattern的話可以用fx自帶的也可以自己添加自己的規(guī)則:

# Sample module
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w1, w2):
        val1 = torch.neg(w1)
        m1 = torch.cat([val1, w2]).sum()
        val2 = torch.neg(w1)
        m2 = torch.cat([val2, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

# Symbolically trace an instance of `M`
traced = symbolic_trace(M())

# Define the pattern. 
def pattern(a1, a2):
    val1 = torch.neg(a1)
    return torch.cat([val1, a2]).sum()

# Define the replacement (same rules as the pattern)
def replacement(w1, w2):
    return torch.stack([w1, w2])

# Replace `pattern` with `replacement` in `traced`
replace_pattern(traced, pattern, replacement)

# After calling `replace_pattern`, the generated code is:
'''
def forward(self, x, w1, w2):
    stack = torch.stack([w1, w2])
    max_1 = torch.max(stack);  stack = None
    add = x + max_1;  x = max_1 = None
    stack_1 = torch.stack([w1, w2]);  w1 = w2 = None
    max_2 = torch.max(stack_1);  stack_1 = None
    add_1 = add + max_2;  add = max_2 = None
    return add_1
'''

Interpreter

Interpreter屉栓,即解釋器舷蒲,這個名字用的好。其實就是以一個比較優(yōu)雅的方式循環(huán)一個Graph的node并且執(zhí)行它們友多,并同時順帶完成一些任務牲平。比如我們想看下模型在運行期間每一層的shape變化:

import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

上述的propagate函數很簡單,遍歷一遍node然后記錄信息到node.shapenode.dtype中夷陋。FX也提供了interpreter類欠拾,存放了一些util的function,我們直接繼承就可以使用(類似于上面這個ShapeProp)骗绕。

Transformer

Transformer就是對torch.nn.Module做一些變換藐窄,這些變換我們可以封裝成一個函數或者寫到類里頭,其實Transformer也可以叫做PASS酬土,總之就是對網絡進行一些修改荆忍。比如這樣:

import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # trace nn.Module
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: 這里對Graph進行修改
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

Your transform will take in an torch.nn.Module, acquire a Graph from it, do some modifications, and return a new torch.nn.Module. You should think of the torch.nn.Module that your FX transform returns as identical to a regular torch.nn.Module – you can pass it to another FX transform, you can pass it to TorchScript, or you can run it. Ensuring that the inputs and outputs of your FX transform are a torch.nn.Module will allow for composability.

當然也可以直接修改GraphModule,沒必要非要返回一個新的:

import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # 這里修改 gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

需要注意撤缴,gm.recompile()這句必須是要加上的刹枉,我們修改graph后,需要recompile來重新生成forward代碼屈呕。

舉個FX的栗子

鋪墊了那么多微宝,簡單舉一個FX的實際例子吧,這里我們用FX去量化一個基于CenterNet框架的目標檢測模型虎眨,backbone使用的是Resnet50蟋软,限于篇幅,本篇只介紹trace完模型和fuse的部分嗽桩,量化和導出trt之后的文章再說岳守。

首先搭建CenterNet模型,然后進行trace:

model = FXCenterNet()
tracer = Tracer()
graph_module = GraphModule(model, tracer.trace(model))

其中trace的函數如下碌冶,大概就是遍歷model中的操作湿痢,按照規(guī)則轉換為node存放到graph中,包含attr和op扑庞、輸入輸出等信息譬重,最終返回graph這個IR結構:

@compatibility(is_backward_compatible=True)
def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
    # root FXCenterNet
    if isinstance(root, torch.nn.Module):
        self.root = root
        fn = type(root).forward
        self.submodule_paths = {mod: name for name, mod in root.named_modules()}
    else:
        self.root = torch.nn.Module()
        fn = root

    tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None)
    self.graph = Graph(tracer_cls=tracer_cls)
    # 這里大概就是遍歷root中的操作,按照規(guī)則轉換為node存放到graph中嫩挤,
    # 包含attr和op害幅、輸入輸出等信息,最終返回graph這個IR結構
    ... 
    return self.graph

生成的self.graph類型是torch.fx.graph.Graph岂昭。

self.graph
<torch.fx.graph.Graph object at 0x7f57f59efdf0>

調用self.graph.print_tabular()打印graph的node信息,可以看到熟悉的resnet-50-backbone的結構,以IR的形式組織起來:

生成centernet-graph中的node信息

生成graph后约啊,開始組裝GraphModule邑遏,GraphModule是由graph生成的,GraphModule會把graph的node中的參數和模塊信息復制一份到自己:

@compatibility(is_backward_compatible=True)
class GraphModule(torch.nn.Module):
    def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
        for t in cls.__mro__:
            c = t.__qualname__.split('.')[-1]
            if c != 'GraphModuleImpl':
                cls = t
                break

        class GraphModuleImpl(cls):  # type: ignore[misc, valid-type]
            pass
        return super().__new__(GraphModuleImpl)

    @compatibility(is_backward_compatible=True)
    def __init__(self,
                 root: Union[torch.nn.Module, Dict[str, Any]],
                 graph: Graph,
                 class_name: str = 'GraphModule'):
        super().__init__()
        self.__class__.__name__ = class_name
        if isinstance(root, torch.nn.Module):
            if hasattr(root, 'training'):
                self.training = root.training
            # 這里拷貝graph中的參數信息和模塊信息到self也就是GraphModule中
            for node in graph.nodes:
                if node.op in ['get_attr', 'call_module']:
                    assert isinstance(node.target, str)
                    _copy_attr(root, self, node.target)
        elif isinstance(root, dict):
            targets_to_copy = []
            for node in graph.nodes:
                if node.op in ['get_attr', 'call_module']:
                    assert isinstance(node.target, str)
                    if node.target not in root:
                        raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target +
                                           ' but that target was not provided in ``root``!')
                    targets_to_copy.append(node.target)
            targets_to_copy.sort(key=lambda t: t.count('.'))
            for target_to_copy in targets_to_copy:
                _assign_attr(root[target_to_copy], self, target_to_copy)
        else:
            raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!')

        self.graph = graph
        self._tracer_cls = None
        if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
            self._tracer_cls = self.graph._tracer_cls
    __jit_unused_properties__ = ['graph']

最終graph_module中包含了生成的代碼恰矩,通過print(graph_module.code)打印出來:

def forward(self, input):
    input_1 = input
    upsampler_deconv_layers_0_bias = getattr(self.upsampler.deconv_layers, "0").bias
    ...
    head_angle_0 = getattr(self.head.angle, "0")(upsampler_deconv_layers_11);  upsampler_deconv_layers_11 = None
    head_angle_1 = getattr(self.head.angle, "1")(head_angle_0);  head_angle_0 = None
    head_angle_2 = getattr(self.head.angle, "2")(head_angle_1);  head_angle_1 = None
    return {'hm': head_hm_2, 'wh': head_wh_2, 'reg': head_reg_2, 'angle': head_angle_2}

這個時候我們就有了trace后的Module记盒,這個Module和原始模型并沒有區(qū)別,forward函數也是按照原始模型的forward生成的外傅。因為我們只是簡單地trace了一遍纪吮,所以相同輸入結果也是一樣的:graph_module(input) == original_model(input),畢竟沒干啥特殊的萎胰。

OP融合

接下來就是fuse碾盟,這里直接調用FX提供的fuse函數,其實里頭也就是調用了Fuser


def _fuse_fx(
    graph_module: GraphModule,
    is_qat: bool,
    fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
    backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
    r""" Internal helper function to fuse modules in preparation for quantization

    Args:
        graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
    """
    _check_is_graph_module(graph_module)
    fuser = Fuser()
    return fuser.fuse(
        graph_module, is_qat, fuse_custom_config_dict, backend_config_dict)

來看看Fuser都干了啥技竟,其實很簡單冰肴,就是遍歷一遍input_graph = model.graph中的node,然后根據指定好的fuse規(guī)則進行融合榔组,融合會涉及到修改graph結構:

class Fuser:
    def fuse(
        self,
        model: GraphModule,
        is_qat: bool,
        fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
        backend_config_dict: Optional[Dict[str, Any]] = None,
    ) -> GraphModule:
        if fuse_custom_config_dict is None:
            fuse_custom_config_dict = {}

        input_root = model
        input_graph = model.graph
        # 這里首先copy 原始模型中的named_modules中熙尉,之后會根據fuse情況進行修改
        self.modules = dict(input_root.named_modules())  
        ... 
        # 這里查找匹配的fuse pattern
        fusion_pairs = self._find_matches(
            input_root, input_graph, fusion_pattern_to_fuse_handler_cls)
        self.fused_graph = Graph()
        env: Dict[Any, Any] = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        def get_root_node(node_pattern):
            while not isinstance(node_pattern[-1], Node):
                node_pattern = node_pattern[-1]
            return node_pattern[-1]

        for node in input_graph.nodes:
            maybe_last_node, pattern, matched_node_pattern, obj = \
                fusion_pairs.get(node.name, (None, None, None, None))
            if maybe_last_node is node:
                assert obj is not None
                # TODO: currently we hard code the root node, which only works for
                # a sequence of ops and assume the root node is the last node,
                # we want to make this more general to support more complex patterns
                root_node = get_root_node(matched_node_pattern)  # 尋找fuse的根node
                env[node.name] = obj.fuse( # 這里將self傳入,對self進行修改
                    self, load_arg, root_node, matched_node_pattern,  # type: ignore[arg-type]
                    fuse_custom_config_dict, fuser_method_mapping, is_qat)
            elif maybe_last_node is None:
                env[node.name] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        preserved_attributes = set(fuse_custom_config_dict.get("preserved_attributes", []))
        model = FusedGraphModule(input_root, self.fused_graph, preserved_attributes)
        return model

    def _find_matches(
            self, root: GraphModule, graph: Graph,
            patterns: Dict[Pattern, Callable]
    ) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]]:
        modules = dict(root.named_modules())
        match_map : Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]] = {}  # node name -> (root_node, match_value)

        def apply_match(pattern, node, match, matched_node_pattern):
            if isinstance(pattern, tuple):
                s, *args = pattern
                current_node_pattern: List[Node] = []
                apply_match(s, node, match, current_node_pattern)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match, current_node_pattern)
                matched_node_pattern.append(tuple(current_node_pattern))
            else:
                # the first pattern matches will take precedence
                if node.name not in match_map:
                    matched_node_pattern.append(node)
                    root_node, pattern, handler = match
                    match_map[node.name] = (root_node, pattern, matched_node_pattern, handler)
        # 這里就是match過程
        for node in reversed(graph.nodes):
            if node.name not in match_map:
                for pattern, value in patterns.items():
                    matched_node_pattern: List[Node] = []
                    if is_match(modules, node, pattern):
                        apply_match(pattern, node, (node, pattern, value(self, node)), matched_node_pattern)

        return match_map

至于定義了哪些fuse的規(guī)則搓扯,可以在pytorch/torch/ao/quantization/fx/fusion_patterns.py這里頭找到:

# /ao/quantization/fx/fusion_patterns.py
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
class DefaultFuseHandler(FuseHandler):
    def __init__(
            self,
            quantizer: QuantizerCls,
            node: Node):
        super().__init__(quantizer, node)

    def fuse(...):
        # 這里執(zhí)行實際的融合操作

具體的融合操作在DefaultFuseHandler類中的fuse方法內執(zhí)行检痰,找到對應的fuser_method,然后調用锨推,返回融合后的fused_module使用setattr來修改網絡的modules铅歼,同樣也會通過node_copy修改graph中的node:

matched_module_types = get_matched_types(matched_modules)
module_parent_name, module_name = _parent_name(root_node.target)
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
# TODO: change the signature for fuser_method to take matched module patterns
# as input
fused_module = fuser_method(is_qat, *matched_modules)
# TODO: maybe add a pass to cleanup bn modules?
setattr(quantizer.modules[module_parent_name], module_name, fused_module) # 往fuse控制的新模型中加入 新的modules
return quantizer.fused_graph.node_copy(root_node, load_arg)                # 往fuse控制的新graph中加入forward參數

其中,Conv+bn+relu的融合細節(jié)會調用pytorch/torch/ao/quantization/fuser_method_mappings.py中的fuse_conv_bn_relu函數:

def fuse_conv_bn_relu(is_qat, conv, bn, relu):
    assert(conv.training == bn.training == relu.training),\
        "Conv and BN both must be in the same mode (train or eval)."
    fused_module : Optional[Type[nn.Sequential]] = None
    map_to_fused_module_eval = {
        nn.Conv1d: nni.ConvReLU1d,
        nn.Conv2d: nni.ConvReLU2d,
        nn.Conv3d: nni.ConvReLU3d,
    }
    fused_module = map_to_fused_module_eval.get(type(conv), None)
    if fused_module is not None:
        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
        return fused_module(fused_conv, relu)
    else:
        raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))

上述fused_moduletorch.nn.intrinsic.modules.fused.ConvReLU2d類爱态,會調用fuse_conv_bn_eval來實際吸bn到conv:

def fuse_conv_bn_eval(conv, bn, transpose=False):
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)

    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)

    return fused_conv

def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False):
    if conv_b is None:
        conv_b = torch.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = torch.ones_like(bn_rm)
    if bn_b is None:
        bn_b = torch.zeros_like(bn_rm)
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)

    if transpose:
        shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
    else:
        shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)

    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape)
    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b

    return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)

吸完后谭贪,得到新的conv,然后帶入ConvReLU2d類锦担。

class ConvReLU2d(_FusedModule):
    r"""This is a sequential container which calls the Conv2d and ReLU modules.
    During quantization this will be replaced with the corresponding fused module."""
    def __init__(self, conv, relu):
        assert type(conv) == Conv2d and type(relu) == ReLU, \
            'Incorrect types for input modules{}{}'.format(
                type(conv), type(relu))
        super().__init__(conv, relu)

整體流程就是conv + bn->conv然后conv + relu -> ConvReLU2d俭识。

fuse后的code就清爽很多了,bn和relu都被融合進去了(當然還有其他融合):

def forward(self, input):
    input_1 = input
    backbone_conv1 = self.backbone.conv1(input_1)
    backbone_maxpool = self.backbone.maxpool(backbone_relu)
    backbone_layer1_0_conv1 = getattr(self.backbone.layer1, "0").conv1(backbone_maxpool)
    backbone_layer1_0_conv2 = getattr(self.backbone.layer1, "0").conv2(backbone_layer1_0_relu)
    backbone_layer1_0_conv3 = getattr(self.backbone.layer1, "0").conv3(backbone_layer1_0_relu_1)
    ...
    head_reg_0 = getattr(self.head.reg, "0")(upsampler_deconv_layers_11)
    head_reg_2 = getattr(self.head.reg, "2")(head_reg_1)
    head_angle_0 = getattr(self.head.angle, "0")(upsampler_deconv_layers_11)
    head_angle_2 = getattr(self.head.angle, "2")(head_angle_1)
    return {'hm': head_hm_2, 'wh': head_wh_2, 'reg': head_reg_2, 'angle': head_angle_2}

至此洞渔,就得到了trace后和fuse后的模型套媚,可以看到融合后的ConvReLU2d模塊。

trace后以及fuse后的module

這個GraphModuletorch.nn.module的使用方式一模一樣磁椒,可以簡單輸入一個image驗證一下堤瘤。

下一篇中我們會對這個GraphModule進行量化操作。

如何debug

那么我們得到了最終的GraphModule浆熔,該如何debug呢本辐,也就是一步一步單獨調試。這也是有辦法的,調試fx生成model的方式有三種:

直接通過pdb進行debug

我們是可以進入FX的Generated Code中的慎皱,也可以設置斷點:

FX生成的代碼是可以debug進去的

打印生成的代碼老虫,并且和Module組合

因為graph中的node包含了指定邏輯,GraphModule中包含了模型權重等信息茫多,而這些權重信息是通過原始的Module獲取的祈匙,那么我們可以直接將生成的code放到原始Module子類的forward中,組成一個新的Module來調用天揖。

# Assume that `traced` is a GraphModule that has undergone some
# number of transforms

# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
    x = self.x
    add_1 = x + y;  x = y = None
    return add_1
"""

# 這里繼承原始的Module
class SubclassM(M):
    def __init__(self):
        super().__init__()

    # 把生成的代碼粘到這里
    def forward(self, y):
        x = self.x
        add_1 = x + y;  x = y = None
        return add_1

# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()

是不是很符合常識夺欲!

使用to_folder函數

就像之前例子里說到的,GraphModule.to_folder()是一個神奇的函數今膊,可以直接將FX生成的module導出為一個文件夾些阅,文件夾中包含了模型需要的參數(.pt格式)和模型的定義。

FX代碼導出fold

module.py的代碼也幫你生成好了:

# 導出的module.py中
import torch
from torch.nn import *
class FusedModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # 這里加載權重參數信息
        self.backbone = torch.load(r'fx_debug/backbone.pt')
        self.load_state_dict(torch.load(r'fx_debug/state_dict.pt'))
        ...

    def forward(self, input):
        # 這里就是生成的code部分万细,也幫你寫到forward中了
        input_1 = input
        backbone_conv1 = self.backbone.conv1(input_1)
        backbone_maxpool = self.backbone.maxpool(backbone_relu)
        backbone_layer1_0_conv1 = getattr(self.backbone.layer1, "0").conv1(backbone_maxpool)
        ...
        head_angle_0 = getattr(self.head.angle, "0")(upsampler_deconv_layers_11)
        head_angle_2 = getattr(self.head.angle, "2")(head_angle_1)
        return {'hm': head_hm_2, 'wh': head_wh_2, 'reg': head_reg_2, 'angle': head_angle_2}

是不是很強大扑眉?!

我們也可以修改這個生成的代碼來做其他方面的實驗(不過這個導出有一些bug赖钞,不知道是不是我使用姿勢不對)腰素。

一些限制

torch.fx也是有一些限制的(畢竟不可能十全十美)。

因為Symbolic execution的限制雪营。
Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

The main limitation of symbolic tracing is it does not currently support dynamic control flow. That is, loops or if statements where the condition may depend on the input values of the program.

更詳細的限制可以看官方的介紹:

就先寫到這里吧弓千,關于FX的功能使用更多是在量化過程中體現了,下一篇的量化實操中献起,會結合量化過程來理解FX洋访,同時也會總結下PTQ量化的流程和注意點,我是老潘谴餐,我們下一篇再見~

撩我吧

  • 如果你與我志同道合于此姻政,老潘很愿意與你交流!
  • 如果你喜歡老潘的內容岂嗓,歡迎關注和支持~
  • 如果有問題想要聯(lián)系我汁展,可加公眾號直接私信,

    厌殉!

參考鏈接

撩我吧

  • 如果你與我志同道合于此食绿,老潘很愿意與你交流;
  • 如果你喜歡老潘的內容公罕,歡迎關注和支持器紧。
  • 如果你喜歡我的文章,希望點贊?? 收藏 ?? 評論 ?? 三連一下~

想知道老潘是如何學習踩坑的楼眷,想與我交流問題~請關注公眾號「oldpan博客」铲汪。
老潘也會整理一些自己的私藏熊尉,希望能幫助到大家,點擊神秘傳送門獲取桥状。

?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
  • 序言:七十年代末帽揪,一起剝皮案震驚了整個濱河市硝清,隨后出現的幾起案子辅斟,更是在濱河造成了極大的恐慌,老刑警劉巖芦拿,帶你破解...
    沈念sama閱讀 218,122評論 6 505
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件士飒,死亡現場離奇詭異,居然都是意外死亡蔗崎,警方通過查閱死者的電腦和手機酵幕,發(fā)現死者居然都...
    沈念sama閱讀 93,070評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來缓苛,“玉大人芳撒,你說我怎么就攤上這事∥辞牛” “怎么了笔刹?”我有些...
    開封第一講書人閱讀 164,491評論 0 354
  • 文/不壞的土叔 我叫張陵,是天一觀的道長冬耿。 經常有香客問我舌菜,道長,這世上最難降的妖魔是什么亦镶? 我笑而不...
    開封第一講書人閱讀 58,636評論 1 293
  • 正文 為了忘掉前任日月,我火速辦了婚禮,結果婚禮上缤骨,老公的妹妹穿的比我還像新娘爱咬。我一直安慰自己,他們只是感情好绊起,可當我...
    茶點故事閱讀 67,676評論 6 392
  • 文/花漫 我一把揭開白布精拟。 她就那樣靜靜地躺著,像睡著了一般勒庄。 火紅的嫁衣襯著肌膚如雪串前。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,541評論 1 305
  • 那天实蔽,我揣著相機與錄音荡碾,去河邊找鬼。 笑死局装,一個胖子當著我的面吹牛坛吁,可吹牛的內容都是我干的劳殖。 我是一名探鬼主播,決...
    沈念sama閱讀 40,292評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼拨脉,長吁一口氣:“原來是場噩夢啊……” “哼哆姻!你這毒婦竟也來了?” 一聲冷哼從身側響起玫膀,我...
    開封第一講書人閱讀 39,211評論 0 276
  • 序言:老撾萬榮一對情侶失蹤矛缨,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后帖旨,有當地人在樹林里發(fā)現了一具尸體箕昭,經...
    沈念sama閱讀 45,655評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 37,846評論 3 336
  • 正文 我和宋清朗相戀三年解阅,在試婚紗的時候發(fā)現自己被綠了落竹。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 39,965評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡货抄,死狀恐怖述召,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情蟹地,我是刑警寧澤积暖,帶...
    沈念sama閱讀 35,684評論 5 347
  • 正文 年R本政府宣布,位于F島的核電站锈津,受9級特大地震影響呀酸,放射性物質發(fā)生泄漏。R本人自食惡果不足惜琼梆,卻給世界環(huán)境...
    茶點故事閱讀 41,295評論 3 329
  • 文/蒙蒙 一性誉、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧茎杂,春花似錦错览、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,894評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至刽脖,卻和暖如春羞海,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背曲管。 一陣腳步聲響...
    開封第一講書人閱讀 33,012評論 1 269
  • 我被黑心中介騙來泰國打工却邓, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人院水。 一個月前我還...
    沈念sama閱讀 48,126評論 3 370
  • 正文 我出身青樓腊徙,卻偏偏與公主長得像简十,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子撬腾,可洞房花燭夜當晚...
    茶點故事閱讀 44,914評論 2 355

推薦閱讀更多精彩內容