1. 程式人生 > 實用技巧 >codeforces 1037H - Security (字尾自動機 + 線段樹合併)

codeforces 1037H - Security (字尾自動機 + 線段樹合併)

題解:首先分析,要大於給出的模式串並且儘可能小,那麼一定是優先找和給出的模式串公共字首儘可能長的字串,假設模式串 \(t\) 的長度為 \(tlen\)

\(t[tlen + 1] = a - 1\) ,那麼思路的流程大概如下

1、首先後綴自動機上尋找匹配 \(t\) 的最長字串, 假設長度為 \(x\) 去掉 ,看 \(2\)

2、在模式串 \(x + 1\) 上的字元 \(+1\) ,比如 將 \(a\) 變成 \(b\) ,看 \(3\)

3、判斷新的模式串是否存在文字串的字串匹配長度為 \(x +1\) , 若存在,看 \(4\)

4、判斷是否存在對應 \(endpoint\)

的字串,若存在, 則輸出當前模式串,若不存在,看 \(5\)

5、若新的字元已經是 \(z\) ,看 \(6\),否則將新 \(x + 1\) 的字元 +1, 比如 將 \(a\) 變成 \(b\) 後,看 \(3\)

6、\(x -= 1\) , 若 \(x == -1\) ,輸出 \(-1\), 否則看 \(2\)

某個節點是否存在某個 \(endpoint\) 我們可以用線段樹合併維護,具體看程式碼。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <cmath>
using namespace std;
typedef long long LL;
const int maxn = 2e5 + 50;
const LL mod = 1e9 + 7;
double eps = 1e-6;
 
struct state
{
    int len, link;
    int nex[30];
} st[maxn];
 
int sz, last;;
 
void sam_init(){
    st[0].len = 0;
    st[0].link = -1;
    sz = 1;
    last = 0;
}
 
char s[maxn], t[maxn];
int n;
int ed[maxn];
void sam_extend(int x){
    int cur = sz++;
    st[cur].len = st[last].len + 1; 
    int p = last;
    while(p != -1 && !st[p].nex[x]){
        st[p].nex[x] = cur;
        p = st[p].link;
    }
 
    if(p == -1){
        st[cur].link = 0;
    } else {
        int q = st[p].nex[x];
        if(st[p].len + 1 == st[q].len){
            st[cur].link = q;
        } else {
            int clone = sz++;
            st[clone].len = st[p].len + 1;
            for(int i = 0; i < 26; i++){
                st[clone].nex[i] = st[q].nex[i];
            }
            st[clone].link = st[q].link;
            while(p != -1 && st[p].nex[x] == q){
                st[p].nex[x] = clone;
                p = st[p].link;
            }
            st[q].link = st[cur].link = clone;
        }
    }
    last = cur;
}
 
struct qnode
{
    int ls, rs, val;
} tree[maxn * 30];
int tot, root[maxn];
void insert(int le, int ri, int pos, int &rt){
    if(!rt) rt = ++tot;
    tree[rt].val = 1;
    if(le == ri) return ;
    int mid = (le + ri) >> 1;
    if(pos <= mid) insert(le, mid, pos, tree[rt].ls);
    else insert(mid + 1, ri, pos, tree[rt].rs);
}
 
int merge(int u, int v){ // 線段樹合併
    if(!u || !v) return u | v;
    int p = ++tot;  //記住要開新節點存合併的線段樹,因為後面的查詢可能要用到所有節點的線段樹
    tree[p].val = tree[u].val | tree[v].val; //由於只用判斷該區間存不存在值,所以這樣更新就好了
    tree[p].ls = merge(tree[u].ls, tree[v].ls);
    tree[p].rs = merge(tree[u].rs, tree[v].rs);
    return p;
}
 
int Query(int le, int ri, int L, int R, int rt){
    if(L <= le && ri <= R) return tree[rt].val;
    if(!tree[rt].val) return 0;
    int mid = (le + ri) >> 1;
    if(L <= mid && Query(le, mid, L, R, tree[rt].ls)) return 1;
    if(R > mid && Query(mid + 1, ri, L, R, tree[rt].rs)) return 1;
    return 0;    
}
int tax[maxn], id[maxn];
void pre(){ // 根據每個節點的最大長度進行排序,然後從 link 樹的下面往上進行線段樹合併,因為最大長度越長的節點一定在越靠近link樹的根節點
    for(int i = 1; i < sz; i++) tax[st[i].len]++;
    for(int i = 1; i <= n; i++) tax[i] += tax[i - 1];
    for(int i = 1; i < sz; i++) id[tax[st[i].len]--] = i;
    for(int i = 1; i <= n; i++) insert(1, n, i, root[ed[i]]);
    for(int i = sz - 1; i > 1; i--){
        root[st[id[i]].link] = merge(root[st[id[i]].link], root[id[i]]); // 將每個節點合併到它的父節點上
    }
}
 
int pos[maxn];
void solve(){
    int le, ri;
    scanf("%d%d%s", &le, &ri, t + 1);
    int len = strlen(t + 1);
    t[len + 1] = 'a' - 1; // 這句看不懂的先往下看,很容易理解
    int p = 0;
    int mal = 0;
    for(int i = 1; i <= len; i++){
        int x = t[i] - 'a';
        if(st[p].nex[x]){
            p = st[p].nex[x];
            pos[i] = p; // 記錄每個位置的 p 
            mal = i;
        } else {
            break;
        }
    }
 
    for(int i = mal; i >= 0; i--){
        int x = t[i + 1] - 'a' + 1;
        while(x < 26){
            p = st[pos[i]].nex[x];
            if(p){
                int res = Query(1, n, le + i, ri, root[p]);
                if(res) {
                    for(int j = 1; j <= i; j++){
                        printf("%c", t[j]);
                    }
                    printf("%c\n", x + 'a');
                    return ;
                }
            }
            x++;
        }
    }
    printf("-1\n");
}
int main()
{
    sam_init();
    scanf("%s", s + 1);
    n = strlen(s + 1);
    for(int i = 1; i <= n; i++){
        int x = s[i] - 'a';
        sam_extend(x);
        ed[i] = last; // 記錄文字串每個字首對應的節點,因為所有節點的endpoint,都是從這些點轉移過來的
    }
    pre();
    int q;
    scanf("%d", &q);
    while(q--){
        solve();
    }
    return 0;
}