1. 程式人生 > >[各種面試題] 兩個陣列和的第K大

[各種面試題] 兩個陣列和的第K大

這是谷歌的一道面試題,有兩個陣列A和B,假設有一個數組C,C[i] = A[j] + B[ k ] , 即C中的元素是A和B中兩個元素的和。

讓你求C中第K大的數字。

之前有一篇轉載的用堆來求的方法,因為每出堆一次最多新增兩個元素進來,所以堆的最大容量是 2* k, 所以入堆出堆複雜度是 logK, 最後的複雜度是KlogK。

注意新增的時候要判斷是否已經進過堆。

#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<string>
#include<cstring>
#include<climits>
#include<set>
#include<algorithm>
using namespace std;

struct node
{
	int a,b,val;
	node(int aa,int bb,int v):a(aa),b(bb),val(v){}
	bool operator>(const node& oth)const
	{
		return val>oth.val;
	}
};

int findKthSum(int A[],int m,int B[],int n,int k)
{
	priority_queue<node,vector<node>,greater<node> > Q;
	Q.push(node(0,0,A[0]+B[0]));
	set<pair<int,int> > visited;
	visited.insert(pair<int,int>(0,0));
	while(!Q.empty())
	{
		node t=Q.top();
		Q.pop();
		k--;
		if(k==0)
			return t.val;
		set<pair<int,int> >::iterator it;
		if(t.a+1<m&&(it=visited.find(pair<int,int>(t.a+1,t.b)))==visited.end())
		{
			visited.insert(it,pair<int,int>(t.a+1,t.b));
			Q.push(node(t.a+1,t.b,A[t.a+1]+B[t.b]));
		}
		if(t.b+1<n&&(it=visited.find(pair<int,int>(t.a,t.b+1)))==visited.end())
		{
			visited.insert(it,pair<int,int>(t.a,t.b+1));
			Q.push(node(t.a,t.b+1,A[t.a]+B[t.b+1]));
		}
	}
	return -1;
}
int main()
{
	int m,n,k;
	while(scanf("%d%d%d",&m,&n,&k)!=EOF)
	{
		int* A=new int[m];
		int* B=new int[n];
		for(int i=0;i<m;i++)
			scanf("%d",&A[i]);
		for(int i=0;i<n;i++)
			scanf("%d",&B[i]);
		sort(A,A+m);
		sort(B,B+n);

		int ans=findKthSum(A,m,B,n,k);
		printf("%d\n",ans);
	}
}


然後還有一種比較巧妙的方法,用二分答案然後判定來做。排序後,最小和是A[0]+B[0],最大是A[m-1]+B[n-1] ,所以在這個範圍內二分答案,然後判定。

判定的準則是對於二分到的目標值 Piv, 在C裡需要有至少K個和要小於它,如果滿足,那麼它就是可能的答案,但並不一定是,因為piv並不一定出現在C中,所以要記錄它,然後繼續二分。

計算小於的個數的時候,充分利用A和B有序,可以做到O(m+n),所以加上二分的複雜度,最後複雜度是  O ( logMAXSUM * (m+n ) )

#include<iostream>
#include<cstdio>
#include<vector>
#include<string>
#include<cstring>
#include<climits>
#include<algorithm>
using namespace std;

long long  countSmaller(long long A[],long long m,long long B[],long long n,long long piv)
{
	long long pa=0,pb=n-1;
	long long cnt=0;
	for(;pa<m;pa++)
	{
		if(A[pa]>piv)
			break;
		while(pb>=0&&A[pa]+B[pb]>piv)
			pb--;
		cnt+=pb+1;
	}
	return cnt;
}
long long findKthSum(long long A[],long long m,long long B[],long long n,long long k)
{
	long long  l=A[0]+B[0];
	long long r=A[m-1]+B[n-1];
	long long ans=-1;
	while(l<=r)
	{
		long long mid=l+((r-l)>>1);
		if( countSmaller(A,m,B,n,mid)>=k )
		{
			ans=mid;
			r=mid-1;
		}
		else
			l=mid+1;
	}
	return ans;
}
int main()
{
	long long m,n,k;
	while(scanf("%lld%lld%lld",&m,&n,&k)!=EOF)
	{
		long long* A=new long long[m];
		long long* B=new long long[n];
		for(long long i=0;i<m;i++)
			scanf("%lld",&A[i]);
		for(long long i=0;i<n;i++)
			scanf("%lld",&B[i]);
		sort(A,A+m);
		sort(B,B+n);
		long long ans=findKthSum(A,m,B,n,k);
		printf("%lld\n",ans);
	}
}