Codeforces 710F - String Set Queries(AC 自動機)
阿新 • • 發佈:2020-12-17
題意:強制線上的 AC 自動機。
\(n,\sum|s|\leq 3\times 10^5\)
如果不是強制線上那此題就是道 sb 題,加了強制線上就不那麼 sb 了。
這裡介紹兩種做法:
- 根號分治
考慮到 KMP 擅長處理單個字串匹配的情況,但對於多模式串的情況複雜度就不那麼優秀了。
而 AC 自動機擅長處理多模式串匹配的情況,但預處理複雜度是線性的,每加進來一個字串都預處理一遍複雜度顯然吃不消。
考慮將二者結合,設一個臨界值 \(B\),每 \(B\) 個串建一個 AC 自動機。剩餘的若干個串暴力跑 KMP。
算下複雜度,最多要在 AC 自動機上匹配 \(\frac{n}{B}\)
再考慮 KMP,單次 KMP 是 \(\mathcal O(n+m)\) 的,每個模式串最多跟 \(B\) 個詢問串匹配,而每個詢問串也最多跟 \(B\) 個模式串匹配,所以貢獻為 \(B\sum|s|\)。
最後是刪除操作,眾所周知,AC 自動機是不支援刪除操作的,但發現一加一減,貢獻可以抵消掉,所有我們可以建兩個 AC 自動機,一個處理所有 \(1\) 操作加入進來的字串,一個處理所有 \(2\) 操作刪除的字串,二者相減即可。
總複雜度 \(\mathcal O(\sum|s|(B+\frac{n}{B}))\)
#include <bits/stdc++.h> using namespace std; #define fi first #define se second #define fz(i,a,b) for(int i=a;i<=b;i++) #define fd(i,a,b) for(int i=a;i>=b;i--) #define ffe(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++) #define fill0(a) memset(a,0,sizeof(a)) #define fill1(a) memset(a,-1,sizeof(a)) #define fillbig(a) memset(a,63,sizeof(a)) #define pb push_back #define ppb pop_back #define mp make_pair template<typename T1,typename T2> void chkmin(T1 &x,T2 y){if(x>y) x=y;} template<typename T1,typename T2> void chkmax(T1 &x,T2 y){if(x<y) x=y;} typedef pair<int,int> pii; typedef long long ll; template<typename T> void read(T &x){ char c=getchar();T neg=1; while(!isdigit(c)){if(c=='-') neg=-1;c=getchar();} while(isdigit(c)) x=x*10+c-'0',c=getchar(); x*=neg; } const int BLK=548; const int MAXN=3e5+5; const int ALPHA=26; char buf[MAXN+5]; string reads(){ scanf("%s",buf+1);int len=strlen(buf+1); string ret;for(int i=1;i<=len;i++) ret+=buf[i]; return ret; } class solver{ public: struct ACAM{ int rt[BLK+5]; int ch[MAXN+BLK+5][ALPHA+2],fail[MAXN+BLK+5],cnt[MAXN+BLK+5],ncnt=0; void insert(int r,string s){ int cur=r; for(int i=0;i<s.size();i++){ if(!ch[cur][s[i]-'a']) ch[cur][s[i]-'a']=++ncnt; cur=ch[cur][s[i]-'a']; } cnt[cur]++; } void getfail(int r){ queue<int> q; for(int i=0;i<ALPHA;i++){ if(ch[r][i]) q.push(ch[r][i]),fail[ch[r][i]]=r; else ch[r][i]=r; } while(!q.empty()){ int x=q.front();q.pop(); for(int i=0;i<ALPHA;i++){ if(ch[x][i]) fail[ch[x][i]]=ch[fail[x]][i],q.push(ch[x][i]),cnt[ch[x][i]]+=cnt[fail[ch[x][i]]]; else ch[x][i]=ch[fail[x]][i]; } } } int query(int r,string s){ int cur=r,ret=0; for(int i=0;i<s.size();i++){ cur=ch[cur][s[i]-'a'];ret+=cnt[cur]; } return ret; } } a; int fail[MAXN+5]; int getkmp(string s,string t){ int ls=s.size(),lt=t.size();s=" "+s;t=" "+t; int pos=0,ret=0; for(int i=2;i<=lt;i++){ while(pos&&t[pos+1]!=t[i]) pos=fail[pos]; if(t[pos+1]==t[i]) pos++;fail[i]=pos; } pos=0; for(int i=1;i<=ls;i++){ while(pos&&t[pos+1]!=s[i]) pos=fail[pos]; if(t[pos+1]==s[i]) pos++; if(pos==lt) ret++; } return ret; } string ss[MAXN+5]; int num=0,pre=0,cnt=0; void insert(string s){ ss[++num]=s; if(num%BLK==0){ a.rt[++cnt]=++a.ncnt; for(int i=pre+1;i<=num;i++) a.insert(a.rt[cnt],ss[i]); a.getfail(a.rt[cnt]);pre=num; } } int query(string s){ int ret=0; for(int i=1;i<=cnt;i++) ret+=a.query(a.rt[i],s); for(int i=pre+1;i<=num;i++) ret+=getkmp(s,ss[i]); return ret; } } s1,s2; int T; int main(){ scanf("%d",&T); while(T--){ int opt;string s;scanf("%d",&opt);s=reads(); if(opt==1) s1.insert(s); else if(opt==2) s2.insert(s); else printf("%d\n",s1.query(s)-s2.query(s)),fflush(stdout); } return 0; }
- 二進位制分組
感覺這應該是正解吧,本題官方題解給的就這個做法。
還是將加入的字串分成若干組,每組建一個 AC 自動機。
不過與之前不同的是這次我們按二進位制分組,即每組的大小都是 \(2\) 的整數次冪。
那麼這玩意兒怎麼支援插入操作呢?假設我們插入字串 \(s\),我們先建一個只有一個串的 AC 自動機,然後不斷與前面的 AC 自動機像啟發式合併堆一樣合併。最後暴力重構一發。
是不是有點抽象?打個形象的比喻,2048,假設我們現在有 \(23\) 個串,那麼會分為 \(4\) 組,大小分別為 \(16,4,2,1\),此時你再加入一個串,就變為 \(16,4,2,1,1\),最後兩個 \(1\) 合併,變為一個 \(2\);最後兩個 \(2\) 合併,變為一個 \(4\)……以此類推。最後會得到 \(16,8\)。
算下複雜度,暴力重構複雜度是 \(\sum|s|\) 的,而我們每個字串最多被重構 \(\log n\) 次,故插入的總複雜度是 \(n\log n\) 的,而查詢的時候你最多在 \(\log n\) 個 AC 自動機中查詢,故總複雜度為 \(n\log n\),碾壓演算法 1。
實測 1s,可能因為有個 \(26\) 的常數吧。
#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define fz(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define ffe(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++)
#define fill0(a) memset(a,0,sizeof(a))
#define fill1(a) memset(a,-1,sizeof(a))
#define fillbig(a) memset(a,63,sizeof(a))
#define pb push_back
#define ppb pop_back
#define mp make_pair
template<typename T1,typename T2> void chkmin(T1 &x,T2 y){if(x>y) x=y;}
template<typename T1,typename T2> void chkmax(T1 &x,T2 y){if(x<y) x=y;}
typedef pair<int,int> pii;
typedef long long ll;
template<typename T> void read(T &x){
char c=getchar();T neg=1;
while(!isdigit(c)){if(c=='-') neg=-1;c=getchar();}
while(isdigit(c)) x=x*10+c-'0',c=getchar();
x*=neg;
}
const int MAXN=3e5;
const int LOG_N=19;
const int ALPHA=26;
class solver{
public:
int sz[LOG_N+3],rt[LOG_N+3],cnt=0;
int oc[MAXN*2+5][ALPHA+2],ch[MAXN*2+5][ALPHA+2],ncnt=0,ed[MAXN*2+5],val[MAXN*2+5],fail[MAXN+5];
void insert(char *s,int r){
int len=strlen(s+1),cur=r;
for(int i=1;i<=len;i++){
if(!oc[cur][s[i]-'a']) oc[cur][s[i]-'a']=++ncnt;
cur=oc[cur][s[i]-'a'];
} ed[cur]++;
}
void getfail(int r){
queue<int> q;
for(int i=0;i<ALPHA;i++){
if(oc[r][i]){
fail[oc[r][i]]=r;q.push(oc[r][i]);
val[oc[r][i]]=ed[oc[r][i]];ch[r][i]=oc[r][i];
} else ch[r][i]=r;
}
while(!q.empty()){
int x=q.front();q.pop();
for(int i=0;i<ALPHA;i++){
if(oc[x][i]){
ch[x][i]=oc[x][i];
fail[oc[x][i]]=ch[fail[x]][i];
val[oc[x][i]]=val[fail[oc[x][i]]]+ed[oc[x][i]];
q.push(oc[x][i]);
} else ch[x][i]=ch[fail[x]][i];
}
}
}
int merge(int x,int y){
if(!x||!y) return x+y;
ed[x]+=ed[y];
for(int i=0;i<ALPHA;i++) oc[x][i]=merge(oc[x][i],oc[y][i]);
return x;
}
void insert(char *s){
sz[++cnt]=1;rt[cnt]=++ncnt;insert(s,rt[cnt]);
while(sz[cnt]==sz[cnt-1]){
rt[cnt-1]=merge(rt[cnt-1],rt[cnt]);
sz[cnt-1]<<=1;sz[cnt]=0;cnt--;
} getfail(rt[cnt]);
}
int query(char *s,int r){
int len=strlen(s+1),cur=r,ret=0;
for(int i=1;i<=len;i++){
cur=ch[cur][s[i]-'a'];
ret+=val[cur];
} return ret;
}
int query(char *s){
int ret=0;
for(int i=1;i<=cnt;i++) ret+=query(s,rt[i]);
return ret;
}
} s1,s2;
char buf[MAXN+5];
int main(){
int T;scanf("%d",&T);
while(T--){
int opt;scanf("%d%s",&opt,buf+1);
if(opt==1) s1.insert(buf);
else if(opt==2) s2.insert(buf);
else printf("%d\n",s1.query(buf)-s2.query(buf)),fflush(stdout);
}
return 0;
}