Update test.py

This commit is contained in:
nihui 2022-07-07 19:21:58 +08:00 committed by GitHub
parent b1c5ef77a3
commit f7f9ab4b96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

10
test.py
View File

@ -15,6 +15,7 @@ if __name__ == '__main__':
parser.add_argument('--img', type=str, default='', help='The path of test image')
parser.add_argument('--thresh', type=float, default=0.65, help='The path of test image')
parser.add_argument('--onnx', action="store_true", default=False, help='Export onnx file')
parser.add_argument('--torchscript', action="store_true", default=False, help='Export torchscript file')
parser.add_argument('--cpu', action="store_true", default=False, help='Run on cpu')
opt = parser.parse_args()
@ -61,6 +62,15 @@ if __name__ == '__main__':
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True) # whether to execute constant folding for optimization
# 导出torchscript模型
if opt.torchscript:
import copy
model_cpu = copy.deepcopy(model).cpu()
x = torch.rand(1, 3, cfg.input_height, cfg.input_width)
mod = torch.jit.trace(model_cpu, x)
mod.save("./FastestDet.pt")
print("to convert torchscript to pnnx/ncnn: ./pnnx FastestDet.pt inputshape=[1,3,%d,%d]" % (cfg.input_height, cfg.input_height))
# 模型推理
start = time.perf_counter()
preds = model(img)