背景
最近在AXERA M55H工具鏈做一個(gè)語義分割模型量化慌闭。
AXERA官方文檔表示argmax只能接在conv算子后初坠,而我的deeplabv3+模型最后兩個(gè)節(jié)點(diǎn)是resize上采樣接argmax。
試了一下模型轉(zhuǎn)換(編譯铜幽、量化成M55H支持的.joint模型)滞谢,果然在執(zhí)行argmax相關(guān)操作時(shí)報(bào)錯(cuò)。
于是只能手動(dòng)將onnx文件的argmax節(jié)點(diǎn)刪除除抛,在后處理來做argmax了狮杨。
ONNX刪除節(jié)點(diǎn)
由于模型只有結(jié)尾處有一個(gè)argmax節(jié)點(diǎn),所以直接找到op_type == "ArgMax"的節(jié)點(diǎn)將其刪除即可到忽。
node_to_rm = next(node for node in model.graph.node if node.op_type == "ArgMax")
model.graph.node.remove(node_to_rm)
onnx.save(model, dst_model)
此時(shí)用生成的新model推理會(huì)報(bào)錯(cuò)橄教,大概錯(cuò)誤信息是output節(jié)點(diǎn)不在graph中。
查了一些資料發(fā)現(xiàn)model.graph.output和model.graph.node是平行的存在喘漏,也就是說輸出節(jié)點(diǎn)是區(qū)別于中間節(jié)點(diǎn)獨(dú)立存儲(chǔ)在model.graph.output中的护蝶。(輸入節(jié)點(diǎn)也類似)
上述的操作刪除了最后一個(gè)argmax節(jié)點(diǎn),但是沒有刪除輸出節(jié)點(diǎn)翩迈。并且滓走,一個(gè)graph必須包含1個(gè)以上的輸入和輸出節(jié)點(diǎn)。所以我們需要?jiǎng)h除原有的輸出節(jié)點(diǎn)并創(chuàng)建新的帽馋,即更換輸出節(jié)點(diǎn)搅方。
ONNX更換輸出節(jié)點(diǎn)
model.graph.output是一個(gè)list,包含所有輸出節(jié)點(diǎn)绽族。
目前包含一個(gè)輸出節(jié)點(diǎn)姨涡,就是之前的經(jīng)過argmax的分割特征圖。
輸出節(jié)點(diǎn)跟普通節(jié)點(diǎn)的數(shù)據(jù)結(jié)構(gòu)不同吧慢,它包含了節(jié)點(diǎn)名涛漂、輸出的數(shù)據(jù)結(jié)構(gòu)等信息,
因此只需要在現(xiàn)有節(jié)點(diǎn)基礎(chǔ)上進(jìn)行如下修改即可: (也可以通過onnx.helper創(chuàng)建新的節(jié)點(diǎn))
node_to_out = next(node for node in model.graph.node if node.output == node_to_rm.input) # 找到刪除節(jié)點(diǎn)的上游節(jié)點(diǎn)检诗,作為輸出節(jié)點(diǎn)的前置
out = model.graph.output[0] # 在原來的輸出節(jié)點(diǎn)基礎(chǔ)上改即可
out.name = node_to_out.output[0] # 修改為新的輸出節(jié)點(diǎn)名字
out.type.tensor_type.shape.dim[1].dim_value = 4 # 該維度指channel數(shù)匈仗,argmax以后為1,改為4逢慌,因?yàn)樵撃P陀?類悠轩,onehot表示
out.type.tensor_type.elem_type = 1 # 1表示float32, 經(jīng)過argmax后是7攻泼,表示int64
附錄: onnx的elem_type
elem_type: 1 --> float32
elem_type: 2 --> uint8
elem_type: 3 --> int8
elem_type: 4 --> uint16
elem_type: 5 --> int16
elem_type: 6 --> int32
elem_type: 7 --> int64
elem_type: 8 --> string
elem_type: 9 --> boolean
elem_type: 10 --> float16
elem_type: 11 --> float64
elem_type: 12 --> uint32
elem_type: 14 --> uint64
elem_type: 15 --> complex128
elem_type: 16 --> bfloat16
from: https://blog.csdn.net/weixin_43945848/article/details/122474749