HDU-3553 Just a String (二分 + 字尾陣列)
題意:找出文字串中字典序第 k 大的字串
思路:
首先我們不能僅僅按字尾陣列排完序後每個字串的大小來找,因為重複字元也參與排名,比如 AAB 2, 結果是 A 而不是 AA。
注:以下第 i 個字尾均指排完序後第 i 小的字尾。
所以我們二分找第 k 大的字串位於哪個區間,假定我們現在確定目標位於字尾區間 \([le, ri]\) (排完序的),我們求出 \(LCP(le, ri) = x\),
並找出最小的LCP對應的字尾 \(mid\) 如果 \(x \times(ri - le + 1) \geq k\) ,那麼我們就可以確定該字串的長度
\(len = k / (ri - le + 1) + k \% (ri - le + 1)\)
\(x \times(ri - le + 1) < k\) 的情況,我們首先先將 \(k = k - x \times (ri - mid)\),然後看區間 \([le, mid]\) 的字串總數量 \(sum\) 是否大於等於\(k\),如果大於等
於\(k\),那麼我們可以確定目標一定在區間 \([le, mid]\) 中, 因為區間 \([le, mid]\) 的任意字串字典序一定小於區間 \([mid + 1, ri]\) 中的所有長度大於 \(x\) 的字串。
如果區間 \([le, mid]\) 的字串總數量小於 \(k\), 那麼目標就一定在區間 \([mid + 1, ri]\)
\(k = k + x \times (ri - mid) - sum\) ,因為可以保證區間 \([le, mid]\) 中所有字串的都小於 \(k\),因為我們最好控制每次開始查詢的 \([le, ri]\) 區間裡的字串還未被減去, 這樣方便我們程式設計。關於具體細節看下面程式碼。
#include <bits/stdc++.h> using namespace std; typedef long long LL; const int maxn = 1e5 + 50; int Sa[maxn], Height[maxn], Tax[maxn], Rank[maxn], tp[maxn], a[maxn], n, m; LL sum[maxn]; int ca = 0; char str[maxn]; void Rsort(){ for(int i = 0; i <= m; i++) Tax[i] = 0; for(int i = 1; i <= n; i++) Tax[Rank[tp[i]]]++; for(int i = 1; i <= m; i++) Tax[i] += Tax[i - 1]; for(int i = n; i >= 1; i--) Sa[Tax[Rank[tp[i]]]--] = tp[i]; } int cmp(int *f, int x, int y, int w){ if(x + w > n || y + w > n) return 0; // 注意防止越界,多組輸入的時候這條必須有 return f[x] == f[y] && f[x + w] == f[y + w]; } void Suffix(){ for(int i = 1; i <= n; i++) Rank[i] = a[i], tp[i] = i; m = 500, Rsort(); int p = 0; for(int w = 1, i; p < n; w <<= 1, m = p){ for(p = 0, i = n - w + 1; i <= n; i++) tp[++p] = i; for(i = 1; i <= n; i++) if(Sa[i] > w) tp[++p] = Sa[i] - w; Rsort(); for(int i = 1;i <= n;i++) tp[i] = Rank[i]; Rank[Sa[1]] = p = 1; for(int i = 2; i <= n; i++) Rank[Sa[i]] = cmp(tp, Sa[i], Sa[i - 1], w) ? p : ++p; } int j, k = 0; for(int i = 1; i <= n; Height[Rank[i++]] = k){ for(k = k ? k - 1 : k, j = Sa[Rank[i] - 1]; i + k <= n && j + k <= n && a[i + k] == a[j + k]; ++k); } } int min_st(int p1, int p2){ if(Height[p1] <= Height[p2]) return p1; else return p2; } int dpmi[maxn][30]; void RMQ(){ for(int i = 1; i <= n; i++) dpmi[i][0] = i; for(int j = 1; (1 << j) <= n; j++){ for(int i = 1; i + (1 << j) - 1 <= n; i++){ dpmi[i][j] = min_st(dpmi[i][j - 1], dpmi[i + (1 << (j - 1))][j - 1]); } } } int QueryMin(int le, int ri){ int k = log2(ri - le + 1); return min_st(dpmi[le][k], dpmi[ri - (1 << k) + 1][k]); } int QueryLcp(int le, int ri){ if(le > ri) swap(le, ri); le++; return QueryMin(le, ri); } void Solve(LL k){ int le = 1, ri = n; while(le <= ri){ if(le == ri){ for(int i = 0; i < k; i++){ printf("%c", str[Sa[le] + i]); } printf("\n"); break; } int mid = QueryLcp(le, ri) - 1; if(k <= 1LL * Height[mid + 1] * (ri - le + 1)){ int len = k / (ri - le + 1); if(k % (ri - le + 1)) len++; for(int i = 0; i < len; i++){ printf("%c", str[Sa[le] + i]); } printf("\n"); break; } else { k -= 1LL * (ri - mid) * Height[mid + 1]; if(sum[mid] - sum[le - 1] >= k){ ri = mid; } else { k += 1LL * (ri - mid) * Height[mid + 1]; k -= (sum[mid] - sum[le - 1]); le = mid + 1; } } } } int main(int argc, char const *argv[]) { int tt; scanf("%d", &tt); while(tt--){ scanf("%s", str + 1); n = strlen(str + 1); LL k; scanf("%lld", &k); for(int i = 1; i <= n; i++) { a[i] = str[i]; } Suffix(); for(int i = 1; i <= n; i++){ sum[i] = sum[i - 1] + n - Sa[i] + 1; } RMQ(); printf("Case %d: ", ++ca); Solve(k); } return 0; }