Tensorflow中float32模型強(qiáng)制轉(zhuǎn)為float16半浮點(diǎn)模型

最近看到一個(gè)巨牛的人工智能教程,分享一下給大家型雳。教程不僅是零基礎(chǔ)当凡,通俗易懂,而且非常風(fēng)趣幽默纠俭,像看小說一樣沿量!覺得太牛了,所以分享給大家冤荆。平時(shí)碎片時(shí)間可以當(dāng)小說看朴则,【點(diǎn)這里可以去膜拜一下大神的“小說”】

在Tensorflow框架訓(xùn)練完成后匙赞,部署模型時(shí)希望對(duì)模型進(jìn)行壓縮佛掖。一種方案是前面文字介紹的方法《【Ubuntu】Tensorflow對(duì)訓(xùn)練后的模型做8位(uint8)量化轉(zhuǎn)換》妖碉。另一種方法是半浮點(diǎn)量化,今天我們主要介紹如何通過修改Tensorflow的pb文件中的計(jì)算節(jié)點(diǎn)和常量(const)芥被,將float32數(shù)據(jù)類型的模型大小壓縮減半為float16數(shù)據(jù)類型的模型欧宜。

1 加載pb模型

封裝函數(shù),加載pb模型:

def load_graph(model_path):
    graph = tf.Graph()
    with graph.as_default():
        graph_def = tf.GraphDef()
        if model_path.endswith("pb"):
            with open(model_path, "rb") as f:
                graph_def.ParseFromString(f.read())
        else:
            with open(model_path, "r") as pf:
                text_format.Parse(pf.read(), graph_def)
        tf.import_graph_def(graph_def, name="")
        sess = tf.Session(graph=graph)
        ops=graph.get_operations()
        for op in ops:
            print(op.name)
        return sess

2 重寫B(tài)atchNorm

由于BatchNorm對(duì)精度比較敏感拴魄,需要保持float32類型冗茸,因此BatchNorm需要特殊處理。

#用FusedBatchNormV2替換FusedBatchNorm匹中,以保證反向梯度下降計(jì)算時(shí)使用的是float
def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'): 
    if target_type == 'fp16':
        dtype = types_pb2.DT_HALF
    elif target_type == 'fp64':
        dtype = types_pb2.DT_DOUBLE
    else:
        dtype = types_pb2.DT_FLOAT
    new_node = graph_def.node.add()
    new_node.op = "FusedBatchNormV2"
    new_node.name = node.name
    new_node.input.extend(node.input)
    new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
    for attr in list(node.attr.keys()):
        if attr == "T":
            node.attr[attr].type = dtype
        new_node.attr[attr].CopyFrom(node.attr[attr])
    print("rewrite fused_batch_norm done!")

3 Graph轉(zhuǎn)換

重新構(gòu)造graph夏漱,參數(shù)從原始pb的graph中拷貝,并轉(zhuǎn)為float16


def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None):
    #生成新的圖數(shù)據(jù)類型
    if target_type == 'fp16':
        dtype = types_pb2.DT_HALF
    elif target_type == 'fp64':
        dtype = types_pb2.DT_DOUBLE
    else:
        dtype = types_pb2.DT_FLOAT

    #加載需要轉(zhuǎn)換的模型
    source_sess = load_graph(model_path)
    source_graph_def = source_sess.graph.as_graph_def()
    #創(chuàng)建新的模圖對(duì)象
    target_graph_def = graph_pb2.GraphDef()
    target_graph_def.versions.CopyFrom(source_graph_def.versions)
    #對(duì)加載的模型遍歷計(jì)算節(jié)點(diǎn)
    for node in source_graph_def.node:
        # 對(duì)FusedBatchNorm計(jì)算節(jié)點(diǎn)替換為FusedBatchNormV2
        if node.op == "FusedBatchNorm":
            rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type)
            continue
        # 復(fù)制計(jì)算節(jié)點(diǎn)
        new_node = target_graph_def.node.add()
        new_node.op = node.op
        new_node.name = node.name
        new_node.input.extend(node.input)

        #對(duì)attrs屬性進(jìn)行復(fù)制顶捷,attrs屬性主要關(guān)注
        attrs = list(node.attr.keys())
        # BatchNorm屬性保持不變
        if ("BatchNorm" in node.name) or ('batch_normalization' in node.name):
            for attr in attrs:
                new_node.attr[attr].CopyFrom(node.attr[attr])
            continue
        # 除了BatchNorm以外其他計(jì)算節(jié)點(diǎn)的屬性單獨(dú)
        for attr in attrs:
            # 對(duì)指定的計(jì)算節(jié)點(diǎn)保持不變
            if node.name in keep_fp32_node_name:
                new_node.attr[attr].CopyFrom(node.attr[attr])
                continue
            #將Float類型修改為設(shè)置的目標(biāo)類型
            if node.attr[attr].type == types_pb2.DT_FLOAT:
                # modify node dtype
                node.attr[attr].type = dtype
                
            #重點(diǎn)關(guān)注value挂绰,weights都是保存在value屬性中
            if attr == "value":
                tensor = node.attr[attr].tensor
                if tensor.dtype == types_pb2.DT_FLOAT:
                    # if float_val exists
                    if tensor.float_val:
                        float_val = tf.make_ndarray(node.attr[attr].tensor)
                        new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype))
                        continue
                    # if tensor content exists
                    if tensor.tensor_content:
                        tensor_shape = [x.size for x in tensor.tensor_shape.dim]
                        tensor_weights = tf.make_ndarray(tensor)
                        # reshape tensor
                        tensor_weights = np.reshape(tensor_weights, tensor_shape)
                        tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype)
                        new_node.attr[attr].tensor.CopyFrom(tensor_proto)
                        continue
            new_node.attr[attr].CopyFrom(node.attr[attr])
    # transform graph
    if output_names:
        if not input_name:
            input_name = []
        transforms = ["strip_unused_nodes"]
        target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms)
    # write graph_def to model
    tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text)
    print("Converting done ...")

