HDU 5799 This world need more Zhu (樹上莫隊)
阿新 • • 發佈:2021-10-02
HDU 5799 This world need more Zhu
Mean
給定一棵樹,每個點有一個點權。兩種詢問。
1.詢問子樹\(u\)中出現\(a\)次的權值的累加和與出現\(b\)次的權值的累加和的\(gcd\)。
2.詢問路徑\(u-v\)中出現\(a\)次的權值的累加和與出現\(b\)次的權值的累加和的\(gcd\)。
\(n,m<=1e5\),\(t<=10\)。Time: \(5000ms\).
Sol
樹上莫隊。
考慮將兩種詢問分開處理。
對於詢問1.
採用只記錄入口的\(dfs\)序,直接做普通的莫隊即可。
對於詢問2.
樹上莫隊.
採用尤拉序,需要考慮兩種情況(預設在尤拉序中先訪問\(u\)
當\(u,v\)在一條鏈上時,轉換成詢問一段連續的區間為\([st[u],st[v]]\).
當\(u,v\)不在一條鏈上時,轉換成詢問一段連續的區間為\([min(st[y],ed[x]),max(st[y],ed[x])]\),由於這段區間缺少\(LCA(u,v)\),所以最後要把\(LCA(u,v)\) 的貢獻算上,統計完答案後需要把\(LCA(u,v)\)的貢獻減去。
其他的就是普通的莫隊操作了,值得注意的是需要離散化且答案會炸int。
另外此做法需要卡常,莫隊做法加了IO流以及奇偶排序才能通過。
更快的做法可以把詢問1換成dsu on tree。
Code
#pragma GCC optimize(2) #include<bits/stdc++.h> #define lowbit(x) (x&(-x)) #define lson l,mid,rt<<1 #define rson mid+1,r,rt<<1|1 #define rep(i,a,b) for(int i=(a);i<=(b);++i) #define dep(i,a,b) for(int i=(a);i>=(b);--i) using namespace std; typedef unsigned long long ull; typedef long long ll; const int N=2e5+20; const int MAX=10000007; /** 樹上莫隊 or dsu+莫隊 */ namespace IO{ #define BUF_SIZE 100000 #define OUT_SIZE 100000 #define ll long long //fread->read bool IOerror=0; inline char nc(){ static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE; if (p1==pend){ p1=buf; pend=buf+fread(buf,1,BUF_SIZE,stdin); if (pend==p1){IOerror=1;return -1;} //{printf("IO error!\n");system("pause");for (;;);exit(0);} } return *p1++; } inline bool blank(char ch){return ch==' '||ch=='\n'||ch=='\r'||ch=='\t';} inline void read(int &x){ bool sign=0; char ch=nc(); x=0; for (;blank(ch);ch=nc()); if (IOerror)return; if (ch=='-')sign=1,ch=nc(); for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0'; if (sign)x=-x; } inline void read(ll &x){ bool sign=0; char ch=nc(); x=0; for (;blank(ch);ch=nc()); if (IOerror)return; if (ch=='-')sign=1,ch=nc(); for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0'; if (sign)x=-x; } inline void read(double &x){ bool sign=0; char ch=nc(); x=0; for (;blank(ch);ch=nc()); if (IOerror)return; if (ch=='-')sign=1,ch=nc(); for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0'; if (ch=='.'){ double tmp=1; ch=nc(); for (;ch>='0'&&ch<='9';ch=nc())tmp/=10.0,x+=tmp*(ch-'0'); } if (sign)x=-x; } inline void read(char *s){ char ch=nc(); for (;blank(ch);ch=nc()); if (IOerror)return; for (;!blank(ch)&&!IOerror;ch=nc())*s++=ch; *s=0; } inline void read(char &c){ for (c=nc();blank(c);c=nc()); if (IOerror){c=-1;return;} } struct Ostream_fwrite{ char *buf,*p1,*pend; Ostream_fwrite(){buf=new char[BUF_SIZE];p1=buf;pend=buf+BUF_SIZE;} void out(char ch){ if (p1==pend){ fwrite(buf,1,BUF_SIZE,stdout);p1=buf; } *p1++=ch; } void print(int x){ static char s[15],*s1;s1=s; if (!x)*s1++='0';if (x<0)out('-'),x=-x; while(x)*s1++=x%10+'0',x/=10; while(s1--!=s)out(*s1); } void println(int x){ static char s[15],*s1;s1=s; if (!x)*s1++='0';if (x<0)out('-'),x=-x; while(x)*s1++=x%10+'0',x/=10; while(s1--!=s)out(*s1); out('\n'); } void print(ll x){ static char s[25],*s1;s1=s; if (!x)*s1++='0';if (x<0)out('-'),x=-x; while(x)*s1++=x%10+'0',x/=10; while(s1--!=s)out(*s1); } void println(ll x){ static char s[25],*s1;s1=s; if (!x)*s1++='0';if (x<0)out('-'),x=-x; while(x)*s1++=x%10+'0',x/=10; while(s1--!=s)out(*s1); out('\n'); } void print(double x,int y){ static ll mul[]={1,10,100,1000,10000,100000,1000000,10000000,100000000, 1000000000,10000000000LL,100000000000LL,1000000000000LL,10000000000000LL, 100000000000000LL,1000000000000000LL,10000000000000000LL,100000000000000000LL}; if (x<-1e-12)out('-'),x=-x;x*=mul[y]; ll x1=(ll)floor(x); if (x-floor(x)>=0.5)++x1; ll x2=x1/mul[y],x3=x1-x2*mul[y]; print(x2); if (y>0){out('.'); for (size_t i=1;i<y&&x3*mul[i]<mul[y];out('0'),++i); print(x3);} } void println(double x,int y){print(x,y);out('\n');} void print(char *s){while (*s)out(*s++);} void println(char *s){while (*s)out(*s++);out('\n');} void flush(){if (p1!=buf){fwrite(buf,1,p1-buf,stdout);p1=buf;}} ~Ostream_fwrite(){flush();} }Ostream; inline void print(int x){Ostream.print(x);} inline void println(int x){Ostream.println(x);} inline void print(char x){Ostream.out(x);} inline void println(char x){Ostream.out(x);Ostream.out('\n');} inline void print(ll x){Ostream.print(x);} inline void println(ll x){Ostream.println(x);} inline void print(double x,int y){Ostream.print(x,y);} inline void println(double x,int y){Ostream.println(x,y);} inline void print(char *s){Ostream.print(s);} inline void println(char *s){Ostream.println(s);} inline void println(){Ostream.out('\n');} inline void flush(){Ostream.flush();} #undef ll #undef OUT_SIZE #undef BUF_SIZE }; inline void out(int x) { if(x>9) out(x/10); putchar(x%10+'0'); } int t,n,m; int nval[N]; int h[N],e[N*2],ne[N*2]; int idx=0,tot=0,tot1=0; int block; struct node { int op,u,v,a,b; int id; int l,r; int lca; }askpa[N],ask[N],askt[N]; bool cmp(node x,node y){ if(x.l/block==y.l/block){ return x.r<y.r; } return x.l/block<y.l/block; } bool cmp1(node a,node b){//奇偶排序 if(a.l/block==b.l/block){ if((a.l/block)%2==1){ return a.r<b.r; } return a.r>b.r; } return a.l<b.l; } void add(int x,int y){ e[++idx] = y,ne[idx] = h[x] , h[x] =idx; } int siz[N]; int kepval[N]; int st[N],ed[N]; int dfn[N]; int fat[N]; int dep[N]; int son[N]; int ary[N],ary1[N]; int top[N]; ll ANS[N]; int stats[N]; void dfs(int x,int fa,int dp){ st[x]=++tot; ary[tot]=x; dfn[x]=++tot1; ary1[tot1]=x; siz[x]=1; dep[x]=dp; fat[x]=fa; int mxson = -1; for(int i=h[x];i;i=ne[i]){ int to = e[i]; if(to == fa)continue; dfs(to,x,dp+1); siz[x]+=siz[to]; if(siz[to]>mxson){ son[x]=to; mxson=siz[to]; } } ed[x]=++tot; ary[tot]=x; } void dfs1(int x,int topf){//x當前節點 topf當前鏈的最頂端的節點 top[x]=topf;//這個點所在鏈的頂端 if(!son[x])return ;//如果沒有兒子則返回 dfs1(son[x],topf);//優先處理重兒子,再處理輕兒子順序遞迴 for(int i=h[x];i;i=ne[i]){ int y=e[i]; if(y==fat[x]||y==son[x])continue;//遇到重兒子或父親節點跳過 dfs1(y,y);//對於每一個輕兒子都有一條從它自己開始的鏈 每一天重鏈的頂端都是輕兒子 } } int lca(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]])swap(x,y); x=fat[top[x]]; } if(dep[x]>dep[y])return y; return x; } int cnt; int con,pon; int rans = 0; ll fmp[N]; ll mp[N]; ll num[N]; void init(){ idx = tot = tot1 = 0; rep(i,1,n){ h[i]=0; dep[i]=siz[i]=top[i]=son[i]=ary[i]=ary1[i]=dfn[i]=fat[i]=st[i]=ed[i]=fmp[i]=0; } con=rans=0; } void add(int x){ num[mp[nval[x]]]-=fmp[nval[x]]; mp[nval[x]]++; num[mp[nval[x]]]+=fmp[nval[x]]; } void del(int x){ num[mp[nval[x]]]-=fmp[nval[x]]; mp[nval[x]]--; num[mp[nval[x]]]+=fmp[nval[x]]; } void cal(int x){ if(stats[x]==0){ add(x); } else{ del(x); } stats[x]^=1; } int ca=0; void solve(){ ++ca; IO::read(n),IO::read(m); block = 600;//最佳塊 rep(i,1,n){ IO::read(nval[i]); kepval[i]=nval[i]; } sort(kepval+1,kepval+1+n); int sz = unique(kepval+1,kepval+1+n)-(kepval+1); rep(i,1,n){ int val = nval[i]; nval[i] = lower_bound(kepval+1,kepval+1+sz,nval[i])-kepval; fmp[nval[i]]=val; } rep(i,1,n-1){ int u,v; IO::read(u),IO::read(v); add(u,v); add(v,u); } con=pon=0; dfs(1,0,1); dfs1(1,1); rep(i,1,m){ int op,u,v,a,b; IO::read(op),IO::read(u),IO::read(v),IO::read(a),IO::read(b); ask[i]=(node){op,u,v,a,b,i}; if(op==2)askpa[++con]=ask[i]; else askt[++pon]=ask[i]; } rep(i,1,con){ int x,y; x=askpa[i].u; y=askpa[i].v; if(st[x]>st[y])swap(x,y),swap(askpa[i].v,askpa[i].u); int Lc=lca(x,y); if(Lc==x){ askpa[i].l = st[x]; askpa[i].r = st[y]; askpa[i].lca = 0; } else{ askpa[i].l = min(st[y],ed[x]); askpa[i].r = max(st[y],ed[x]); askpa[i].lca = Lc; } } rep(i,1,pon){ int u = askt[i].u; askt[i].l=dfn[u]; askt[i].r=dfn[u]+siz[u]-1; askt[i].lca=0; } sort(askt+1,askt+1+pon,cmp1); sort(askpa+1,askpa+1+con,cmp1); int tl=1,tr=0; rep(i,1,n)stats[i]=mp[i]=num[i]=0; for(int i=1;i<=con;++i){ int L=askpa[i].l,R=askpa[i].r; while(tl<L)cal(ary[tl++]); while(tl>L)cal(ary[--tl]); while(tr<R)cal(ary[++tr]); while(tr>R)cal(ary[tr--]); if(askpa[i].lca){ cal(askpa[i].lca); } ANS[askpa[i].id] = __gcd(num[askpa[i].a],num[askpa[i].b]); if(askpa[i].lca){ cal(askpa[i].lca); } } rep(i,1,n)stats[i]=mp[i]=num[i]=0; tl=1,tr=0; for(int i=1;i<=pon;++i){ int L=askt[i].l,R=askt[i].r; while(tl<L)cal(ary1[tl++]); while(tl>L)cal(ary1[--tl]); while(tr<R)cal(ary1[++tr]); while(tr>R)cal(ary1[tr--]); ANS[askt[i].id] = __gcd(num[askt[i].a],num[askt[i].b]); } IO::print("Case #"); IO::print(ca); IO::println(":"); rep(i,1,m){ IO::println(ANS[i]); } } int main(){ IO::read(t); while(t--){ init(); solve(); } return 0; } /* 1 5 5 1 2 4 1 2 1 2 2 3 3 4 4 5 1 1 1 1 1 1 1 1 1 2 2 1 5 1 1 2 1 5 1 2 2 1 1 2 2 4 1 4 1 0 */