Merge pull request #6 from nihui/patch-1

Export torchscript
This commit is contained in:
xuehao.ma 2022-07-07 22:38:34 +08:00 committed by GitHub
commit 1a28f35322
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 0 deletions

View File

@ -181,6 +181,11 @@ TRAIN:
``` ```
python3 test.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth --img data/3.jpg --onnx python3 test.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth --img data/3.jpg --onnx
``` ```
## Export torchscript
* You can export .pt by adding the --torchscript option when executing test.py
```
python3 test.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth --img data/3.jpg --torchscript
```
## onnx-runtime ## onnx-runtime
* You can learn about the pre and post-processing methods of FastestDet in this Sample * You can learn about the pre and post-processing methods of FastestDet in this Sample
``` ```

10
test.py
View File

@ -15,6 +15,7 @@ if __name__ == '__main__':
parser.add_argument('--img', type=str, default='', help='The path of test image') parser.add_argument('--img', type=str, default='', help='The path of test image')
parser.add_argument('--thresh', type=float, default=0.65, help='The path of test image') parser.add_argument('--thresh', type=float, default=0.65, help='The path of test image')
parser.add_argument('--onnx', action="store_true", default=False, help='Export onnx file') parser.add_argument('--onnx', action="store_true", default=False, help='Export onnx file')
parser.add_argument('--torchscript', action="store_true", default=False, help='Export torchscript file')
parser.add_argument('--cpu', action="store_true", default=False, help='Run on cpu') parser.add_argument('--cpu', action="store_true", default=False, help='Run on cpu')
opt = parser.parse_args() opt = parser.parse_args()
@ -61,6 +62,15 @@ if __name__ == '__main__':
opset_version=11, # the ONNX version to export the model to opset_version=11, # the ONNX version to export the model to
do_constant_folding=True) # whether to execute constant folding for optimization do_constant_folding=True) # whether to execute constant folding for optimization
# 导出torchscript模型
if opt.torchscript:
import copy
model_cpu = copy.deepcopy(model).cpu()
x = torch.rand(1, 3, cfg.input_height, cfg.input_width)
mod = torch.jit.trace(model_cpu, x)
mod.save("./FastestDet.pt")
print("to convert torchscript to pnnx/ncnn: ./pnnx FastestDet.pt inputshape=[1,3,%d,%d]" % (cfg.input_height, cfg.input_height))
# 模型推理 # 模型推理
start = time.perf_counter() start = time.perf_counter()
preds = model(img) preds = model(img)