test.py add cpu forward

This commit is contained in:
duoduo 2022-07-02 09:47:06 +08:00
parent e64a02cdba
commit df3630dd6e

18
test.py
View File

@ -7,9 +7,6 @@ import torch
from utils.tool import *
from module.detector import Detector
# 指定后端设备CUDA&CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
# 指定训练配置文件
parser = argparse.ArgumentParser()
@ -18,12 +15,25 @@ if __name__ == '__main__':
parser.add_argument('--img', type=str, default='', help='The path of test image')
parser.add_argument('--thresh', type=float, default=0.8, help='The path of test image')
parser.add_argument('--onnx', action="store_true", default=False, help='Export onnx file')
parser.add_argument('--cpu', action="store_true", default=False, help='Run on cpu')
opt = parser.parse_args()
assert os.path.exists(opt.yaml), "请指定正确的配置文件路径"
assert os.path.exists(opt.weight), "请指定正确的模型路径"
assert os.path.exists(opt.img), "请指定正确的测试图像路径"
# 选择推理后端
if opt.cpu:
print("run on cpu...")
device = torch.device("cpu")
else:
if torch.cuda.is_available():
print("run on gpu...")
device = torch.device("cuda")
else:
print("run on cpu...")
device = torch.device("cpu")
# 解析yaml配置文件
cfg = LoadYaml(opt.yaml)
print(cfg)
@ -85,4 +95,4 @@ if __name__ == '__main__':
cv2.putText(ori_img, '%.2f' % obj_score, (x1, y1 - 5), 0, 0.7, (0, 255, 0), 2)
cv2.putText(ori_img, category, (x1, y1 - 25), 0, 0.7, (0, 255, 0), 2)
cv2.imwrite("result.png", ori_img)
cv2.imwrite("result.png", ori_img)