1 Star 0 Fork 0

dhc3 / handpose_x

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
onnx_inference.py 3.45 KB
一键复制 编辑 原始数据 按行查看 历史
lixiang 提交于 2021-10-13 13:29 . add onnx tool
#-*-coding:utf-8-*-
# date:2021-10-5
# Author: Eric.Lee
# function: onnx Inference
import os, sys
sys.path.append(os.getcwd())
import onnxruntime
import onnx
import cv2
import torch
import numpy as np
from hand_data_iter.datasets import draw_bd_handpose
class ONNXModel():
def __init__(self, onnx_path,gpu_cfg = False):
"""
:param onnx_path:
"""
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
if gpu_cfg:
self.onnx_session.set_providers(['CUDAExecutionProvider'], [ {'device_id': 0}])
self.input_name = self.get_input_name(self.onnx_session)
self.output_name = self.get_output_name(self.onnx_session)
print("input_name:{}".format(self.input_name))
print("output_name:{}".format(self.output_name))
def get_output_name(self, onnx_session):
"""
output_name = onnx_session.get_outputs()[0].name
:param onnx_session:
:return:
"""
output_name = []
for node in onnx_session.get_outputs():
output_name.append(node.name)
return output_name
def get_input_name(self, onnx_session):
"""
:param onnx_session:
:return:
"""
input_name = []
for node in onnx_session.get_inputs():
input_name.append(node.name)
return input_name
def get_input_feed(self, input_name, image_numpy):
"""
:param input_name:
:param image_numpy:
:return:
"""
input_feed = {}
for name in input_name:
input_feed[name] = image_numpy
return input_feed
def forward(self, image_numpy):
'''
# image_numpy = image_numpy[np.newaxis, :]
# onnx_session.run([output_name], {input_name: x})
# :param image_numpy:
# :return:
'''
input_feed = self.get_input_feed(self.input_name, image_numpy)
output = self.onnx_session.run(self.output_name, input_feed=input_feed)
return output
if __name__ == "__main__":
img_size = 256
model = ONNXModel("resnet_50_size-256.onnx")
path_ = "./image/"
for f_ in os.listdir(path_):
img0 = cv2.imread(path_ + f_)
img_width = img0.shape[1]
img_height = img0.shape[0]
img = cv2.resize(img0, (img_size,img_size), interpolation = cv2.INTER_CUBIC)
img_ndarray = img.transpose((2, 0, 1))
img_ndarray = img_ndarray / 255.
img_ndarray = np.expand_dims(img_ndarray, 0)
output = model.forward(img_ndarray.astype('float32'))[0][0]
output = np.array(output)
print(output.shape[0])
pts_hand = {} #构建关键点连线可视化结构
for i in range(int(output.shape[0]/2)):
x = (output[i*2+0]*float(img_width))
y = (output[i*2+1]*float(img_height))
pts_hand[str(i)] = {}
pts_hand[str(i)] = {
"x":x,
"y":y,
}
draw_bd_handpose(img0,pts_hand,0,0) # 绘制关键点连线
#------------- 绘制关键点
for i in range(int(output.shape[0]/2)):
x = (output[i*2+0]*float(img_width))
y = (output[i*2+1]*float(img_height))
cv2.circle(img0, (int(x),int(y)), 3, (255,50,60),-1)
cv2.circle(img0, (int(x),int(y)), 1, (255,150,180),-1)
cv2.namedWindow('image',0)
cv2.imshow('image',img0)
if cv2.waitKey(600) == 27 :
break
cv2.waitKey(0)
1
https://gitee.com/dhc3/handpose_x.git
git@gitee.com:dhc3/handpose_x.git
dhc3
handpose_x
handpose_x
master

搜索帮助

53164aa7 5694891 3bd8fe86 5694891