ir_version: 3
producer_name: "onnx.utils.extract_model"
graph {
node {
input: "Input3"
input: "Constant339"
output: "Minus340_Output_0"
name: "Minus340"
op_type: "Sub"
doc_string: ""
domain: ""
}
node {
input: "Minus340_Output_0"
input: "Constant343"
output: "Block352_Output_0"
name: "Block352"
op_type: "Div"
doc_string: ""
domain: ""
}
name: "Extracted from {CNTKGraph}"
initializer {
data_type: 1
float_data: 127.5
name: "Constant339"
}
initializer {
data_type: 1
float_data: 255.0
name: "Constant343"
}
input {
name: "Input3"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 64
}
}
}
}
}
output {
name: "Block352_Output_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 64
}
}
}
}
}
value_info {
name: "Minus340_Output_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 64
}
}
}
}
}
value_info {
name: "Block352_Output_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 64
}
}
}
}
}
}
opset_import {
domain: ""
version: 7
}
# 了解了onnx的結(jié)構(gòu)后椭盏,我們可以根據(jù)它的結(jié)構(gòu)將其拆分成多個單節(jié)點的onnx模型组砚,以便于對整體模型的單個節(jié)點進行測試和分析。
import onnx
from onnx import helper,numpy_helper
def show_weight(weight):
print("="*10, "details of weight: ", weight.name, "="*10)
print("data type: ", weight.data_type)
print("shape: ", weight.dims)
data_numpy = numpy_helper.to_array(weight)
# data_numpy = np.frombuffer(weight.raw_data, dtype=xxx)
# print("detail data:", data_numpy)
print("="*40)
# onnx.utils.extract_model("emotion-ferplus-7.onnx","mini_model.onnx",["Input3"],["Block352_Output_0"])
model = onnx.load("emotion-ferplus-7.onnx")
# print(model.ir_version) # IR的版本
# print(model.producer_name) #
# print(model.opset_import) # opset 版本信息
# graph
# graph中有node(NodeProto類型)掏颊,input(ValueInfoProto類型)糟红,output(ValueInfoProto類型)和initializer(TensorProto類型)
# 其中node中存放著模型中的所有計算節(jié)點,input中存放著模型所有的輸入節(jié)點乌叶,output存放著模型所有的輸出節(jié)點盆偿,initializer存放著模型所有的權(quán)重常量 tensor;
#value_info存放了各個 動態(tài)tensor 的信息
#在ONNX模型中准浴,initializer和value_info都是用于描述模型中的張量(tensor)的事扭,但是它們的用途有所不同。
# initializer是一個包含模型中所有預(yù)先初始化的張量的列表乐横。這些張量通常是模型的權(quán)重和偏差求橄,它們在模型訓(xùn)練過程中被學(xué)習(xí),并在模型推理過程中被使用葡公。initializer中的每個元素都是一個TensorProto對象罐农,包含了張量的數(shù)據(jù)類型、形狀和值催什。
# value_info則是用于描述模型中的輸入涵亏、輸出和中間結(jié)果的張量的。它包含了張量的名稱蒲凶、數(shù)據(jù)類型和形狀气筋,但不包含張量的值。value_info主要用于在模型的圖形定義中豹爹,描述那些不是模型輸入也不是模型輸出裆悄,但在模型計算過程中會被使用的張量矛纹。
# 總的來說臂聋,initializer和value_info都是用于描述模型中的張量的,但initializer更關(guān)注于張量的值或南,而value_info更關(guān)注于張量的元數(shù)據(jù)孩等。
# node 通過input和output的指向關(guān)系,描繪出一個深度學(xué)習(xí)模型的拓?fù)鋱D
# for node in model.graph.node:
# print(node)
# print(model.graph.input)
# print(model.graph.output)
# #獲取節(jié)點數(shù)量
# print(len(model.graph.node))
# # with open("model.txt","w") as f:
# # f.write(str(model))
# # print(model.graph)
# print(getNodeNameList(model))
# print("-----------------------------------------------")
# #如何修改一個initializer 的值采够?肄方?
# # 移除舊的initializer 添加一個新的initializer
# # 直接修改當(dāng)前的initializer
# for initializer in model.graph.initializer:
# print(initializer)
# # model.graph.initializer.remove(next(init for init in model.graph.initializer if init.name == 'Constant343'))
# model.graph.initializer.remove(model.graph.initializer[0])
# # 創(chuàng)建一個新的TensorProto對象作為新的initializer
# new_initializer = helper.make_tensor(name = 'Constant339', data_type = onnx.TensorProto.FLOAT,
# dims = [1], vals = [255.0], raw=False)
# model.graph.initializer.append(new_initializer)
# print("-----------------------------------------------")
# for initializer in model.graph.initializer:
# print(initializer)
# print("-----------------------------------------------")
# init = model.graph.initializer[0]
# print(init)
# init.name = "Constant343_new"
# #data是一個類似列表的對象,按照列表的方式去操作
# # init.float_data.pop()
# # init.float_data[:] = []
# # init.float_data.extend([123.0])
# # print(dir(init.float_data))
# init.float_data[0] = 123.50
# print("-----------------------------------------------")
# print(model.graph.initializer)
#####################################################################################################################
for info in model.graph.value_info:
# info.type.tensor_type.shape.dim[:] = [1,2,3]
from onnx import TensorShapeProto, TensorShapeProto
# 添加一個維度
dim = TensorShapeProto.Dimension()
dim.dim_value = 666 # 設(shè)置新維度的大小
info.type.tensor_type.shape.dim.append(dim)
info.type.tensor_type.shape.dim[0].dim_value = 333
# 刪除一個維度
del info.type.tensor_type.shape.dim[2]
print(info.type.tensor_type.shape.dim)
# print(dir(info.type.tensor_type.shape.dim))
# tensorname_to_info = {info.name:info for info in model.graph.value_info}
# print(tensorname_to_info)
init_maper = {init.name:init for init in model.graph.initializer}
valueinfo_maper = {value_info.name:value_info for value_info in model.graph.value_info}
node_mapper = {node.name:node for init in model.graph.node}
def get_parent_node(graph, node):
parents = []
for input_name in node.input:
for n in graph.node:
if input_name in n.output:
parents.append(n)
return parents
def get_children_node(graph, node):
children = []
for output_name in node.output:
for n in graph.node:
if output_name in n.input:
children.append(n)
return children
node = node_mapper["noode_name"]
parents = get_parent_node(model.graph, node)
children = get_children_node(model.graph, node)
def remove_node_by_name(graph, node_name):
# 標(biāo)記需要移除的node
node_to_remove = None
for node in graph.node:
if node.name == node_name:
node_to_remove = node
break
if node_to_remove is None:
return False
# 刪除該node的所有輸入和輸出edge
# 把即將刪除節(jié)點輸出的tensor重定向蹬癌,改接向即將刪除節(jié)點輸入的tensor
for output in node_to_remove.output:
for node in graph.node:
for index, input in enumerate(node.input):
if input == output:
node.input[index] = node_to_remove.input[0]
# 從graph中移除這個node
graph.node.remove(node_to_remove)
return True
##在后面插入一個新的node
def insert_node_after(graph, existing_node_name, new_node):
# 添加新節(jié)點到圖中权她,并更新輸入輸出連接
nodes = graph.node
for i, node in enumerate(nodes):
if node.name == existing_node_name:
# 假設(shè)現(xiàn)有節(jié)點的輸出僅鏈接到另外一個節(jié)點
existing_output_name = node.output[0]
# 修改現(xiàn)有節(jié)點的輸出名稱
new_output_name = new_node.output[0]
node.output[0] = new_output_name
# 更新所有引用原先輸出名稱的節(jié)點的對應(yīng)輸入
for later_node in nodes[i+1:]:
for j, input_name in enumerate(later_node.input):
if input_name == existing_output_name:
later_node.input[j] = new_output_name
# 插入新節(jié)點到現(xiàn)有節(jié)點后面
nodes.insert(i + 1, new_node)
break
def insert_node_before(graph, existing_node_name, new_node):
# 查找現(xiàn)有節(jié)點以及它的輸入
existing_node_index = None
existing_node_input = None
for index, node in enumerate(graph.node):
if node.name == existing_node_name:
existing_node_index = index
existing_node_input = node.input
break
if existing_node_index is None:
raise ValueError(f"No node with name {existing_node_name} found in the graph.")
# 新節(jié)點將采用現(xiàn)有節(jié)點的輸入虹茶,并將其輸出設(shè)置為現(xiàn)有節(jié)點的新輸入
# 假設(shè)新節(jié)點有一個輸出
new_node_output = new_node.output[0]
# 所有原來指向現(xiàn)有節(jié)點輸入的連接現(xiàn)在應(yīng)該指向新節(jié)點的輸出
for node in graph.node:
for i, input_name in enumerate(node.input):
if input_name in existing_node_input:
node.input[i] = new_node_output
# 現(xiàn)有節(jié)點的新輸入將是新節(jié)點的輸出
graph.node[existing_node_index].input[:] = [new_node_output]
# 將新節(jié)點插入到圖中現(xiàn)有節(jié)點之前的位置
graph.node.insert(existing_node_index, new_node)
# NodeNameList = []
# for i in range(len(model.graph.node)):
# node = model.graph.node[i]
# print(node.input)
# print(node.output)
# print(node.attribute)
# NodeNameList.append(model.graph.node[i].name)
# out_tvi = [inner_output for inner_output in model.graph.value_info if inner_output.name == name]
https://github.com/ZhangGe6/onnx-modifier/tree/master
https://github.com/bindog/onnx-surgery/blob/master/surgery.py
https://bindog.github.io/blog/2020/03/13/deep-learning-model-convert-and-depoly/
https://www.zhihu.com/question/386526462
https://blog.csdn.net/ChuiGeDaQiQiu/article/details/123794387