FastestDet/train.py
2022-07-01 21:22:39 +08:00

134 lines
5.5 KiB
Python

import os
import math
import torch
import argparse
from tqdm import tqdm
from torch import optim
from torchsummary import summary
from utils.tool import *
from utils.datasets import *
from utils.evaluation import CocoDetectionEvaluator
from module.loss import DetectorLoss
from module.detector import Detector
# 指定后端设备CUDA&CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class FastestDet:
def __init__(self):
# 指定训练配置文件
parser = argparse.ArgumentParser()
parser.add_argument('--yaml', type=str, default="", help='.yaml config')
parser.add_argument('--weight', type=str, default=None, help='.weight config')
opt = parser.parse_args()
assert os.path.exists(opt.yaml), "请指定正确的配置文件路径"
# 解析yaml配置文件
self.cfg = LoadYaml(opt.yaml)
print(self.cfg)
# 初始化模型结构
if opt.weight is not None:
print("load weight from:%s"%opt.weight)
self.model = Detector(self.cfg.category_num, True).to(device)
self.model.load_state_dict(torch.load(opt.weight))
else:
self.model = Detector(self.cfg.category_num, False).to(device)
# # 打印网络各层的张量维度
summary(self.model, input_size=(3, self.cfg.input_height, self.cfg.input_width))
#构建优化器
print("use SGD optimizer")
self.optimizer = optim.SGD(params=self.model.parameters(),
lr=self.cfg.learn_rate,
momentum=0.949,
weight_decay=0.0005,
)
# 学习率衰减策略
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
milestones=self.cfg.milestones,
gamma=0.1)
# 定义损失函数
self.loss_function = DetectorLoss(device)
# 定义验证函数
self.evaluation = CocoDetectionEvaluator(self.cfg.names, device)
# 数据集加载
val_dataset = TensorDataset(self.cfg.val_txt, self.cfg.input_width, self.cfg.input_height, False)
train_dataset = TensorDataset(self.cfg.train_txt, self.cfg.input_width, self.cfg.input_height, True)
#验证集
self.val_dataloader = torch.utils.data.DataLoader(val_dataset,
batch_size=self.cfg.batch_size,
shuffle=False,
collate_fn=collate_fn,
num_workers=4,
drop_last=False,
persistent_workers=True
)
# 训练集
self.train_dataloader = torch.utils.data.DataLoader(train_dataset,
batch_size=self.cfg.batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=4,
drop_last=True,
persistent_workers=True
)
def train(self):
# 迭代训练
batch_num = 0
print('Starting training for %g epochs...' % self.cfg.end_epoch)
for epoch in range(self.cfg.end_epoch + 1):
self.model.train()
pbar = tqdm(self.train_dataloader)
for imgs, targets in pbar:
# 数据预处理
imgs = imgs.to(device).float() / 255.0
targets = targets.to(device)
# 模型推理
preds = self.model(imgs)
# loss计算
iou, obj, cls, total = self.loss_function(preds, targets)
# 反向传播求解梯度
total.backward()
# 更新模型参数
self.optimizer.step()
self.optimizer.zero_grad()
# 学习率预热
for g in self.optimizer.param_groups:
warmup_num = 5 * len(self.train_dataloader)
if batch_num <= warmup_num:
scale = math.pow(batch_num/warmup_num, 4)
g['lr'] = self.cfg.learn_rate * scale
lr = g["lr"]
# 打印相关训练信息
info = "Epoch:%d LR:%f IOU:%f Obj:%f Cls:%f Total:%f" % (
epoch, lr, iou, obj, cls, total)
pbar.set_description(info)
batch_num += 1
# 模型验证及保存
if epoch % 10 == 0 and epoch > 0:
# 模型评估
self.model.eval()
print("computer mAP...")
mAP05 = self.evaluation.compute_map(self.val_dataloader, self.model)
torch.save(self.model.state_dict(), "checkpoint/weight_AP05:%f_%d-epoch.pth"%(mAP05, epoch))
# 学习率调整
self.scheduler.step()
if __name__ == "__main__":
model = FastestDet()
model.train()