「SCOI2016」萌萌噠
這篇題解來自 Dls,但因為某種原因被誤刪了,在此做一個搬運 & 格式修改 qwq
Problem
Solution
“兩段區間相等”不好處理。我們可以把這種關係具體化為“區間裡的每一對位置分別相等”。而兩個點相等這種關係,很容易用並查集來刻畫。最後求答案時查詢一下,序列(的並查集)中有多少個連通塊。設連通塊數量為 \(c\),則答案就是 \(9\cdot 10^{c-1}\),因為大數第一個位置不能為 \(0\)。於是這樣可以簡單做到 \(\mathcal O(nm)\) 的複雜度(帶並查集常數)。
考慮用倍增來優化這個過程。我們的並查集不建 \(n\) 個點,而是建 \(\mathcal O(n\log n)\)
對於一組條件 \((l_1,r_1,l_2,r_2)\),設 \(k=\lfloor\log_2(r_1-l_1+1)\rfloor\)。則我們把 \((l_1,k),(l_2,k)\) 以及 \((r_1-2^k+1,k),(r_2-2^k+1,k)\) 這兩對點用並查集並起來即可。這類似於 RMQ 的思想。這樣,處理所有限制的時間複雜度是 \(\mathcal O(m)\)
在求答案時,我們類似於線段樹下放懶標記,把所有相等關係下放到最底層。具體來說,我們從大到小列舉 \(k\),對於 \(k\) 這一層的兩個節點 \((i,k),(j,k)\),如果他們被並起來了,那麼 \((i,k-1),(j,k-1)\) 和 \((i+2^{k-1},k-1),(j+2^{k-1},k-1)\) 也要被分別並起來。這樣,列舉所有點下放一遍,時間複雜度是 \(\mathcal O(n\log n)\)。全部下放完成後,列舉最底層(長度為 \(1\))的點,統計連通塊數即可。
總時間複雜度 \(\mathcal O(m+n\log n)\)。
最後,我們反思一下為什麼使用倍增,而不用線段樹或者其他東西替代。因為倍增的性質可以保證劃分出的區間長度相等,這就方便我們直接給兩個區間畫上等號。
參考程式碼:
//problem:LOJ2014
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int MOD=1e9+7;
inline int pow_mod(int x,int i){int y=1;while(i){if(i&1)y=(ll)y*x%MOD;x=(ll)x*x%MOD;i>>=1;}return y;}
const int MAXN=1e5;
int n,fa[MAXN+5][17];
int get_fa(int p,int k){
if(fa[p][k]==p)return p;
else return fa[p][k]=get_fa(fa[p][k],k);
}
void union_s(int p1,int p2,int k){
int f1=get_fa(p1,k);
int f2=get_fa(p2,k);
if(f1!=f2)fa[f1][k]=f2;
}
int main() {
cin>>n;
static int _log2[MAXN+5];
_log2[0]=-1;
for(int i=1;i<=n;++i)_log2[i]=_log2[i>>1]+1;
for(int j=0;j<=16;++j)for(int i=1;i<=n;++i)fa[i][j]=i;
int m;cin>>m;for(int i=1;i<=m;++i){
int l1,r1,l2,r2;cin>>l1>>r1>>l2>>r2;
int k=_log2[r1-l1+1];
union_s(l1,l2,k);
union_s(r1-(1<<k)+1,r2-(1<<k)+1,k);
}
for(int j=16;j>=1;--j){
for(int i=1;i+(1<<(j-1))<=n;++i){
int f=get_fa(i,j);
if(f==i)continue;
union_s(i,f,j-1);
union_s(i+(1<<(j-1)),f+(1<<(j-1)),j-1);
}
}
int num=0;
for(int i=1;i<=n;++i){
if(get_fa(i,0)==i)num++;
}
cout<<9LL*pow_mod(10,num-1)%MOD<<endl;
return 0;
}
My Code QwQ
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5,M=20,mod=1e9+7;
int n,m,k,f[M][N],l1,r1,l2,r2,ans;
int find(int x,int f[M]){
return x==f[x]?x:f[x]=find(f[x],f);
}
void merge(int x,int y,int k){
x=find(x,f[k]),y=find(y,f[k]);
if(x!=y) f[k][x]=y;
}
int mul(int x,int n,int mod){
int ans=mod!=1;
for(x%=mod;n;n>>=1,x=x*x%mod)
if(n&1) ans=ans*x%mod;
return ans;
}
signed main(){
scanf("%lld%lld",&n,&m);
for(int i=1;i<=n;i++)
for(int j=0;j<=log2(n);j++) f[j][i]=i;
for(int i=1;i<=m;i++){
scanf("%lld%lld%lld%lld",&l1,&r1,&l2,&r2);
k=log2(r1-l1+1),merge(l1,l2,k),merge(r1-(1<<k)+1,r2-(1<<k)+1,k);
}
for(int j=log2(n);j>=1;j--)
for(int i=1;i+(1<<(j-1))<=n;i++){
int x=find(i,f[j]);
if(x!=i) merge(i,x,j-1),merge(i+(1<<(j-1)),x+(1<<(j-1)),j-1);
}
for(int i=1;i<=n;i++)
if(find(i,f[0])==i) ans++;
printf("%lld\n",9*mul(10,ans-1,mod)%mod);
return 0;
}