1. 程式人生 > >[BZOJ4310]跳蚤-字尾陣列-二分答案

[BZOJ4310]跳蚤-字尾陣列-二分答案

跳蚤

Description

很久很久以前,森林裡住著一群跳蚤。一天,跳蚤國王得到了一個神祕的字串,它想進行研究。首先,他會把串分成不超過 k 個子串,然後對於每個子串 S,他會從S的所有子串中選擇字典序最大的那一個,並在選出來的 k 個子串中選擇字典序最大的那一個。他稱其為“魔力串”。現在他想找一個最優的分法讓“魔力串”字典序最小。

Input

第一行一個整數 k,K<=15
接下來一個長度不超過 10^5 的字串 S。

Output

輸出一行,表示字典序最小的“魔力串”。

Sample Input

2
ababa

Sample Output

ba

//解釋:
分成aba和ba兩個串,其中字典序最大的子串為ba

頹廢十多天後的第一篇部落格-_-
另外這題題面寫錯了,是最小化“魔力串”的字典序~

思路:

使最大值最小,可以考慮二分答案,即“魔力串”的字典序編號。

根據字尾陣列的特性,字串本質不同的子串數量為sum=nsa[i]height[i]+1
(式子中的”+1”與個人的字尾陣列實現有關)
這個式子的含義是,排名相鄰的兩個字尾中,不同的字首的方案數。

於是,二分的區間即為1sum之間。
同時,根據這個式子,可以快速找出某個確定字典序排名對應的子串。

考慮二分一個答案後如何檢驗。
顯然,方法為貪心地分段,若分段段數不超過k,則當前答案合法。
由於字尾陣列考慮的是字尾的排名,那麼從後往前考慮,每次在最前端新增一個字元時,判斷當前段與二分的答案之間的排名先後順序,若大於當前二分排名則需要在上一個位置處分一段,否則不進行分段。
查詢串的字典序大小可以使用h

eight陣列做成rmq來快速實現。

於是這就完成了~

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

typedef pair<int,int> pr;
typedef long long ll;
const int N=1e6+9;
const int K=23;

int n,k;
int wa[N],wb[N],wc[N];
int sa[N],rk[N],hi[N];
int st[N][K],logs[N];
char s[N];

inline
int minn(int a,int b){return a>b?b:a;} inline void calc(char *r,int n,int m) { int *x=wa,*y=wb; for(int i=1;i<=m;i++) wc[i]=0; for(int i=1;i<=n;i++) wc[x[i]=r[i]-'a'+1]++; for(int i=1;i<=m;i++) wc[i]+=wc[i-1]; for(int i=n;i>=1;i--) sa[wc[x[i]]--]=i; for(int j=1,p=0;j<=n && p<n;j<<=1,m=p,p=0) { for(int i=n-j+1;i<=n;i++) y[++p]=i; for(int i=1;i<=n;i++) if(sa[i]>j) y[++p]=sa[i]-j; for(int i=1;i<=m;i++) wc[i]=0; for(int i=1;i<=n;i++) wc[x[i]]++; for(int i=1;i<=m;i++) wc[i]+=wc[i-1]; for(int i=n;i>=1;i--) sa[wc[x[y[i]]]--]=y[i]; swap(x,y); x[sa[1]]=p=1; for(int i=2;i<=n;i++) if(y[sa[i]]==y[sa[i-1]] && y[sa[i]+j]==y[sa[i-1]+j]) x[sa[i]]=p; else x[sa[i]]=++p; } for(int i=1;i<=n;i++) rk[sa[i]]=i; for(int i=1,p=1;i<=n;i++) { if(p)p--; if(rk[i]==n)continue; int j=sa[rk[i]+1]; while(r[i+p]==r[j+p])p++; hi[rk[i]]=p; } } inline void build() { logs[0]=-1; for(int i=2;i<=n;i++) logs[i]=logs[i>>1]+1; for(int i=1;i<=n;i++) st[i][0]=hi[i]; for(int i=1;i<=logs[n];i++) for(int j=1;j+(1<<i)-1<=n;j++) st[j][i]=minn(st[j][i-1],st[j+(1<<i-1)][i-1]); } inline int lcp(int l,int r) { if(l==r)return n-l+1; l=rk[l];r=rk[r]; if(l>r)swap(l,r);r--; int dt=logs[r-l+1]; return minn(st[l][dt],st[r-(1<<dt)+1][dt]); } inline pr find(ll kth) { static int p; for(p=1;p<=n && kth>n-sa[p]-hi[p-1]+1;p++) kth-=n-sa[p]-hi[p-1]+1; return pr(sa[p],sa[p]+hi[p-1]+kth-1); } inline int cmp(pr a,pr b) { int la=a.second-a.first+1; int lb=b.second-b.first+1; int dt=lcp(a.first,b.first); if(la<=dt || lb<=dt)return la<=lb; return s[a.first+dt]<s[b.first+dt]; } inline bool check(ll mid) { pr kth=find(mid); for(int i=n,lst=n,cnt=1;i>=1;i--) { if(s[i]>s[kth.first])return 0; if(!cmp(pr(i,lst),kth))cnt++,lst=i; if(cnt>k)return 0; } return 1; } int main() { scanf("%d%s",&k,s+1); n=strlen(s+1); calc(s,n,26); build(); ll sum=0; for(int i=1;i<=n;i++) sum+=n-sa[i]-hi[i-1]+1; ll l=1,r=sum,ans=sum,mid; while(l<=r) { mid=l+r>>1; if(check(mid)) r=mid-1,ans=mid; else l=mid+1; } pr ret=find(ans); for(int i=ret.first;i<=ret.second;i++) putchar(s[i]); return 0; }