Introduction
由于最近項(xiàng)目需要筐赔,研究和學(xué)習(xí)PyTorch PTQ和QAT 量化的使用吟宦。比較新的PyTorch版本目前推薦使用FX Graph Mode Quantization。
FX Graph 模式量化Demo演示使用
Post-Training-Quantization (PTQ) 靜態(tài)量化的主要流程:
PyTorch FX Graph模式進(jìn)行量化的主要流程 step1 ~ step4:
- step1: 設(shè)置壁肋,選擇量化方式 : 比如逐通道/layer QScheme, 量化之后的值域表示范圍(Qmin, Qmax)
- step2: prepare_fx:
* a) 將輸入的模型(nn.Module)轉(zhuǎn)為GraphModule (IR轉(zhuǎn)換)
* b) Graph子圖钝荡,op融合(比如conv+relu --> convReLu)
* c) 在Conv遣鼓, Linear等OP前后插入Observer, 用于收集激活值Feature map的特征(范圍) - step3: 喂數(shù)據(jù)刻撒,進(jìn)行Activation標(biāo)定
- step4: 計(jì)算Weight和Activation量化參數(shù) (比如 scale, zero_point), 模型FP32 --> INT8
- step5: 驗(yàn)證INT8 量化之后模型的精度
from ctypes import util
from torchvision.models import resnet18, resnet50
import torch
from torch.ao.quantization import quantize_fx, get_default_qconfig
import os
import copy
import utils
def calibrate(model, data_loader, num_batch, device):
utils.evaluate(model=model, data_loader=data_loader, neval_batches=num_batch, n_print=1, device=device)
if __name__ == '__main__':
device = torch.device('cuda', 0)
eval_batch_size = 32
imagenet_data='/media/wei/Document/ImageNet/ILSVRC2012'
model_fp = resnet50(pretrained=True, progress=True).to(device)
model_fp.eval()
_, test_dataloader = utils.prepare_dataloader(data_path=imagenet_data, eval_batch_size=eval_batch_size, num_workers=8)
utils.evaluate(model=model_fp, criterion=None, data_loader=test_dataloader, device=device)
# ResNet-18: Tested on imagenet-val: batch:3125 Acc@1 56.25 ( 69.76), Acc@5 75.00 ( 89.08)
# ResNet-50: batch:1560 Acc@1 59.38 ( 76.18), Acc@5 90.62 ( 92.87)
# torch quantization
model_prepare = copy.deepcopy(model_fp)
model_prepare.eval()
# 設(shè)置量化方式
qconfig_dict = {"", get_default_qconfig(backend='fbgemm')}
model_prepare = quantize_fx.prepare_fx(model=model_prepare, qconfig_dict=qconfig_dict)
model_prepare.eval()
# 標(biāo)定,確定Activation的量化范圍
calibrate(model_prepare, test_dataloader, 10, device)
# 根據(jù)之前設(shè)置的量化方式以及標(biāo)定計(jì)算的參數(shù)耿导, 進(jìn)行模型轉(zhuǎn)換声怔, FP32--> INT8
quantized_model = quantize_fx.convert_fx(graph_module=model_prepare)
quantized_model.eval()
# 測(cè)試量化之后模型的精度
utils.evaluate(quantized_model, data_loader=test_dataloader)
得益于PyTorch FX Graph Quantization API的精簡(jiǎn)設(shè)計(jì), 我們只需要很少的代碼以及修改就可以實(shí)現(xiàn)量化舱呻, 激動(dòng)4谆稹!箱吕!芥驳, 接下來我們一探FX Graph 量化背后的具體實(shí)現(xiàn)原理。
下面逐一分析FX Graph 量化的過程
PyTorch FX Graph量化——Step1. 量化方式的配置選擇
這里是pytorch默認(rèn)的PTQ量化配置茬高, 'fbgemm' --- 這是一個(gè)矩陣計(jì)算的庫兆旬,支持server 端x86 CPU 的 Int8 Conv, Linear等OP怎栽。
qconfig_dict = {"", get_default_qconfig(backend='fbgemm')}
def get_default_qconfig(backend='fbgemm'):
"""
Returns the default PTQ qconfig for the specified backend.
Args:
* `backend`: a string representing the target backend. Currently supports `fbgemm`
and `qnnpack`.
Return:
qconfig
"""
if backend == 'fbgemm':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
weight=default_per_channel_weight_observer)
elif backend == 'qnnpack':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
weight=default_weight_observer)
else:
qconfig = default_qconfig
return qconfig
default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
)
我們發(fā)現(xiàn)qconfig包含2部分: 分別對(duì)weight, 以及activation的量化方式的配置丽猬, 其中 activation采用 HistogramObserver
基于直方圖統(tǒng)計(jì)的逐tensor/layer非對(duì)稱量化方式, Weight采用PerChannelMinMaxObserver
逐channel對(duì)稱量化方式熏瞄。
Why ? 為什么Activation和Weight的量化方式不同脚祟?
- Weight的量化方式:
- weight中元素的分布和activation有所不同: 因?yàn)閣eight一般都是均值為0, 左右對(duì)稱的Gaussian分布强饮, 因此采用對(duì)稱量化
- 為了減少量化OP中的計(jì)算量由桌, 因?yàn)閷?duì)稱量化的zero_point=0
參考高通AI的量化白皮書介紹:
Observer的作用
總的來說Observer是用于觀測(cè)數(shù)據(jù)分布, 計(jì)算量化參數(shù) scale, zero_point. 接下來從代碼進(jìn)行解析.
分析 PerChannelMinxMaxObserver
類
class PerChannelMinMaxObserver(_ObserverBase):
r"""Observer module for computing the quantization parameters based on the
running per channel min and max values.
This observer uses the tensor min/max statistics to compute the per channel
quantization parameters. The module records the running minimum and maximum
of incoming tensors, and uses this statistic to compute the quantization
parameters.
Args:
ch_axis: Channel axis
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
memoryless: Boolean that controls whether observer removes old data when a new input is seen.
This is most useful for simulating dynamic quantization, especially during QAT.
The quantization parameters are computed the same way as in
:class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference
that the running min/max values are stored per channel.
Scales and zero points are thus computed per channel as well.
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(
self,
ch_axis=0,
dtype=torch.quint8,
qscheme=torch.per_channel_affine,
reduce_range=False,
quant_min=None,
quant_max=None,
factory_kwargs=None,
memoryless=False,
) -> None:
super(PerChannelMinMaxObserver, self).__init__(
dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs,
)
self.memoryless = memoryless
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
self.ch_axis = ch_axis
self.register_buffer("min_val", torch.tensor([], **factory_kwargs))
self.register_buffer("max_val", torch.tensor([], **factory_kwargs))
if (
self.qscheme == torch.per_channel_symmetric
and self.reduce_range
and self.dtype == torch.quint8
):
raise NotImplementedError(
"Cannot reduce range for symmetric quantization for quint8"
)
def forward(self, x_orig):
return self._forward(x_orig)
def _forward(self, x_orig):
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach() # avoid keeping autograd tape
min_val = self.min_val
max_val = self.max_val
x_dim = x.size()
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
new_axis_list[self.ch_axis] = 0
new_axis_list[0] = self.ch_axis
y = x.permute(new_axis_list)
# Need to match dtype of min/max because the updates to buffers
# are done in place and types need to match for comparisons
y = y.to(self.min_val.dtype)
y = torch.flatten(y, start_dim=1)
if min_val.numel() == 0 or max_val.numel() == 0 or self.memoryless:
min_val, max_val = torch.aminmax(y, dim=1)
else:
min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
min_val = torch.min(min_val_cur, min_val)
max_val = torch.max(max_val_cur, max_val)
self.min_val.resize_(min_val.shape)
self.max_val.resize_(max_val.shape)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
return x_orig
@torch.jit.export
def calculate_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val)
為了計(jì)算量化所需的參數(shù)邮丰, PyTorch定義了一系列的Observer, 比如MinMaxObserver
, MovingAveragingMinMaxObserver
等等行您, 所有這些XXXObserver都繼承自一個(gè)基類,在基類的Observer中主要定義了以下2個(gè)重要的函數(shù):
我們發(fā)現(xiàn)Observer中主要的2個(gè)函數(shù):
- forward(self, x_orig): 觀測(cè)weight中元素的最大柠座,最小值
- calculate_qparams(self): 計(jì)算scale, zero_point
forward(self, x_orig)
函數(shù)的功能實(shí)現(xiàn):
- 輸入: x_orig: 也就是weight tensor, 一般CNN的weight的shape為: Oc * Ic * Kh * Kw 4D Tensor
- 輸出/結(jié)果: 觀測(cè)到的最大邑雅,最小值
在實(shí)例化Observer對(duì)象的時(shí)候, init() 函數(shù)中的一個(gè)參數(shù) ch_axis=0
用于指定channel維度妈经, ch_axis=0說明Observer觀測(cè)的是weight的 Oc (output_channels) 方向的最大和最小值淮野。 觀測(cè)最大捧书、最小值的核心代碼:
min_val, max_val = torch.aminmax(y, dim=1)
因?yàn)镺c的在axis=0的維度上, 因此aminmax(dim=1)對(duì)axis=1的維度上進(jìn)行了規(guī)約reduction骤星, 得到了Oc個(gè) min, max_val, 即Weight的每個(gè)output_channel包含一個(gè)scale, zero_point
def _forward(self, x_orig):
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach() # avoid keeping autograd tape
min_val = self.min_val
max_val = self.max_val
x_dim = x.size()
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
new_axis_list[self.ch_axis] = 0
new_axis_list[0] = self.ch_axis
y = x.permute(new_axis_list)
# Need to match dtype of min/max because the updates to buffers
# are done in place and types need to match for comparisons
y = y.to(self.min_val.dtype)
y = torch.flatten(y, start_dim=1)
if min_val.numel() == 0 or max_val.numel() == 0 or self.memoryless:
min_val, max_val = torch.aminmax(y, dim=1)
else:
min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
min_val = torch.min(min_val_cur, min_val)
max_val = torch.max(max_val_cur, max_val)
self.min_val.resize_(min_val.shape)
self.max_val.resize_(max_val.shape)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
return x_orig
calculate_qparams
函數(shù)的功能實(shí)現(xiàn)
很容易理解這個(gè)函數(shù)是用于計(jì)算量化參數(shù): scale & zero_point (對(duì)于線性量化)的经瓷, 下面分析代碼實(shí)現(xiàn):
- 輸入: 觀測(cè)得到的 max_val, min_val, 以及定義好的qmax, qmin
- 輸出: 計(jì)算得到的scale, zero_point
def _calculate_qparams(
self, min_val: torch.Tensor, max_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Calculates the quantization parameters, given min and max
value tensors. Works for both per tensor and per channel cases
Args:
min_val: Minimum values per channel
max_val: Maximum values per channel
Returns:
scales: Scales tensor of shape (#channels,)
zero_points: Zero points tensor of shape (#channels,)
"""
if not check_min_max_valid(min_val, max_val):
return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
quant_min, quant_max = self.quant_min, self.quant_max
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
if (
self.qscheme == torch.per_tensor_symmetric
or self.qscheme == torch.per_channel_symmetric
):
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
scale = torch.max(scale, self.eps)
if self.dtype == torch.quint8:
if self.has_customized_qrange:
# When customized quantization range is used, down-rounded midpoint of the range is chosen.
zero_point = zero_point.new_full(
zero_point.size(), (quant_min + quant_max) // 2
)
else:
zero_point = zero_point.new_full(zero_point.size(), 128)
elif self.qscheme == torch.per_channel_affine_float_qparams:
scale = (max_val - min_val) / float(quant_max - quant_min)
scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
# We use the quantize function
# xq = Round(Xf * inv_scale + zero_point),
# setting zero_point to (-1 * min *inv_scale) we get
# Xq = Round((Xf - min) * inv_scale)
zero_point = -1 * min_val / scale
else:
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.max(scale, self.eps)
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
# For scalar values, cast them to Tensors of size 1 to keep the shape
# consistent with default values in FakeQuantize.
if len(scale.shape) == 0:
# TODO: switch to scale.item() after adding JIT support
scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
if len(zero_point.shape) == 0:
# TODO: switch to zero_point.item() after adding JIT support
zero_point = torch.tensor(
[int(zero_point)], dtype=zero_point.dtype, device=device
)
if self.qscheme == torch.per_channel_affine_float_qparams:
zero_point = torch.tensor(
[float(zero_point)], dtype=zero_point.dtype, device=device
)
return scale, zero_point
計(jì)算量化參數(shù)Scale , zero_point的核心代碼
- 對(duì)稱量化 (symmetric Quantization)
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
- 非對(duì)稱量化 (Affine Quantization)
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.max(scale, self.eps)
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
以上分析了nn.Conv2d layer的weight的量化參數(shù)的計(jì)算過程以及PerChannelMinMaxObserver
的實(shí)現(xiàn)過程。下面繼續(xù)分析Activation的量化參數(shù)計(jì)算過程洞难。
Activation的量化參數(shù)計(jì)算以及HistgramObserver分析
在選擇量化設(shè)置的時(shí)候舆吮, 默認(rèn)的backend=fbgemm中Activation采用 HistogramObserver
, 即基于直方圖分析的方式計(jì)算量化參數(shù)。
if backend == 'fbgemm':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
weight=default_per_channel_weight_observer)
HistogramObserver
過程分析
- 初始化: init()
- 默認(rèn)bins=2048, 因?yàn)檫M(jìn)行直方圖統(tǒng)計(jì)需要設(shè)置一個(gè)bins代表直方圖的統(tǒng)計(jì)區(qū)間队贱,即把min_val到max_val區(qū)間劃分2048份色冀。
- qscheme=per_tensor_affine, 即量化粒度采用逐tensor/layer 仿射量化, 逐tensor代表只有一個(gè)量化參數(shù)scale + zero_point, 而不是一組
class HistogramObserver(_ObserverBase):
r"""
The module records the running histogram of tensor values along with
min/max values. ``calculate_qparams`` will calculate scale and zero_point.
Args:
bins: Number of bins to use for the histogram
upsample_rate: Factor by which the histograms are upsampled, this is
used to interpolate histograms with varying ranges across observations
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
The scale and zero point are computed as follows:
1. Create the histogram of the incoming inputs.
The histogram is computed continuously, and the ranges per bin change
with every new tensor observed.
2. Search the distribution in the histogram for optimal min/max values.
The search for the min/max values ensures the minimization of the
quantization error with respect to the floating point model.
3. Compute the scale and zero point the same way as in the
:class:`~torch.ao.quantization.MinMaxObserver`
"""
histogram: torch.Tensor
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(
self,
bins: int = 2048,
upsample_rate: int = 128,
dtype: torch.dtype = torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,
quant_min=None,
quant_max=None,
factory_kwargs=None,
) -> None:
# bins: The number of bins used for histogram calculation.
super(HistogramObserver, self).__init__(
dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs,
)
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
self.bins = bins
self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs))
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
self.upsample_rate = upsample_rate
- 對(duì)Activation的 Tensor進(jìn)行統(tǒng)計(jì)觀察
def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach()
min_val = self.min_val
max_val = self.max_val
same_values = min_val.item() == max_val.item()
is_uninitialized = min_val == float("inf") and max_val == float("-inf")
if is_uninitialized or same_values:
min_val, max_val = torch.aminmax(x)
self.min_val.resize_(min_val.shape)
self.min_val.copy_(min_val)
self.max_val.resize_(max_val.shape)
self.max_val.copy_(max_val)
assert (
min_val.numel() == 1 and max_val.numel() == 1
), "histogram min/max values must be scalar."
torch.histc(
x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram
)
else:
new_min, new_max = torch.aminmax(x)
combined_min = torch.min(new_min, min_val)
combined_max = torch.max(new_max, max_val)
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
(
combined_min,
combined_max,
downsample_rate,
start_idx,
) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
assert (
combined_min.numel() == 1 and combined_max.numel() == 1
), "histogram min/max values must be scalar."
combined_histogram = torch.histc(
x, self.bins, min=int(combined_min), max=int(combined_max)
)
if combined_min == min_val and combined_max == max_val:
combined_histogram += self.histogram
else:
combined_histogram = self._combine_histograms(
combined_histogram,
self.histogram,
self.upsample_rate,
downsample_rate,
start_idx,
self.bins,
)
self.histogram.detach_().resize_(combined_histogram.shape)
self.histogram.copy_(combined_histogram)
self.min_val.detach_().resize_(combined_min.shape)
self.min_val.copy_(combined_min)
self.max_val.detach_().resize_(combined_max.shape)
self.max_val.copy_(combined_max)
return x_orig