1. 程式人生 > 其它 >【CF163E】e-Government

【CF163E】e-Government

題目

題目連結:https://codeforces.com/problemset/problem/163/E
給定包含 \(n\) 個字串的集合 \(S\),有 \(m\) 個操作,操作有三種類型:

  • ? 開頭的操作為詢問操作,詢問當前字串集 \(S\) 中的每一個字串匹配詢問字串的次數之和;
  • + 開頭的操作為新增操作,表示將編號為 \(i\) 的字串加入到集合中;
  • - 開頭的操作為刪除操作,表示將編號為 \(i\) 的字串從集合中刪除。

注意當編號為 \(i\) 的字串已經在集合中時,允許存在新增編號為 \(i\) 的字串,刪除亦然。
\(n,m\leq 10^5,\sum |s_i|\leq 10^6\)

思路

\(n\) 個字串都扔進一個 AC 自動機裡。我們知道計算 \(t\) 串在 \(s\) 串中的出現次數,可以在 AC 自動機上不斷匹配 \(S\) 串,對於 \(S\) 串每一個字尾,找 fail 樹上它的祖先是否有 \(t\) 串結尾所表示的點。
那麼如果添加了 \(t\) 串進集合,等價於讓 \(t\) 串結尾所表示節點,在 fail 樹上的子樹全部多了一次匹配。
那麼只需要將 fail 樹的 dfs 序求出來,然後樹狀陣列維護區間加,單點查詢即可。
時間複雜度 \(O(m\log (\sum |S_i|))\)

程式碼

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=100010,M=1000010;
int n,m,lim,ed[N],id[M],siz[M];
bool vis[N];
char s[M];

struct BIT
{
	ll c[M];
	
	void add(int x,ll v)
	{
		for (int i=x;i<=lim;i+=i&-i)
			c[i]+=v;
	}
	
	ll query(int x)
	{
		ll ans=0;
		for (int i=x;i;i-=i&-i)
			ans+=c[i];
		return ans;
	}
}bit;

struct ACA
{
	int tot,fa[M],fail[M],ch[M][26];
	vector<int> e[M];
	
	void insert(char *s,int j,int st=1)
	{
		int p=0,len=strlen(s+1);
		for (int i=st;i<=len;i++)
		{
			if (!ch[p][s[i]-'a'])
				ch[p][s[i]-'a']=++tot,fa[tot]=p;
			p=ch[p][s[i]-'a'];
		}
		ed[j]=p;
	}
	
	void build()
	{
		queue<int> q;
		for (int i=0;i<26;i++)
			if (ch[0][i]) q.push(ch[0][i]);
		while (q.size())
		{
			int u=q.front(); q.pop();
			e[fail[u]].push_back(u);
			for (int i=0;i<26;i++)
				if (ch[u][i]) fail[ch[u][i]]=ch[fail[u]][i],q.push(ch[u][i]);
					else ch[u][i]=ch[fail[u]][i];
		}
	}
	
	void dfs(int x)
	{
		id[x]=++tot; siz[x]=1;
		for (int i=0;i<e[x].size();i++)
		{
			int v=e[x][i];
			dfs(v); siz[x]+=siz[v];
		}
	}
	
	void query(char *s)
	{
		int len=strlen(s+1),p=0;
		ll ans=0;
		for (int i=2;i<=len;i++)
		{
			p=ch[p][s[i]-'a'];
			ans+=bit.query(id[p]);
		}
		cout<<ans<<"\n";
	}
}AC;

int main()
{
	scanf("%d%d",&m,&n);
	for (int i=1;i<=n;i++)
	{
		scanf("%s",s+1);
		AC.insert(s,i);
	}
	AC.build();
	AC.tot=0; AC.dfs(0);
	lim=AC.tot;
	for (int i=1;i<=n;i++)
	{
		vis[i]=1;
		bit.add(id[ed[i]],1);
		bit.add(id[ed[i]]+siz[ed[i]],-1);
	}
	for (int i=1;i<=m;i++)
	{
		scanf("%s",s+1);
		int len=strlen(s+1);
		if (s[1]=='?') AC.query(s);
		if (s[1]=='+')
		{
			int x=0;
			for (int j=2;j<=len;j++)
				x=x*10+s[j]-48;
			if (!vis[x])
			{
				vis[x]=1;
				bit.add(id[ed[x]],1);
				bit.add(id[ed[x]]+siz[ed[x]],-1);
			}
		}
		if (s[1]=='-')
		{
			int x=0;
			for (int j=2;j<=len;j++)
				x=x*10+s[j]-48;
			if (vis[x])
			{
				vis[x]=0;
				bit.add(id[ed[x]],-1);
				bit.add(id[ed[x]]+siz[ed[x]],1);
			}
		}
	}
	return 0;
}