diff --git a/test.py b/test.py index 4f5671f..28b4170 100644 --- a/test.py +++ b/test.py @@ -15,6 +15,7 @@ if __name__ == '__main__': parser.add_argument('--img', type=str, default='', help='The path of test image') parser.add_argument('--thresh', type=float, default=0.65, help='The path of test image') parser.add_argument('--onnx', action="store_true", default=False, help='Export onnx file') + parser.add_argument('--torchscript', action="store_true", default=False, help='Export torchscript file') parser.add_argument('--cpu', action="store_true", default=False, help='Run on cpu') opt = parser.parse_args() @@ -61,6 +62,15 @@ if __name__ == '__main__': opset_version=11, # the ONNX version to export the model to do_constant_folding=True) # whether to execute constant folding for optimization + # 导出torchscript模型 + if opt.torchscript: + import copy + model_cpu = copy.deepcopy(model).cpu() + x = torch.rand(1, 3, cfg.input_height, cfg.input_width) + mod = torch.jit.trace(model_cpu, x) + mod.save("./FastestDet.pt") + print("to convert torchscript to pnnx/ncnn: ./pnnx FastestDet.pt inputshape=[1,3,%d,%d]" % (cfg.input_height, cfg.input_height)) + # 模型推理 start = time.perf_counter() preds = model(img)