牛客網位元組跳動冬令營網路賽J Sortable Path on Tree —— 點分治
阿新 • • 發佈:2018-12-26
題目:https://ac.nowcoder.com/acm/contest/296/J
用點分治;
記錄了值起伏的形態,二元組 (x,y) 表示有 x 個小於號,y 個大於號;
因為小於號和大於號都 >=2 就不合法了,所以狀態是 3×3 的;
然後根據各種形態拼接...寫了一晚上,最後連最簡單的樣例都過不了了...
感覺似乎走入歧途了,這樣討論太麻煩...
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> #define pb push_back using艱難嘗試namespace std; typedef long long ll; int const xn=1e5+5,inf=1e9; int n,hd[xn],ct,to[xn<<1],nxt[xn<<1],w[xn],siz[xn],rt,mx,sum,wmx; ll ans,b[3][3][xn]; bool vis[xn]; struct N{ int x,y,v; N(int x=0,int y=0,int v=0):x(x),y(y),v(v) {} }; vector<N>tmp; int rd() { int ret=0,f=1; char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return f?ret:-ret; } void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;} void ins(int x,int y,int p){for(;p<=wmx;p+=(p&-p))b[x][y][p]++;} int query(int x,int y,intp){int ret=0; for(;p;p-=(p&-p))ret+=b[x][y][p]; return ret;} int upquery(int x,int y,int p){return query(x,y,wmx)-query(x,y,p-1);} void getrt(int x,int fa) { siz[x]=1; int nmx=0; for(int i=hd[x],u;i;i=nxt[i]) { if((u=to[i])==fa)continue; getrt(u,x); siz[x]+=siz[u]; if(siz[u]>nmx)nmx=siz[u]; } nmx=max(nmx,sum-siz[x]); if(nmx<mx)mx=nmx,rt=x; } void update(int t1,int t2,int a1) { if(t1==0&&t2==0)ans+=query(2,0,wmx)+query(0,2,wmx)+query(0,1,wmx)+query(1,0,wmx)+query(1,1,wmx)+upquery(1,2,a1)+query(2,1,a1),ans++; //--\ , --/ , --__ , --`` , --``__ , --.\ , --/. if(t1==0&&t2==1)ans+=query(2,0,a1)+query(0,2,wmx)+query(1,0,wmx)+query(0,1,wmx)+upquery(1,1,a1)+upquery(1,2,a1),ans++; //--__/ , ``--\ , --__-- , ``--__ , ``--__`` , ``--\` if(t1==0&&t2==2)ans+=query(0,2,wmx)+query(0,1,wmx)+upquery(1,0,a1)+upquery(1,1,a1)+upquery(1,2,a1),ans++;//``\--\ , ``\--__ , --\__`` , ``\--__`` , ``\.` if(t1==1&&t2==0)ans+=query(0,1,a1)+query(2,0,wmx)+upquery(0,2,a1)+query(1,0,wmx)+query(1,1,a1)+query(2,1,a1),ans++; //__``-- , __--/ , __-- \ , __--`` , __--``__ , __--``. if(t1==1&&t2==1)ans+=query(1,0,a1)+query(2,0,a1)+upquery(0,1,a1)+upquery(0,2,a1),ans++;//--``__-- , --``__/ , --__``-- , --__``\ if(t1==1&&t2==2)ans+=upquery(0,1,a1)+upquery(0,2,a1),ans++;//.\--__ , .\--\ if(t1==2&&t2==0)ans+=query(2,0,wmx)+query(0,1,a1)+query(1,0,wmx)+query(1,1,a1)+query(2,1,a1),ans++;// // , /``__ , /--`` , /--``__ , /`. if(t1==2&&t2==1)ans+=query(1,0,a1)+query(2,0,a1),ans++;// /.--`` , /./ } void dfs(int x,int fa,int s1,int s2)//to rt { if(w[x]<w[fa])s1++; if(w[x]>w[fa])s2++; if((s1>=2&&s2==1&&w[rt]>w[x])||(s1==1&&s2>=2&&w[rt]<w[x]))return;//&& int t1=min(s1,2),t2=min(s2,2); ans+=query(0,0,wmx); update(t1,t2,w[x]); //printf("x=%d ans=%lld t1=%d t2=%d\n",x,ans,t1,t2); tmp.pb(N(t2,t1,w[x])); for(int i=hd[x],u;i;i=nxt[i]) if((u=to[i])!=fa&&!vis[u])dfs(u,x,s1,s2); } /* void update2(int t1,int t2,int a1) { if(t1==0&&t2==0)ans+=query(0,1,wmx)+query(0,2,wmx)+query(1,0,wmx)+query(1,1,wmx)+query(2,0,mx),ans++;//--__ , --\__ , --`` , --``__ , --/ if(t1==0&&t2==1)ans+=query(0,2,wmx)+query(1,0,wmx)+query(2,0,a1)+query(0,1,wmx),ans++;//``\--__ , __--__ , /``__ , ``--__ if(t1==0&&t2==2)ans+=query(0,1,wmx)+query(0,2,wmx)+query(1,0,a1),ans++;//``--\ , ``\--\ , __``\ if(t1==1&&t2==0)ans+=query(0,1,wmx)+query(0,2,a1)+query(1,0,wmx)+upquery(1,1,a1)+query(2,0,wmx),ans++;//--__-- , --\__`` , __--`` , --``__-- , /--`` if(t1==1&&t2==1)ans+=upquery(1,0,a1)+upquery(2,0,a1),ans++;//__--``__ , /--``__ if(t1==2&&t2==0)ans+=query(0,1,a1)+query(1,0,wmx)+upquery(1,1,a1)+query(2,0,wmx),ans++;//--__/ , __--/ , --``__/ , // } void dfs2(int x,int fa,int s1,int s2)//from rt { if(w[x]<w[fa])s2++; if(w[x]>w[fa])s1++; if((s1>=2&&s2==1)||(s1==1&&s2>=2))return; int t1=min(s1,2),t2=min(s2,2); ans+=query(0,0,wmx); update2(t1,t2,w[x]); tmp.pb(N(t1,t2,w[x])); for(int i=hd[x],u;i;i=nxt[i]) if((u=to[i])!=fa&&!vis[u])dfs2(u,x,s1,s2); } */ void work(int x,int ss) { printf("x=%d\n",x); vis[x]=1;// memset(b,0,sizeof b); for(int i=hd[x],u;i;i=nxt[i]) { if(vis[u=to[i]])continue; dfs(u,x,0,0); printf("u=%d\n",u); for(int j=0;j<tmp.size();j++){ins(tmp[j].x,tmp[j].y,tmp[j].v); printf("t(%d,%d,%d)\n",tmp[j].x,tmp[j].y,tmp[j].v);} printf("ans=%lld\n",ans); tmp.clear(); } /* //無序 memset(b,0,sizeof b); for(int i=hd[x],u;i;i=nxt[i]) { if(vis[u=to[i]])continue; dfs2(u,x,0,0); for(int j=0;j<tmp.size();j++)ins(tmp[j].x,tmp[j].y,tmp[j].v); tmp.clear(); } */ for(int i=hd[x],u;i;i=nxt[i]) { if(vis[u=to[i]])continue; sum=(siz[u]>siz[x]?ss-siz[x]:siz[x]); mx=inf; getrt(u,0); work(u,sum);//(u,0) } } int main() { int T=rd(); while(T--) { n=rd(); ct=0; memset(hd,0,sizeof hd); wmx=0; memset(vis,0,sizeof vis); for(int i=1;i<=n;i++)w[i]=rd(),wmx=max(wmx,w[i]); for(int i=1,x,y;i<n;i++)x=rd(),y=rd(),add(x,y),add(y,x); sum=n; mx=inf; getrt(1,0);// ans=0; work(rt,sum); printf("%lld\n",ans+n);// } return 0; }
然後看了看AC程式碼,也有一篇拼形態的,看了半天沒看懂...
其實應該是不用拼形態的,直接考慮拼起來以後的大於號、小於號情況來判斷是否對兩個端點的值有要求;
如果拼起來以後是 (0,~) 或 (~,0) 或 (1,1),那麼對前後端點值的大小是沒有限制的;
如果是 (1,~),小於號只有一個,那麼一定是兩個下降的段,要求前端的值 <= 後端的值,才可以從小於號的地方斷開,成為有序的;
如果是 (~,1),大於號只有一個,同理,前端的值 >= 後端的值;
所以只要考慮拼完以後的形態就可以了,討論變得簡單許多;
注意判掉不合法的情況,而且每個點自己單獨也是一種合法的情況;
清空樹狀陣列時不要暴力使複雜度劇增,可以再 dfs 一遍,把剛才加上的減去;
可以對值域先離散化,加快樹狀陣列的查詢;
注意點分治不要寫錯;
注意本題沒有對 1e9+7 取模呵呵。
程式碼如下:
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> #define pb push_back using namespace std; typedef long long ll; int const xn=1e5+5,inf=1e9; int n,hd[xn],ct,to[xn<<1],nxt[xn<<1],w[xn],siz[xn],mx,sum,rt,wmx,b[3][3][xn]; ll ans; bool vis[xn]; struct N{ int u,v,w; N(int u=0,int v=0,int w=0):u(u),v(v),w(w) {} }; vector<N>tmp; int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return f?ret:-ret; } void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;} void ins(int x,int y,int p,int v){for(;p<=wmx;p+=(p&-p))b[x][y][p]+=v;} int query(int x,int y,int p){if(!p)return 0; int ret=0; for(;p;p-=(p&-p))ret+=b[x][y][p]; return ret;} int upquery(int x,int y,int p){return query(x,y,wmx)-query(x,y,p-1);} void getrt(int x,int fa) { siz[x]=1; int nmx=0; for(int i=hd[x],u;i;i=nxt[i]) { if((u=to[i])==fa||vis[u])continue; getrt(u,x); siz[x]+=siz[u]; if(nmx<siz[u])nmx=siz[u]; } nmx=max(nmx,sum-siz[x]); if(nmx<mx)mx=nmx,rt=x; } void update(int s1,int s2,int w) { ans++; for(int i=0;i<=2;i++) for(int j=0;j<=2;j++) { int t1=s1+i,t2=s2+j; if(t1>=2&&t2>=2)continue; if(!t1||!t2||(t1==1&&t2==1))ans+=query(i,j,wmx); else if(t1==1)ans+=upquery(i,j,w); else if(t2==1)ans+=query(i,j,w); } } void dfs(int x,int fa,int s1,int s2) { if(w[x]<w[fa])s1++; if(w[x]>w[fa])s2++; if(s1>2)s1=2; if(s2>2)s2=2; if((s1==2&&s2==2)||(s1==1&&s2==2&&w[x]>w[rt])||(s1==2&&s2==1&&w[x]<w[rt]))return; update(s1,s2,w[x]); tmp.pb(N(s1,s2,w[x])); //printf("x=%d s1=%d s2=%d ans=%d\n",x,s1,s2,ans); for(int i=hd[x],u;i;i=nxt[i]) if((u=to[i])!=fa&&!vis[u])dfs(u,x,s1,s2); } void clear(int x,int fa,int s1,int s2) { if(w[x]<w[fa])s1++; if(w[x]>w[fa])s2++; if(s1>2)s1=2; if(s2>2)s2=2; if((s1==2&&s2==2)||(s1==1&&s2==2&&w[x]>w[rt])||(s1==2&&s2==1&&w[x]<w[rt]))return; ins(s2,s1,w[x],-1); for(int i=hd[x],u;i;i=nxt[i]) if((u=to[i])!=fa&&!vis[u])clear(u,x,s1,s2); } void work(int x,int ss) { //printf("x=%d\n",x); vis[x]=1; for(int i=hd[x],u;i;i=nxt[i]) { if(vis[u=to[i]])continue; dfs(u,x,0,0); int siz=tmp.size(); for(int j=0;j<siz;j++)ins(tmp[j].v,tmp[j].u,tmp[j].w,1); tmp.clear(); } for(int i=hd[x],u;i;i=nxt[i]) { if(vis[u=to[i]])continue; clear(u,x,0,0); } for(int i=hd[x],u;i;i=nxt[i]) { if(vis[u=to[i]])continue; sum=(siz[u]>siz[x]?ss-siz[x]:siz[u]);//siz[u]! mx=inf; getrt(u,0); work(rt,sum); } } int tt[xn]; int main() { int T=rd(); while(T--) { n=rd(); ct=0; for(int i=1;i<=n;i++)hd[i]=0; for(int i=1;i<=n;i++)w[i]=rd(),tt[i]=w[i]; sort(tt+1,tt+n+1); wmx=unique(tt+1,tt+n+1)-tt-1; for(int i=1;i<=n;i++)w[i]=lower_bound(tt+1,tt+wmx+1,w[i])-tt; for(int i=1,x,y;i<n;i++)x=rd(),y=rd(),add(x,y),add(y,x); ans=0; for(int i=1;i<=n;i++)vis[i]=0; mx=inf; sum=n; getrt(1,0); work(rt,n); printf("%lld\n",ans+n); } return 0; }