1. 程式人生 > 其它 >關於SIFT,GIFT在旋轉不變性上的對比實驗

關於SIFT,GIFT在旋轉不變性上的對比實驗

目錄

關於SIFT,GIFT在旋轉不變性上的對比實驗

這篇文章不討論SIFTGIFT的實現原理,只從最終匹配結果的準確度上來進行對比。

回顧

先簡單回顧一下,兩種方法略有不同。SIFT是檢測出各個特徵點,並得到特徵點描述子。

GIFT是先使用其它演算法(SIFT,SuperPoint,Harris)等方法得到特徵點,然後將原圖加檢測得到的特徵點輸入網路得到特徵點描述子。

用特徵點描述子進行匹配,我們可以得到如下的試驗結果:

原圖1 原圖2 GIFT SIFT

準確率測試

但是仔細觀察GIFT的匹配結果,肉眼就能看出很多錯誤的匹配。(SIFT,GIFT都使用的是FLANN)

對此,我設計了一個方法來估算一下二者匹配的準確率,也能側面反映出描述子的健壯性。

思路:一張圖片\(img\),將其旋轉90°,得到\(img_{}'\),然後分別使用SIFT,GIFT進行特徵點匹配,由於\(ima\)\(ima_{}'\)的畫素點存在一個旋轉矩陣的對應關係,我們可以據此來大體估算兩種方法的準確率。

實驗環境:pycharm, cpu, opencv

  • Test1

我們看一下在這兩張圖片上的匹配結果:

SIFT SIFT points+GIFT SuperPoint points+GIFT
特徵點匹配數目 5184 1506 630
正確匹配數目 5038 876 358
準確率 0.9718 0.5817 0.5682

特徵點匹配對應的就是附錄程式碼裡的good_mathes,正確匹配數目是使用兩張圖片的旋轉矩陣計算而來,也就是,\(img\)裡的一個特徵點座標\(p\),對應\(img_{}'\)裡的\(p_{}'\), 有$p_{}' = Mp $。詳情請檢視附錄程式碼。

  • Test2

SIFT SIFT points+GIFT SuperPoint points+GIFT
特徵點匹配數目 594 256 216
正確匹配數目 587 155 127
準確率 0.9882 0.6055 0.5880

總結

雖然GIFT具有一定的旋轉不變性,但是效果不是很好,使用特徵描述子匹配出的錯點比較多,特徵描述子的健壯性也不如SIFT通用。

最後要說的是,受限於筆者目前的知識水平和技術水平,不排除在復現GIFT原始碼時出現概念性錯誤,或者是由於粗心導致的細節疏忽。所以此篇文章僅供參考,如果您有新的想法或者建議,歡迎在評論區指出或者傳送郵件([email protected])討論。實驗程式碼請看附錄。

核心程式碼

GIFT_Test.py 程式碼修改自GIFT論文作者在GitHub釋出的demo.ipynb

備註:

  • 復現此程式碼時需要修改test_acc下的M矩陣,因為實驗所用的旋轉變更了座標系,因此每張圖片的旋轉矩陣都是不同的

GIFT_Test.py

import numpy as np
import torch
from skimage.io import imread
from network.wrapper import GIFTDescriptor
from train.evaluation import EvaluationWrapper, Matcher
from utils.superpoint_utils import SuperPointWrapper, SuperPointDescriptor
import matplotlib.pyplot as plt
import cv2

MIN_MATCH_COUNT = 10


def test_acc(good, kps0, kps1):

    # 給特徵點末尾新增一列變為其次座標
    points1 = np.insert(kps0, 2, values=np.ones(kps0.shape[0]), axis=1)
    points2 = np.insert(kps1, 2, values=np.ones(kps1.shape[0]), axis=1)

    # 旋轉矩陣
    M = np.array([[0, -1, 528],
                  [1, 0, 0],
                  [0, 0, 1]], dtype=np.float32)
    # 圖一是圖二旋轉 90°得到,因此畫素座標乘以一個旋轉矩陣即可

    count = 0
    for i in good:
        pts_after_rotate = M.dot(points1[i.queryIdx].T)
        if (pts_after_rotate - points2[i.trainIdx]).sum() < 3:
            count += 1

    print(len(good))
    print(count)


