Update test.py
This commit is contained in:
parent
b1c5ef77a3
commit
f7f9ab4b96
10
test.py
10
test.py
@ -15,6 +15,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--img', type=str, default='', help='The path of test image')
|
||||
parser.add_argument('--thresh', type=float, default=0.65, help='The path of test image')
|
||||
parser.add_argument('--onnx', action="store_true", default=False, help='Export onnx file')
|
||||
parser.add_argument('--torchscript', action="store_true", default=False, help='Export torchscript file')
|
||||
parser.add_argument('--cpu', action="store_true", default=False, help='Run on cpu')
|
||||
|
||||
opt = parser.parse_args()
|
||||
@ -61,6 +62,15 @@ if __name__ == '__main__':
|
||||
opset_version=11, # the ONNX version to export the model to
|
||||
do_constant_folding=True) # whether to execute constant folding for optimization
|
||||
|
||||
# 导出torchscript模型
|
||||
if opt.torchscript:
|
||||
import copy
|
||||
model_cpu = copy.deepcopy(model).cpu()
|
||||
x = torch.rand(1, 3, cfg.input_height, cfg.input_width)
|
||||
mod = torch.jit.trace(model_cpu, x)
|
||||
mod.save("./FastestDet.pt")
|
||||
print("to convert torchscript to pnnx/ncnn: ./pnnx FastestDet.pt inputshape=[1,3,%d,%d]" % (cfg.input_height, cfg.input_height))
|
||||
|
||||
# 模型推理
|
||||
start = time.perf_counter()
|
||||
preds = model(img)
|
||||
|
Loading…
x
Reference in New Issue
Block a user