137 lines
4.6 KiB
Python
137 lines
4.6 KiB
Python
import os
|
|
import cv2
|
|
import numpy as np
|
|
|
|
import torch
|
|
import random
|
|
|
|
def random_crop(image, boxes):
|
|
height, width, _ = image.shape
|
|
# random crop imgage
|
|
cw, ch = random.randint(int(width * 0.75), width), random.randint(int(height * 0.75), height)
|
|
cx, cy = random.randint(0, width - cw), random.randint(0, height - ch)
|
|
|
|
roi = image[cy:cy + ch, cx:cx + cw]
|
|
roi_h, roi_w, _ = roi.shape
|
|
|
|
output = []
|
|
for box in boxes:
|
|
index, category = box[0], box[1]
|
|
bx, by = box[2] * width, box[3] * height
|
|
bw, bh = box[4] * width, box[5] * height
|
|
|
|
bx, by = (bx - cx)/roi_w, (by - cy)/roi_h
|
|
bw, bh = bw/roi_w, bh/roi_h
|
|
|
|
output.append([index, category, bx, by, bw, bh])
|
|
|
|
output = np.array(output, dtype=float)
|
|
|
|
return roi, output
|
|
|
|
def random_narrow(image, boxes):
|
|
height, width, _ = image.shape
|
|
# random narrow
|
|
cw, ch = random.randint(width, int(width * 1.25)), random.randint(height, int(height * 1.25))
|
|
cx, cy = random.randint(0, cw - width), random.randint(0, ch - height)
|
|
|
|
background = np.ones((ch, cw, 3), np.uint8) * 128
|
|
background[cy:cy + height, cx:cx + width] = image
|
|
|
|
output = []
|
|
for box in boxes:
|
|
index, category = box[0], box[1]
|
|
bx, by = box[2] * width, box[3] * height
|
|
bw, bh = box[4] * width, box[5] * height
|
|
|
|
bx, by = (bx + cx)/cw, (by + cy)/ch
|
|
bw, bh = bw/cw, bh/ch
|
|
|
|
output.append([index, category, bx, by, bw, bh])
|
|
|
|
output = np.array(output, dtype=float)
|
|
|
|
return background, output
|
|
|
|
def collate_fn(batch):
|
|
img, label = zip(*batch)
|
|
for i, l in enumerate(label):
|
|
if l.shape[0] > 0:
|
|
l[:, 0] = i
|
|
return torch.stack(img), torch.cat(label, 0)
|
|
|
|
class TensorDataset():
|
|
def __init__(self, path, img_width, img_height, aug=False):
|
|
assert os.path.exists(path), "%s文件路径错误或不存在" % path
|
|
|
|
self.aug = aug
|
|
self.path = path
|
|
self.data_list = []
|
|
self.img_width = img_width
|
|
self.img_height = img_height
|
|
self.img_formats = ['bmp', 'jpg', 'jpeg', 'png']
|
|
|
|
# 数据检查
|
|
with open(self.path, 'r') as f:
|
|
for line in f.readlines():
|
|
data_path = line.strip()
|
|
if os.path.exists(data_path):
|
|
img_type = data_path.split(".")[-1]
|
|
if img_type not in self.img_formats:
|
|
raise Exception("img type error:%s" % img_type)
|
|
else:
|
|
self.data_list.append(data_path)
|
|
else:
|
|
raise Exception("%s is not exist" % data_path)
|
|
|
|
def __getitem__(self, index):
|
|
img_path = self.data_list[index]
|
|
label_path = img_path.split(".")[0] + ".txt"
|
|
|
|
# 加载图片
|
|
img = cv2.imread(img_path)
|
|
# 加载label文件
|
|
if os.path.exists(label_path):
|
|
label = []
|
|
with open(label_path, 'r') as f:
|
|
for line in f.readlines():
|
|
l = line.strip().split(" ")
|
|
label.append([0, l[0], l[1], l[2], l[3], l[4]])
|
|
label = np.array(label, dtype=np.float32)
|
|
|
|
if label.shape[0]:
|
|
assert label.shape[1] == 6, '> 5 label columns: %s' % label_path
|
|
#assert (label >= 0).all(), 'negative labels: %s'%label_path
|
|
#assert (label[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s'%label_path
|
|
else:
|
|
raise Exception("%s is not exist" % label_path)
|
|
|
|
# 是否进行数据增强
|
|
if self.aug:
|
|
if random.randint(1, 10) % 2 == 0:
|
|
img, label = random_narrow(img, label)
|
|
else:
|
|
img, label = random_crop(img, label)
|
|
|
|
img = cv2.resize(img, (self.img_width, self.img_height), interpolation = cv2.INTER_LINEAR)
|
|
|
|
# debug
|
|
# for box in label:
|
|
# bx, by, bw, bh = box[2], box[3], box[4], box[5]
|
|
# x1, y1 = int((bx - 0.5 * bw) * self.img_width), int((by - 0.5 * bh) * self.img_height)
|
|
# x2, y2 = int((bx + 0.5 * bw) * self.img_width), int((by + 0.5 * bh) * self.img_height)
|
|
# cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
# cv2.imwrite("debug.jpg", img)
|
|
|
|
img = img.transpose(2,0,1)
|
|
|
|
return torch.from_numpy(img), torch.from_numpy(label)
|
|
|
|
def __len__(self):
|
|
return len(self.data_list)
|
|
|
|
if __name__ == "__main__":
|
|
data = TensorDataset("/home/xuehao/Desktop/TMP/pytorch-yolo/widerface/train.txt")
|
|
img, label = data.__getitem__(0)
|
|
print(img.shape)
|
|
print(label.shape) |