使用C++標準庫sort自定義比較函式導致死迴圈問題
永遠讓比較函式對相等的值返回false(來自Effective
C++)
---------------------------------------------------------------------------------------------------------
轉自http://www.cnblogs.com/yuanzz/p/3735213.html
最近寫程式碼,無意中發現了一個坑,關於自定義比較函式的stl sort函式的坑,於是記錄下來。
先貼程式碼:
1 #include <iostream> 2 #include <vector> 3#include <algorithm> 4 5 struct finder 6 { 7 bool operator()(int first, int second){return first <= second;} 8 } my_finder; 9 10 int main(int argc, char** argv) 11 { 12 int value = atoi(argv[1]); 13 int num = atoi(argv[2]); 14 std::vector<int> vecTest;15 for(int i=0; i!=num; ++i) 16 vecTest.push_back(value); 17 18 std::sort(vecTest.begin(), vecTest.end(), my_finder); 19 for(int i=0; i!=vecTest.size(); ++i) 20 std::cout<<vecTest[i]<<'\t'; 21 std::cout<<std::endl; 22 23return 0; 24 }
這段程式碼看上去好好的,實際上卻有core的可能。
且看圖:
敏思苦想很久,也想不出為啥會core,後來查了資料,才發現了問題所在,現在通過原始碼分析一下原因。
於是定位到sort函式:
1 template<typename _RandomAccessIterator, typename _Compare> 2 inline void 3 sort(_RandomAccessIterator __first, _RandomAccessIterator __last, 4 _Compare __comp) 5 { 6 typedef typename iterator_traits<_RandomAccessIterator>::value_type 7 _ValueType; 8 9 // concept requirements 10 __glibcxx_function_requires(_Mutable_RandomAccessIteratorConcept< 11 _RandomAccessIterator>) 12 __glibcxx_function_requires(_BinaryPredicateConcept<_Compare, _ValueType, 13 _ValueType>) 14 __glibcxx_requires_valid_range(__first, __last); 15 16 if (__first != __last) 17 { 18 std::__introsort_loop(__first, __last, __lg(__last - __first) * 2, 19 __comp); 20 std::__final_insertion_sort(__first, __last, __comp); 21 } 22 }
這是stl_algo.h中的sort函式,且忽略10-14行的引數檢查,實際上sort函式先是用了introsort(內省排序,http://en.wikipedia.org/wiki/Introsort),然後採用了insertsort(插入排序)。
1、我們先來分析內省排序吧。
先來看看__introsort_loop的函式原型
1 template<typename _RandomAccessIterator, typename _Size> 2 void 3 __introsort_loop(_RandomAccessIterator __first, 4 _RandomAccessIterator __last, 5 _Size __depth_limit) 6 { 7 typedef typename iterator_traits<_RandomAccessIterator>::value_type 8 _ValueType; 9 10 while (__last - __first > int(_S_threshold)) 11 { 12 if (__depth_limit == 0) 13 { 14 std::partial_sort(__first, __last, __last); 15 return; 16 } 17 --__depth_limit; 18 _RandomAccessIterator __cut = 19 std::__unguarded_partition(__first, __last, 20 _ValueType(std::__median(*__first, 21 *(__first 22 + (__last 23 - __first) 24 / 2), 25 *(__last 26 - 1)))); 27 std::__introsort_loop(__cut, __last, __depth_limit); 28 __last = __cut; 29 } 30 }
如果__last - __first > int(_S_threshold)的時候,就開始迴圈了。
關於_S_threshold的定義:
1 enum { _S_threshold = 16 };
好吧,是寫死的,為【16】
注意,我為什麼把16標紅,這就是坑開始的地方了。如果元素小於16(第10行),就直接略過,開始了:
std::__final_insertion_sort(__first, __last, __comp); // 本次不對其進行分析
我們繼續往下走,__depth_limit哪來的呢。看程式碼:
1 template<typename _Size> 2 inline _Size 3 __lg(_Size __n) 4 { 5 _Size __k; 6 for (__k = 0; __n != 1; __n >>= 1) 7 ++__k; 8 return __k; 9 }
還記得sort呼叫introsort嗎?
std::__introsort_loop(__first, __last, __lg(__last - __first) * 2, __comp);
當__depth_limit != 0時,則開始了introsort遞迴,而真正影響它的是__unguarded_partition函式。
__unguarded_partition函式原型:
1 template<typename _RandomAccessIterator, typename _Tp, typename _Compare> 2 _RandomAccessIterator 3 __unguarded_partition(_RandomAccessIterator __first, 4 _RandomAccessIterator __last, 5 _Tp __pivot, _Compare __comp) 6 { 7 while (true) 8 { 9 while (__comp(*__first, __pivot)) 10 ++__first; 11 --__last; 12 while (__comp(__pivot, *__last)) 13 --__last; 14 if (!(__first < __last)) 15 return __first; 16 std::iter_swap(__first, __last); 17 ++__first; 18 } 19 }
好吧,終於要找到原因了,就是這個__pivot了。
還記得我們自定義的__comp函式嗎?
struct finder
{
bool operator()(int first, int second){return first <= second;}
} my_finder;
當*__first == __pivot的時候,返回了true,然後,執行了16行的元素交換。
那如果像我測試的程式碼那樣,所有元素都相等呢?豈不就是走進了死迴圈,等著的就可能是越界了。
為什麼__unguarded_partition不檢查邊界呢?有人分析稱是為了效率,再大資料排序的時候,每次都需要校驗邊界,確實也是個很大開銷。
後來看人總結:永遠讓比較函式對相等的值返回false(來自Effective C++)
附上正確程式碼:
1 #include <iostream> 2 #include <vector> 3 #include <algorithm> 4 5 struct finder 6 { 7 //永遠讓比較函式中對相等值返回false 8 bool operator()(int first, int second){return first < second;} 9 } my_finder; 10 11 int main(int argc, char** argv) 12 { 13 int value = atoi(argv[1]); 14 int num = atoi(argv[2]); 15 std::vector<int> vecTest; 16 for(int i=0; i!=num; ++i) 17 vecTest.push_back(value); 18 19 std::sort(vecTest.begin(), vecTest.end(), my_finder); 20 for(int i=0; i!=vecTest.size(); ++i) 21 std::cout<<vecTest[i]<<'\t'; 22 std::cout<<std::endl; 23 24 return 0; 25 }
結果: