1. 程式人生 > >分治法最近點對問題

分治法最近點對問題

在二維平面上的n個點中,如何快速的找出最近的一對點,就是最近點對問題。

    一種簡單的想法是暴力列舉每兩個點,記錄最小距離,顯然,時間複雜度為O(n^2)。

    在這裡介紹一種時間複雜度為O(nlognlogn)的演算法。其實,這裡用到了分治的思想。將所給平面上n個點的集合S分成兩個子集S1和S2,每個子集中約有n/2個點。然後在每個子集中遞迴地求最接近的點對。在這裡,一個關鍵的問題是如何實現分治法中的合併步驟,即由S1和S2的最接近點對,如何求得原集合S中的最接近點對。如果這兩個點分別在S1和S2中,問題就變得複雜了。

    為了使問題變得簡單,首先考慮一維的情形。此時,S中的n個點退化為x軸上的n個實數x1,x2,...,xn。最接近點對即為這n個實數中相差最小的兩個實數。顯然可以先將點排好序,然後線性掃描就可以了。但我們為了便於推廣到二維的情形,嘗試用分治法解決這個問題。

    假設我們用m點將S分為S1和S2兩個集合,這樣一來,對於所有的p(S1中的點)和q(S2中的點),有p<q。

    遞迴地在S1和S2上找出其最接近點對{p1,p2}和{q1,q2},並設

d = min{ |p1-p2| , |q1-q2| }

    由此易知,S中最接近點對或者是{p1,p2},或者是{q1,q2},或者是某個{q3,p3},如下圖所示。



 

    如果最接近點對是{q3,p3},即|p3-q3|<d,則p3和q3兩者與m的距離都不超過d,且在區間(m-d,d]和(d,m+d]各有且僅有一個點。這樣,就可以線上性時間內實現合併。

    此時,一維情形下的最近點對時間複雜度為O(nlogn)。

    在二維情形下,類似的,利用分治法,但是難點在於如何實現線性的合併?



 

    由上圖可見,形成的寬為2d的帶狀區間,最多可能有n個點,合併時間最壞情況下為n^2,。但是,P1和P2中的點具有以下稀疏的性質,對於P1中的任意一點,P2中的點必定落在一個d X 2d的矩形中,且最多隻需檢查六個點(鴿巢原理)。

    這樣,先將帶狀區間的點按y座標排序,然後線性掃描,這樣合併的時間複雜度為O(nlogn),幾乎為線性了。

    光說不練也不行,經過自己的思考和參考網上的程式,完成了最近點對的程式,並在各OJ上成功AC了。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
using namespace std ;
const int  maxn = 1000001 ;
const int  INF = 1000000001 ;
struct Point
{
	double x , y ;
}point[ maxn ] ;
int n ;
int temp[ maxn ];

bool cmp(const  Point& a , const Point& b )
{
	if( a.x == b.x )
		return  a.y < b.y ;
	else
		return a.x < b.x ;
}

bool cmpy( const int& a , const int& b )
{
	return point[ a ].y < point[ b ].y ;
}

double min( double a , double b )
{
	return a < b ? a : b ;
}

double dist( int i , int j )
{
	return sqrt( (point[ i ].x - point[ j ].x) * ( point[ i ].x - point[ j ].x ) + ( point[ i ].y - point [ j ].y ) * ( point[ i ].y - point[ j ].y ) ) ;
}

double merge( int left , int right )
{
	double d = INF ;
	if( left == right )
		return d ;
	if( left + 1 == right )
		return dist( left , right ) ;

	int mid = ( left + right ) >> 1 ;
	double d1 = merge( left , mid ) ;
	double d2 = merge( mid + 1 , right ) ;
	d = min( d1 , d2 ) ;
	int i , j , k = 0 ;
	for( i = left ; i <= right ; ++i )
	{
		if( fabs( point[ mid ].x - point[ i ].x ) <= d )
			temp[ k++ ] = i ;
	}
	sort( temp , temp + k , cmpy ) ;
	for( i = 0 ; i < k ; ++i )
		for( j = i + 1 ; j < k &&  point[ temp[ j ] ].y - point[ temp[ i ] ].y  < d ; ++j )
		{
			double d3 = dist( temp[ i ] , temp[ j ] ) ; 
			if( d > d3 )
				d = d3 ;
		}
	return d ;
}

int main()
{
	while( scanf( "%d" , &n ) && n )
	{
		for(int  i = 0 ; i < n ; ++i )
		{
			scanf( "%lf%lf" , &point[ i ].x , &point[ i ].y ) ;
		}
		sort( point , point + n , cmp ) ;
		printf( "%.2lf\n" , merge( 0 , n - 1 ) / 2 ) ;
	}
	return 0 ;
}