if __name__ == '__main__':

    detector = SuperPointWrapper(EvaluationWrapper.load_cfg('configs/eval/superpoint_det.yaml'))
    gift_desc = GIFTDescriptor(EvaluationWrapper.load_cfg('configs/eval/gift_pretrain_desc.yaml'))
    superpoint_desc = SuperPointDescriptor(EvaluationWrapper.load_cfg('configs/eval/superpoint_desc.yaml'))
    # matcher = Matcher(EvaluationWrapper.load_cfg('configs/eval/match_v0.yaml'))


    img0 = imread("demo/woman.jpg")
    img1 = imread("demo/woman_ro.jpg")


    # 此處可以通過修改註釋來切換檢測器
    
    # to use superpoint detector
    # kps0, _ = detector(img0)
    # kps1, _ = detector(img1)


    # to use SIFT detector
    sift = cv2.SIFT_create()
    kps0, _ = sift.detectAndCompute(img0, None)
    kps1, _ = sift.detectAndCompute(img1, None)

    kps0 = np.array([[i.pt[0], i.pt[1]] for i in kps0])
    kps1 = np.array([[i.pt[0], i.pt[1]] for i in kps1])
    # -----------------------

	
    # 得到GIFT特徵描述子
    des1 = gift_desc(img0, kps0)
    des2 = gift_desc(img1, kps1)

	
    #描述子匹配
    FLANN_INDEX_KDTREE = 0
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    search_params = dict(checks=50)
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    matches1 = flann.knnMatch(des1, des2, k=2)
    # matches2 = flann.knnMatch(des2, des1, k=1)

    ratio_thresh = 0.98 # 此處設定為0.7,0.8時就一個都匹配不上了。
    good_matches = []
    for m, n in matches1:
        if m.distance < ratio_thresh * n.distance:
            good_matches.append(m)

    test_acc(good_matches, kps0, kps1)

	
    
    kps0 = [cv2.KeyPoint(kps0[i][0], kps0[i][1], 1)
                for i in range(kps0.shape[0])]
    kps1 = [cv2.KeyPoint(kps1[i][0], kps1[i][1], 1)
                     for i in range(kps1.shape[0])]

    img_matches = np.empty(
        (max(img0.shape[0], img1.shape[0]), img0.shape[1] + img1.shape[1], 3),
        dtype=np.uint8)

    cv2.drawMatches(img0, kps0, img1, kps1, good_matches, img_matches,
                    flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)

    cv2.namedWindow("Good Matches of GIFT", 0)
    cv2.resizeWindow("Good Matches of GIFT", 1024, 1024)
    cv2.imshow('Good Matches of GIFT', img_matches)
    cv2.waitKey()

SIFT_Test.py

from __future__ import print_function
import cv2 as cv
import numpy as np


pic1 = "data/woman.jpg"
pic2 = "data/woman_ro.jpg"


def test_acc(good, kps0, kps1):

    count = 0
    # 給特徵點末尾新增一列變為其次座標
    points1 = np.insert(kps0, 2, values=np.ones(kps0.shape[0]), axis=1)
    points2 = np.insert(kps1, 2, values=np.ones(kps1.shape[0]), axis=1)

    # 旋轉矩陣
    M = np.array([[0, -1, 528],
                  [1, 0, 0],
                  [0, 0, 1]], dtype=np.float32)
    # 圖一是圖二旋轉 90°得到,因此畫素座標乘以一個旋轉矩陣即可

    count = 0
    for i in good:
        pts_after_rotate = M.dot(points1[i.queryIdx].T)
        if (pts_after_rotate-points2[i.trainIdx]).sum() < 3:
            count += 1
    print(len(good))
    print(count)



img_object = cv.imread(pic1)
img_scene = cv.imread(pic2)
if img_object is None or img_scene is None:
    print('Could not open or find the images!')
    exit(0)

#-- Step 1: Detect the keypoints using SURF Detector, compute the descriptors
sift = cv.SIFT_create()

keypoints_obj, descriptors_obj = sift.detectAndCompute(img_object,None)
keypoints_scene, descriptors_scene = sift.detectAndCompute(img_scene,None)

#-- Step 2: Matching descriptor vectors with a FLANN based matcher
# Since SURF is a floating-point descriptor NORM_L2 is used
matcher = cv.DescriptorMatcher_create(cv.DescriptorMatcher_FLANNBASED)
knn_matches = matcher.knnMatch(descriptors_obj, descriptors_scene, 2)

#-- Filter matches using the Lowe's ratio test
ratio_thresh = 0.75
good_matches = []
for m,n in knn_matches:
    if m.distance < ratio_thresh * n.distance:
        good_matches.append(m)

print("The number of keypoints in image1 is", len(keypoints_obj))
print("The number of keypoints in image2 is", len(keypoints_scene))

#
kp1 = np.array([[i.pt[0], i.pt[1]] for i in keypoints_obj], dtype=np.int32)
kp2 = np.array([[i.pt[0], i.pt[1]] for i in keypoints_scene], dtype=np.int32)

test_acc(good_matches, kp1, kp2)

#-- Draw matches
img_matches = np.empty((max(img_object.shape[0], img_scene.shape[0]), img_object.shape[1]+img_scene.shape[1], 3), dtype=np.uint8)
cv.drawMatches(img_object, keypoints_obj, img_scene, keypoints_scene, good_matches, img_matches, flags=cv.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)

cv.namedWindow("Good Matches of SIFT", 0)
cv.resizeWindow("Good Matches of SIFT", 1024, 1024)
cv.imshow('Good Matches of SIFT', img_matches)
cv.waitKey()