test.py add cpu forward
This commit is contained in:
parent
e64a02cdba
commit
df3630dd6e
16
test.py
16
test.py
@ -7,9 +7,6 @@ import torch
|
|||||||
from utils.tool import *
|
from utils.tool import *
|
||||||
from module.detector import Detector
|
from module.detector import Detector
|
||||||
|
|
||||||
# 指定后端设备CUDA&CPU
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 指定训练配置文件
|
# 指定训练配置文件
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -18,12 +15,25 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--img', type=str, default='', help='The path of test image')
|
parser.add_argument('--img', type=str, default='', help='The path of test image')
|
||||||
parser.add_argument('--thresh', type=float, default=0.8, help='The path of test image')
|
parser.add_argument('--thresh', type=float, default=0.8, help='The path of test image')
|
||||||
parser.add_argument('--onnx', action="store_true", default=False, help='Export onnx file')
|
parser.add_argument('--onnx', action="store_true", default=False, help='Export onnx file')
|
||||||
|
parser.add_argument('--cpu', action="store_true", default=False, help='Run on cpu')
|
||||||
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
assert os.path.exists(opt.yaml), "请指定正确的配置文件路径"
|
assert os.path.exists(opt.yaml), "请指定正确的配置文件路径"
|
||||||
assert os.path.exists(opt.weight), "请指定正确的模型路径"
|
assert os.path.exists(opt.weight), "请指定正确的模型路径"
|
||||||
assert os.path.exists(opt.img), "请指定正确的测试图像路径"
|
assert os.path.exists(opt.img), "请指定正确的测试图像路径"
|
||||||
|
|
||||||
|
# 选择推理后端
|
||||||
|
if opt.cpu:
|
||||||
|
print("run on cpu...")
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print("run on gpu...")
|
||||||
|
device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
print("run on cpu...")
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
# 解析yaml配置文件
|
# 解析yaml配置文件
|
||||||
cfg = LoadYaml(opt.yaml)
|
cfg = LoadYaml(opt.yaml)
|
||||||
print(cfg)
|
print(cfg)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user