4 完整的代碼

import tensorflow as tf
from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from google.protobuf import text_format
import numpy as np

# object detection api input and output nodes
input_name = "input_tf"
output_names = ["output:0"]
keep_fp32_node_name = []

def load_graph(model_path):
    graph = tf.Graph()
    with graph.as_default():
        graph_def = tf.GraphDef()
        if model_path.endswith("pb"):
            with open(model_path, "rb") as f:
                graph_def.ParseFromString(f.read())
        else:
            with open(model_path, "r") as pf:
                text_format.Parse(pf.read(), graph_def)
        tf.import_graph_def(graph_def, name="")
        sess = tf.Session(graph=graph)
        ops=graph.get_operations()
        for op in ops:
            print(op.name)
        return sess

#用FusedBatchNormV2替換FusedBatchNorm,以保證反向梯度下降計(jì)算時(shí)使用的是float
def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'): 
    if target_type == 'fp16':
        dtype = types_pb2.DT_HALF
    elif target_type == 'fp64':
        dtype = types_pb2.DT_DOUBLE
    else:
        dtype = types_pb2.DT_FLOAT
    new_node = graph_def.node.add()
    new_node.op = "FusedBatchNormV2"
    new_node.name = node.name
    new_node.input.extend(node.input)
    new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
    for attr in list(node.attr.keys()):
        if attr == "T":
            node.attr[attr].type = dtype
        new_node.attr[attr].CopyFrom(node.attr[attr])
    print("rewrite fused_batch_norm done!")

def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None):
    #生成新的圖數(shù)據(jù)類型
    if target_type == 'fp16':
        dtype = types_pb2.DT_HALF
    elif target_type == 'fp64':
        dtype = types_pb2.DT_DOUBLE
    else:
        dtype = types_pb2.DT_FLOAT

    #加載需要轉(zhuǎn)換的模型
    source_sess = load_graph(model_path)
    source_graph_def = source_sess.graph.as_graph_def()
    #創(chuàng)建新的模圖對(duì)象
    target_graph_def = graph_pb2.GraphDef()
    target_graph_def.versions.CopyFrom(source_graph_def.versions)
    #對(duì)加載的模型遍歷計(jì)算節(jié)點(diǎn)
    for node in source_graph_def.node:
        # 對(duì)FusedBatchNorm計(jì)算節(jié)點(diǎn)替換為FusedBatchNormV2
        if node.op == "FusedBatchNorm":
            rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type)
            continue
        # 復(fù)制計(jì)算節(jié)點(diǎn)
        new_node = target_graph_def.node.add()
        new_node.op = node.op
        new_node.name = node.name
        new_node.input.extend(node.input)

        #對(duì)attrs屬性進(jìn)行復(fù)制服赎,attrs屬性主要關(guān)注
        attrs = list(node.attr.keys())
        # BatchNorm屬性保持不變
        if ("BatchNorm" in node.name) or ('batch_normalization' in node.name):
            for attr in attrs:
                new_node.attr[attr].CopyFrom(node.attr[attr])
            continue
        # 除了BatchNorm以外其他計(jì)算節(jié)點(diǎn)的屬性單獨(dú)
        for attr in attrs:
            # 對(duì)指定的計(jì)算節(jié)點(diǎn)保持不變
            if node.name in keep_fp32_node_name:
                new_node.attr[attr].CopyFrom(node.attr[attr])
                continue
            #將Float類型修改為設(shè)置的目標(biāo)類型
            if node.attr[attr].type == types_pb2.DT_FLOAT:
                # modify node dtype
                node.attr[attr].type = dtype
                
            #重點(diǎn)關(guān)注value葵蒂,weights都是保存在value屬性中
            if attr == "value":
                tensor = node.attr[attr].tensor
                if tensor.dtype == types_pb2.DT_FLOAT:
                    # if float_val exists
                    if tensor.float_val:
                        float_val = tf.make_ndarray(node.attr[attr].tensor)
                        new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype))
                        continue
                    # if tensor content exists
                    if tensor.tensor_content:
                        tensor_shape = [x.size for x in tensor.tensor_shape.dim]
                        tensor_weights = tf.make_ndarray(tensor)
                        # reshape tensor
                        tensor_weights = np.reshape(tensor_weights, tensor_shape)
                        tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype)
                        new_node.attr[attr].tensor.CopyFrom(tensor_proto)
                        continue
            new_node.attr[attr].CopyFrom(node.attr[attr])
    # transform graph
    if output_names:
        if not input_name:
            input_name = []
        transforms = ["strip_unused_nodes"]
        target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms)
    # write graph_def to model
    tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text)
    print("Converting done ...")

