diff --git a/test.py b/test.py index 3343b6b..d623f4b 100644 --- a/test.py +++ b/test.py @@ -7,9 +7,6 @@ import torch from utils.tool import * from module.detector import Detector -# 指定后端设备CUDA&CPU -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if __name__ == '__main__': # 指定训练配置文件 parser = argparse.ArgumentParser() @@ -18,12 +15,25 @@ if __name__ == '__main__': parser.add_argument('--img', type=str, default='', help='The path of test image') parser.add_argument('--thresh', type=float, default=0.8, help='The path of test image') parser.add_argument('--onnx', action="store_true", default=False, help='Export onnx file') + parser.add_argument('--cpu', action="store_true", default=False, help='Run on cpu') opt = parser.parse_args() assert os.path.exists(opt.yaml), "请指定正确的配置文件路径" assert os.path.exists(opt.weight), "请指定正确的模型路径" assert os.path.exists(opt.img), "请指定正确的测试图像路径" + # 选择推理后端 + if opt.cpu: + print("run on cpu...") + device = torch.device("cpu") + else: + if torch.cuda.is_available(): + print("run on gpu...") + device = torch.device("cuda") + else: + print("run on cpu...") + device = torch.device("cpu") + # 解析yaml配置文件 cfg = LoadYaml(opt.yaml) print(cfg) @@ -85,4 +95,4 @@ if __name__ == '__main__': cv2.putText(ori_img, '%.2f' % obj_score, (x1, y1 - 5), 0, 0.7, (0, 255, 0), 2) cv2.putText(ori_img, category, (x1, y1 - 25), 0, 0.7, (0, 255, 0), 2) - cv2.imwrite("result.png", ori_img) \ No newline at end of file + cv2.imwrite("result.png", ori_img)