最近在作pytorch 模型轉(zhuǎn)tensorflow,通過onnx 中間轉(zhuǎn)換和容易会宪,但是再轉(zhuǎn)換時有一個注意事項,即如何處理batch
pytorch 模型轉(zhuǎn)tensorflow: http://www.reibang.com/p/3e5623696a8e
通過 onnx 手動修改batch 為動態(tài)值: model_onnx.graph.input[0].type.tensor_type.shape.dim[0].dim_param ='?'
這下就可以使用batch 了蚯窥。
最后別忘了用transform_graph 壓縮下模型大械Ф臁: http://www.reibang.com/p/d2637646cda1
self.model.load_state_dict(model_dict)
example = torch.ones(1,1,112,112).cuda() #限定好的tensor 輸入大小
# traced_script_module = torch.jit.trace(self.model, example)
# traced_script_module.save('./lt_model.pt')
torch.onnx.export(self.model, example,'./model_simple.onnx',input_names=['input'],
output_names=['output'])
model_onnx = onnx.load('./model_simple.onnx')
model_onnx.graph.input[0].type.tensor_type.shape.dim[0].dim_param ='?'
tf_rep = prepare(model_onnx)
print(tf_rep.tensor_dict)
tf_rep.export_graph('./lt_model.pb')