- 目標檢測中的資料增強需要做兩方面,首先是影象本身的修改,另外需要修改標註檔案中的標註框。所以自然而然的在進行資料增強時,就需要分兩種:
- 一種是隻修改影象而不需要修改其對應的標註資訊,例如修改色調,加椒鹽噪聲,隨機擦除等等
- 一種是急需要修改影象又需要修改標註資訊,甚至生成新的標註資訊,比如mosaic,旋轉,映象等等。
- 下面就講一下應用較為廣泛的資料增強方式及其python的實現。在講之前為了程式碼的完整性,先將影象資料和載入標註檔案獲取標註框資訊
- 在實際場景下我們拿到的標註檔案時各種各樣的,但標註檔案必定包含了兩個方面的資訊,一個是對應圖片中目標的類別。另外一個就是對應圖片中目標的左上角座標和右下角座標,我這次拿到的資料是使用
class Data: """ 封裝影象和其對應的標註資訊 """ def __init__(self, name, boxes=None, img=None): if boxes is None: boxes = [] self.name = name self.boxes = boxes self.img = img self.shape = img.shape def append_box(self, box): """ 向這個資料中新增標註框資訊 :param box: :return: """ self.boxes.append(box) def set_name(self, name): self.name = name def set_img(self, img): self.img = img class Box: """ box類包含兩個欄位,一個是這個box的類別,一個是這個box的座標資訊[xmin,ymin,xmax,ymax] """ def __init__(self, label, cod): self.label = label self.cod = cod def get_label(self): return self.label def get_cod(self): return self.cod
def load_data(img_path, xml_path, flog_path, save_path): """ 載入所有的影象和其標註,以標註檔案為準,有的影象並沒有標註,然後對其進行資料增強並儲存 :param flog_path: 雲霧影象所在目錄 :param save_path: 資料集儲存的根目錄 :param img_path: 影象檔案路徑, tif格式 :param xml_path: 標註檔案路徑,VOC格式 :return: """ annotations = os.listdir(xml_path) data_list = [] for annotation in annotations: xml_file = open(os.path.join(xml_path, annotation), 'br') # 如果是解析標準的VOC標註此處更換函式voc2data boxes = load_annotations(xml_file) name = annotation.split(".")[0] img = cv2.imread(os.path.join(img_path, name + ".tif")) data = Data(name=name, boxes=boxes, img=img) data_list.append(data) def load_annotations(xml_file): """ 解析標註檔案中的資訊,這個標註檔案格式與標準的VOC不同,是我個人拿到的資料形式 :param xml_file:標註檔案 :return: 標註檔案中的座標資訊和所屬類別,對應的影象名稱等等資訊 """ boxes = [] # 解析檔案 tree = ET.parse(xml_file) # 獲取根節點 root = tree.getroot() # 獲取目標節點 objects = root.find("objects") for obj in objects.iter("object"): # 找到這個標註框對應的標籤 label = obj.find("possibleresult").find("name").text x = [] y = [] # 找到這個框的座標 for temp in obj.find("points").iter("point"): xy = temp.text.split(",") x.append(int(float(xy[0]))) y.append(int(float(xy[1]))) cod = [min(x), min(y), max(x), max(y)] box = Box(label=label, cod=cod) boxes.append(box) return boxes def voc2data(xml_file): """ 解析標準VOC格式的標註檔案獲取其對應的box和所屬的類別 :param xml_file: :return: """ boxes = [] tree = ET.parse(xml_file) root = tree.getroot() for obj in root.iter("object"): label = obj.find("name").text bndbox = obj.find("bndbox") xmin = int(bndbox.find("xmin").text) ymin = int(bndbox.find("ymin").text) xmax = int(bndbox.find("xmax").text) ymax = int(bndbox.find("ymax").text) boxes.append(Box(label=label, cod=[xmin, ymin, xmax, ymax])) return boxes
def rot90(data): """ 將data中的影象做一個旋轉90度的處理,將其對應的標註資訊也做對應的處理 :param data: :return: """ img = data.img img_new = np.rot90(img) boxes_new = [] boxes = data.boxes for box in boxes: # 旋轉90度,現在的x_min就是原來的y_min cod = box.cod x_min = cod[1] # 現在的y_min就是影象的寬減去原來的x_max y_min = data.shape[1] - cod[2] # 現在的x_max就是原來的y_max x_max = cod[3] # 現在的y_max就是影象的寬減去原來的x_min y_max = data.shape[1] - cod[0] box_new = Box(box.label, [x_min, y_min, x_max, y_max]) boxes_new.append(box_new) data_new = Data(data.name + "_rot90", boxes_new, img_new) return data_new
def flip_vertical(data): """ 將data中的影象水平旋轉 :param data: :return: """ img = data.img.copy() img_new = cv2.flip(img, 0) boxes = data.boxes boxes_new = [] for box in boxes: cod = box.cod y_min = data.shape[0] - cod[3] y_max = data.shape[0] - cod[1] x_min = cod[0] x_max = cod[2] cod_new = [x_min, y_min, x_max, y_max] box_new = Box(box.label, cod_new) boxes_new.append(box_new) data_new = Data(data.name + "_flip", boxes_new, img_new) return data_new
複製小目標這種資料增強方式主要針對小目標檢測,通過將資料中面積較小的目標進行復制貼上來達到使小目標數量增多的目的。理論依據出自論文《Augmentation for small object detection》
複製小目標首先要確認小目標,給什麼樣 的目標是小目標定一個標準
def paste(data): """ 拷貝影象中的小目標,根據長寬的統計,劃定面積小於300*300的為小目標 :param data: :return: """ change = False img_new = data.img.copy() boxes = data.boxes.copy() for index, box in enumerate(data.boxes): cod = box.cod length = cod[3] - cod[1] width = cod[2] - cod[0] area = length * width if area < 20000: # 如果有需要複製的目標,儲存新的標註檔案和圖片 change = True # 判定為小目標,把目標部分截取出來 cropped_img = data.img[cod[1]:cod[3], cod[0]:cod[2]] # 獲取隨機位置 copy_cods = getRandomCod(data.shape, boxes, 5, width, length) for cod in copy_cods: img_new[cod[1]:cod[3], cod[0]:cod[2]] = cropped_img boxes.append(Box(box.label, cod)) if change: data_new = Data(name=data.name + "_small", img=img_new, boxes=boxes) return data_new else: return None
檢測。 -
def check_overlap(cod_1, cod_2): """ 判斷兩個標註框是否重疊 :param cod_1: :param cod_2: :return: true or false """ # 計算交集 x_min = max(cod_1[0], cod_2[0]) y_min = max(cod_1[1], cod_2[1]) x_max = min(cod_1[2], cod_2[2]) y_max = min(cod_1[3], cod_2[3]) x = max(0, x_max - x_min) y = max(0, y_max - y_min) if x * y > 0: return True else: return False
def getRandomCod(shape, boxes, num, width, height): """ 根據傳入的影象尺寸和原始標註框的座標資訊來獲取num個隨機的標註框左上角座標 :param height: :param width: :param num: 生成的隨機標註框個數 :param shape: 影象本身的尺寸 :param boxes: 原始標註框的座標資訊 :return: 返回n個cod回去 """ ret = [] for i in range(num): while True: x = random.randint(0, shape[1]) y = random.randint(0, shape[0]) cod_tmp = [x, y, x + width, y + height] # 先檢驗是否越界 if check_overSize(cod_tmp, shape): # 如果越界了,就重新生成座標 continue overlap = False for box in boxes: if check_overlap(box.cod, cod_tmp): # 如果相交的話,那麼就直接break overlap = True break else: continue if not overlap and len(ret) != 0: for cod in ret: if check_overlap(cod, cod_tmp): overlap = True break if not overlap: ret.append(cod_tmp) break else: continue return ret
中的高斯濾波函式即可實現影象的效果,對應的標註檔案位置不變。 -
def gaussian_blur(data): """ 高斯模糊 :param data: :return: """ img_new = data.img.copy() img_new = cv2.GaussianBlur(img_new, (11, 11), 0) data_new = Data(name=data.name + "_gaussian", boxes=data.boxes.copy(), img=img_new) return data_new
def add_flog(data, flog_list): """ :param data: :param flog_list:儲存雲霧影象的絕對路徑列表 :return: """ img_src = data.img.copy() img_flog = cv2.imread(random.choice(flog_list)) if random.randint(0, 10) > 5: img_flog = np.rot90(img_flog) img_flog = cv2.resize(img_flog, img_src.shape[:2][::-1]) img_new = cv2.addWeighted(img_src, 0.6, img_flog, 0.4, 0) data_new = Data(name=data.name + "_flog", boxes=data.boxes.copy(), img=img_new) return data_new
def shrink(data): """ 放大僅有小目標的影象 :param data: :return: """ # num用於檢測整張影象中是否僅有大目標,仍然是以20000為界 num = 0 for box in data.boxes: cod = box.cod area = (cod[3] - cod[1]) * (cod[2] - cod[0]) if area < 20000: num += 1 if num == 0: plate = np.zeros_like(data.img) data_new = resize(data, shape=(plate.shape[1] // 2, plate.shape[0] // 2)) # 確保縮小影象在中間位置 shape = data_new.shape plate[shape[0] // 2 :shape[0] // 2 + shape[0], shape[1] // 2: shape[1] + shape[1] // 2] = data_new.img.copy() data_new.set_img(plate) data_new.set_name(data.name + "_shrink") return data_new else: return None
技術是YOLOv4的技巧之一,也是一種用於增強網路對小目標的泛化能力所使用的手段,見論文YOLOv4: Optimal Speed and Accuracy of Object Detection,思想就是隨機讀取四張影象,然後將這些影象隨機進行資料增強,如翻轉,旋轉等等,然後進行組合,原文中四張影象的尺寸是不一致的。這裡為了實現起來方便拼接,對四張影象的尺寸都做了統一處理,處理為長寬皆為600,然後進行組合。 -
def mosaic(data_list): """ mosaic技術需要四張影象才能做 :param data_list: :return: """ img_1 = np.vstack((data_list[0].img.copy(), data_list[1].img.copy())) img_2 = np.vstack((data_list[2].img.copy(), data_list[3].img.copy())) img_new = np.hstack((img_1, img_2)) boxes_new = data_list[0].boxes.copy() for box in data_list[2].boxes: cod = box.cod x_min = 600 + cod[0] y_min = cod[1] x_max = 600 + cod[2] y_max = cod[3] boxes_new.append(Box(label=box.label, cod=[x_min, y_min, x_max, y_max])) for box in data_list[1].boxes: cod = box.cod x_min = cod[0] y_min = cod[1] + 600 x_max = cod[2] y_max = cod[3] + 600 boxes_new.append(Box(label=box.label, cod=[x_min, y_min, x_max, y_max])) for box in data_list[3].boxes: cod = box.cod boxes_new.append(Box(label=box.label, cod=[x + 600 for x in cod])) data_new = Data(name=data_list[0].name + "_mosaic", img=img_new, boxes=boxes_new) return data_new
def do_mosaic(data_list, result): temp_list = [] data_list = random.sample(data_list, len(data_list)) for data in data_list: if len(temp_list) < 4: temp_list.append(resize(data)) if len(temp_list) == 4: new_data = mosaic(temp_list) result.append(new_data) temp_list = temp_list[-3:] return result
函式的實現,這是一個工具函式,多個處理都用到了這個函式 -
def resize(data, shape=None): """ 將影象縮放至600 * 600,方便拼接, :param shape: :param data: :return: """ if shape is None: shape = (600, 600) img_new = cv2.resize(data.img.copy(), shape) boxes = [] for box in data.boxes: cod = box.cod x_min = int(shape[0] / data.shape[1] * cod[0]) y_min = int(shape[0] / data.shape[0] * cod[1]) x_max = int(shape[1] / data.shape[1] * cod[2]) y_max = int(shape[1] / data.shape[1] * cod[3]) box_tmp = Box(label=box.label, cod=[x_min, y_min, x_max, y_max]) boxes.append(box_tmp) data_new = Data(name=data.name, boxes=boxes, img=img_new) return data_new
所有的資料之後,需要對這個列表進行處理, 以上的每個函式都是對data
迴圈,所以用一個函式去對以上的每一種處理進行封裝。 -
def do(func, data_list, ret, ext=None): if ext is None: for data in data_list: new_data = func(data) if new_data is not None: ret.append(new_data) else: for data in data_list: new_data = func(data, ext) if new_data is not None: ret.append(new_data)
flog_list = os.listdir(flog_path) flog_list = [os.path.join(flog_path, x) for x in flog_list] ret = [] # 旋轉,加霧,高斯模糊和反轉使用並行 t1 = threading.Thread(target=do, args=(rot90, data_list, ret)) t1.start() t1.join() t2 = threading.Thread(target=do, args=(flip_vertical, data_list, ret)) t2.start() t2.join() t3 = threading.Thread(target=do, args=(gaussian_blur, data_list, ret)) t3.start() t3.join() t4 = threading.Thread(target=do, args=(add_flog, data_list, ret, flog_list)) t4.start() t4.join() # 小目標複製, 馬賽克和大目標放大並行 data_list += ret ret = [] print("processing 1") t5 = threading.Thread(target=do, args=(copy_small, data_list, ret)) t5.start() t5.join() t6 = threading.Thread(target=do, args=(shrink, data_list, ret)) t6.start() t6.join() t7 = threading.Thread(target=do_mosaic, args=(data_list,ret)) t7.start() t7.join() data_list += ret save(data_list, save_path)
def save(data_list, save_path): annotation_path = os.path.join(save_path, "Annotations") img_path = os.path.join(save_path, "JPEGImages") if not os.path.exists(annotation_path): os.mkdir(annotation_path) if not os.path.exists(img_path): os.mkdir(img_path) for data in data_list: img = data.img name = os.path.join(img_path, data.name +".tif") cv2.imwrite(name, img) print("saving image %s to %s" %(data.name +".tif", img_path)) tree = data2xml(data) tree.write(os.path.join(annotation_path, data.name +".xml"), "utf-8", True) print("saving annotation %s to %s" % (data.name +".xml", annotation_path))
物件 -
def data2xml(data): """ 將data中的資料解析成xml格式的tree :param data: :return: """ root = ET.Element("annotation") folder = ET.SubElement(root, "folder") folder.text = "VOC" filename = ET.SubElement(root, "filename") filename.text = data.name + ".tif" size = ET.SubElement(root, "size") width = ET.SubElement(size, "width") height = ET.SubElement(size, "height") depth = ET.SubElement(size, "depth") width.text = str(data.shape[1]) height.text = str(data.shape[0]) depth.text = str(data.shape[2]) source = ET.SubElement(root, "source") database = ET.SubElement(source, "database") database.text = "高分軟體大賽" for box in data.boxes: obj = ET.SubElement(root, "object") name = ET.SubElement(obj, "name") name.text = box.label bndbox = ET.SubElement(obj, "bndbox") xmin = ET.SubElement(bndbox, "xmin") ymin = ET.SubElement(bndbox, "ymin") xmax = ET.SubElement(bndbox, "xmax") ymax = ET.SubElement(bndbox, "ymax") xmin.text = str(box.cod[0]) ymin.text = str(box.cod[1]) xmax.text = str(box.cod[2]) ymax.text = str(box.cod[3]) truncated = ET.SubElement(obj, "truncated") truncated.text = '0' difficult = ET.SubElement(obj, "difficult") difficult.text = '0' pretty_xml(element=root, indent='\t', newline='\n') tree = ET.ElementTree(root) return tree
def random_check(save_path, num): """ 隨機視覺化檢查save_path下num個樣本 :param save_path: :param num: :return: """ temp_path = os.path.join(save_path, "tmp") if not os.path.exists(temp_path): os.mkdir(temp_path) font = cv2.FONT_HERSHEY_SIMPLEX annotation_path = os.path.join(save_path, "Annotations") img_path = os.path.join(save_path, "JPEGImages") xml_list = os.listdir(annotation_path) xml_list = random.sample(xml_list, num) xml_list = [os.path.join(annotation_path, xml) for xml in xml_list] for xml in xml_list: tree = ET.parse(xml) root = tree.getroot() img = None img_name = '' try: img_name = root.find("filename").text img = cv2.imread(os.path.join(img_path, img_name)) except IOError: RuntimeError("no such file") except TypeError: RuntimeError("xml tag error") for obj in root.iter("object"): name = obj.find("name").text box = obj.find("bndbox") xmin = int(box.find("xmin").text) ymin = int(box.find("ymin").text) xmax = int(box.find("xmax").text) ymax = int(box.find("ymax").text) img = cv2.putText(img, name, (xmin, ymin), font, 1.2, (255,255,255), 2) cv2.rectangle(img,(xmin, ymin), (xmax, ymax), (255,0,0),2) cv2.imwrite(os.path.join(temp_path, img_name), img)