save_path = "test"
name = "output_fp16.pb"
model_path="test.pb"
as_text = False
target_type = 'fp16'
convert_graph_to_fp16(model_path, save_path, name, as_text=as_text, target_type=target_type, input_name=input_name, output_names=output_names)
# 測(cè)試一下轉(zhuǎn)換后的模型是否能夠加載
sess = load_graph(save_path+"/"+name)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市重虑,隨后出現(xiàn)的幾起案子践付,更是在濱河造成了極大的恐慌,老刑警劉巖缺厉,帶你破解...
    沈念sama閱讀 219,490評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件永高,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡提针,警方通過查閱死者的電腦和手機(jī)命爬,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,581評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來关贵,“玉大人遇骑,你說我怎么就攤上這事卖毁∫驹” “怎么了?”我有些...
    開封第一講書人閱讀 165,830評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵亥啦,是天一觀的道長(zhǎng)炭剪。 經(jīng)常有香客問我,道長(zhǎng)翔脱,這世上最難降的妖魔是什么奴拦? 我笑而不...
    開封第一講書人閱讀 58,957評(píng)論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮届吁,結(jié)果婚禮上错妖,老公的妹妹穿的比我還像新娘绿鸣。我一直安慰自己,他們只是感情好暂氯,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,974評(píng)論 6 393
  • 文/花漫 我一把揭開白布潮模。 她就那樣靜靜地躺著,像睡著了一般痴施。 火紅的嫁衣襯著肌膚如雪擎厢。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,754評(píng)論 1 307
  • 那天辣吃,我揣著相機(jī)與錄音动遭,去河邊找鬼。 笑死神得,一個(gè)胖子當(dāng)著我的面吹牛厘惦,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播哩簿,決...
    沈念sama閱讀 40,464評(píng)論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼绵估,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了卡骂?” 一聲冷哼從身側(cè)響起国裳,我...
    開封第一講書人閱讀 39,357評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎全跨,沒想到半個(gè)月后缝左,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,847評(píng)論 1 317
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡浓若,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,995評(píng)論 3 338
  • 正文 我和宋清朗相戀三年渺杉,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片挪钓。...
    茶點(diǎn)故事閱讀 40,137評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡是越,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出碌上,到底是詐尸還是另有隱情倚评,我是刑警寧澤,帶...
    沈念sama閱讀 35,819評(píng)論 5 346
  • 正文 年R本政府宣布馏予,位于F島的核電站天梧,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏霞丧。R本人自食惡果不足惜呢岗,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,482評(píng)論 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧后豫,春花似錦悉尾、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,023評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至饭豹,卻和暖如春鸵赖,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背拄衰。 一陣腳步聲響...
    開封第一講書人閱讀 33,149評(píng)論 1 272
  • 我被黑心中介騙來泰國(guó)打工它褪, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人翘悉。 一個(gè)月前我還...
    沈念sama閱讀 48,409評(píng)論 3 373
  • 正文 我出身青樓茫打,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親妖混。 傳聞我的和親對(duì)象是個(gè)殘疾皇子老赤,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,086評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容