JZOJ5803.【2018.8.12省選模擬】girls(三元環)
PROBLEM
有0至n-1個元素,給出m對元素的衝突,給出A,B,C,定義一個滿足i<j<k的三元組(i,j,k)的貢獻為A∗i+B∗j+C∗k,求所有沒有衝突的三元組的貢獻和
SOLUTION
考慮容斥,設F(i)表示每個三元組考慮i組衝突的總和
Ans=F(0)−F(1)+F(2)−F(3)
沒有衝突的每個元素分別為A,B,C時獨立計算次數
一條衝突的列舉衝突,討論這兩個端點的前後情況
兩條衝突的同理同樣是討論。
考慮三條衝突在一個三元組裡面,如果把每一個衝突看成一個環的話,那麼就是對每一個三元環進行計數。問題在於三元環。
有兩種方法。
-
第一種:將點根據度數是否小於sqrt(m)分成兩類,小於sqrt(m)的點將其所有的連邊暴力,時間為sqrt(m)^ 3;對於另一類點,直接三次方暴力只在這一類裡的點,由於這裡的點數不會超過sqrt(m),所以總的複雜度也不會超過sqrt(m)^ 3.因此總的複雜度不超過sqrt(m)^3。但由於是雙向邊,常數巨大。
-
第二種更優的做法: 將無向邊定向為有向邊,根據度數小的點連向度數大的點、其次是編號小的點連向編號大的點,可以發現這是一個有向無環圖,直接列舉點,暴力列舉兩條邊並判斷這兩個點是否聯通(hash、set或map)即可。並且這種方法不用去重,因為有向邊只能從一個點出發計算,所以十分合適這題的打法,真是妙啊妙。
考慮證明這種做法的正確性與時間複雜度。
如果有一個長度為3的環,是不能夠從一個點掃描到的。由於連向的點的度數不減,那麼回到自己後度數不變,即說明這個環的度數相等,但我們又有編號小的點連向編號大的點,所以在這種情況下一定不會有環。
長度為3的非環,所以有兩條邊共了一個起點,可以考慮到這種情況。
因為每個點連向的點的度數比它大,所以度數不超過sqrt(n)。均攤複雜度不會超過m*sqrt(n),並不會證這種方法QwQ
-
還有另一種列舉的方法,時間複雜度可證:
列舉點,將其通往的所有點打標記,再列舉通往的點,再暴力列舉這個點的出邊是否有通往已標記的點上的。打標記為O(m),列舉每個通往的點總數為邊數m,再枚舉出邊sqrt(n),總複雜度O(m*sqrt(n))
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#define maxn 400005
#define ll unsigned long long
using namespace std;
int n,m,x,y,z,du[maxn],a[maxn][2],tot,node[maxn];
int em,e[maxn],nx[maxn],ls[maxn],r[3];
ll A,B,C,ans,i,j,k,len1,len2,s;
ll num[maxn],sum[maxn],pc[maxn],ps[maxn];
vector<int> to[maxn];
void insert(int x,int y){em++;e[em]=y;nx[em]=ls[x];ls[x]=em;}
ll que(ll x){if (x<=0) return 0; else return x*(x+1)/2;}
const ll p1=998244353,p2=10007,mo=3000007;
int hx[mo],p[mo][2];
void link(int x,int y){
if (x>y) swap(x,y);
ll s=(x*p1+y*p2)%mo;
while (hx[s]&&(p[s][0]!=x||p[s][1]!=y))
{s++; if (s==mo) s-=mo;}
hx[s]=1,p[s][0]=x,p[s][1]=y;
}
int pd(int x,int y){
if (x>y) swap(x,y);
ll s=(x*p1+y*p2)%mo;
while (hx[s]&&(p[s][0]!=x||p[s][1]!=y))
{s++; if (s==mo) s-=mo;}
return hx[s];
}
int main(){
freopen("girls.in","r",stdin);
freopen("girls.out","w",stdout);
scanf("%d%d",&n,&m);
scanf("%lld%lld%lld",&A,&B,&C);
for(i=0;i<n;i++){
len1=i,len2=n-i-1;
ans+=A*i*((len2-1)*len2/2);
ans+=B*i*len1*len2;
ans+=C*i*((len1-1)*len1/2);
}
for(i=1;i<=m;i++) {
scanf("%d%d",&x,&y);
if (x>y) swap(x,y);
a[i][0]=x,a[i][1]=y;
num[y]++,sum[y]+=x;
to[y].push_back(x);
du[x]++,du[y]++;
ans-=(A*x+B*y)*(n-y-1)+(que(n-1)-que(y))*C;
ans-=(A*x+C*y)*(y-x-1)+(que(y-1)-que(x))*B;
ans-=(B*x+C*y)*x+que(x-1)*A;
}
for(i=0;i<n;i++) sort(to[i].begin(),to[i].end());
for(i=0;i<n;i++){
ans+=num[i]*(num[i]-1)/2*i*C;
for(j=0;j<to[i].size();j++){
x=to[i][j];
ans+=(num[i]-j-1)*to[i][j]*A;
ans+=j*to[i][j]*B;
ans+=num[to[i][j]]*(B*to[i][j]+C*i);
ans+=sum[to[i][j]]*A;
ans+=pc[to[i][j]]*(A*to[i][j]+C*i);
ans+=ps[to[i][j]]*B;
pc[to[i][j]]++,ps[to[i][j]]+=i;
}
}
for(i=1;i<=m;i++){
x=a[i][0],y=a[i][1];
if (du[x]<du[y]) insert(x,y); else
if (du[x]>du[y]) insert(y,x); else
insert(x,y);
link(x,y);
}
for(x=0;x<n;x++){
tot=0;
for(i=ls[x];i;i=nx[i]) node[++tot]=e[i];
for(i=1;i<tot;i++) for(j=i+1;j<=tot;j++){
r[0]=x,r[1]=node[i],r[2]=node[j];
if (!pd(r[1],r[2])) continue;
sort(r,r+3);
ans-=r[0]*A+r[1]*B+r[2]*C;
}
}
printf("%llu",ans);
}