1. 程式人生 > >2018.8.24(dfs序,線段樹,動態樹直徑的維護)

2018.8.24(dfs序,線段樹,動態樹直徑的維護)

題目大意:
對於一棵樹,每次詢問刪掉兩棵子樹的直徑

每次刪掉兩顆子樹相當於在dfs序挖掉兩段區間再求解
我們在dfs序上維護一段區間內點集的直徑
每次詢問的時候考慮合併
有這麼一個結論
合併兩個相鄰聯通塊新的直徑的兩個端點肯定是原來兩個聯通塊的直徑的兩個直徑的端點

#include<bits/stdc++.h>
using namespace std;
#define rep(i,j,k) for(int i = j;i <= k;++i)
#define repp(i,j,k) for(int i = j;i >= k;--i)
#define
rept(i,x) for(int i = linkk[x];i;i = e[i].n)
#define P pair<int,int> #define pb push_back #define pc putchar #define mp make_pair #define file(k) memset(k,0,sizeof(k)); #define ll long long #define ls root * 2 #define rs root * 2 + 1 int n , q; int linkk[101000] , t , dep[101000]; int ola[401000] , totw , first[101000
] , last[101000]; int mn[401000][20]; struct node{ int n , y; }e[201000]; struct tree{ int l,r,len; }tr[802000]; tree ans; int read() { int sum = 0;char c = getchar();bool flag = true; while(c < '0' || c > '9') {if(c == '-') flag = false;c = getchar();} while(c >= '0' && c <= '9'
) sum = sum * 10 + c - 48 , c = getchar(); if(!flag) sum = -sum; return sum; } void insert(int x,int y) { e[++t].y = y;e[t].n = linkk[x];linkk[x] = t; e[++t].y = x;e[t].n = linkk[y];linkk[y] = t; return; } int Lca(int x,int y) { if(x == y) return x; if(x == 0 || y == 0) return 0; if(first[x] > first[y]) swap(x,y); int k = log2(first[y] - first[x]); return dep[ mn[first[x]][k] ] > dep[ mn[first[y] - (1<<k)][k] ] ? mn[first[y] - (1<<k)][k] : mn[first[x]][k]; } void dfs(int x,int fa) { ola[++totw] = x;first[x] = totw;dep[x] = dep[fa] + 1; rept(i,x) if(e[i].y != fa) { dfs(e[i].y,x); ola[++totw] = x; } last[x] = totw; } bool mycmp(tree a,tree b){return a.len > b.len;} int get(int x,int y){return dep[x] + dep[y] - 2 * dep[Lca(x,y)];} void update(tree &root,tree a,tree b) { root = a.len > b.len ? a : b; tree tmp[4]; tmp[0].l = a.l;tmp[0].r = b.l;tmp[0].len = get(a.l,b.l); tmp[1].l = a.l;tmp[1].r = b.r;tmp[1].len = get(a.l,b.r); tmp[2].l = a.r;tmp[2].r = b.l;tmp[2].len = get(a.r,b.l); tmp[3].l = a.r;tmp[3].r = b.r;tmp[3].len = get(a.r,b.r); sort(tmp,tmp+4,mycmp); if(root.len < tmp[0].len) root = tmp[0]; } void build(int root,int l,int r) { if(l == r) { tr[root].l = tr[root].r = mn[l][0]; tr[root].len = 0; return; } int mid = (l + r)>>1; build(ls,l,mid);build(rs,mid+1,r); update(tr[root],tr[ls],tr[rs]); } void init() { n = read();q = read(); rep(i,1,n-1) { int x = read() , y = read(); insert(x,y); } dfs(1,0); rep(i,1,totw) mn[i][0] = ola[i]; for(int j = 1;(1<<j) <= totw;++j) for(int i = 1;(i + (1<<j)) <= totw;++i) { mn[i][j] = mn[i][j-1]; if(dep[ mn[i][j-1] ] > dep[ mn[i+(1<<(j-1))][j-1] ]) mn[i][j] = mn[i+(1<<(j-1))][j-1]; } } void New(int root,int l,int r,int x,int y) { if(r < x || l > y) return; if(x > y) return; if(x <= l && r <= y) { update(ans,ans,tr[root]); return; } int mid = (l + r)>>1; New(ls,l,mid,x,y);New(rs,mid+1,r,x,y); } void work() { dep[0] = 10000000; while(q--) { int x = read() , y = read(); ans.l = ans.r = ans.len = 0; if(first[x] > first[y]) swap(x,y); int a = first[x] , b = last[x] , c = first[y] , d = last[y]; New(1,1,totw,1,a-1); New(1,1,totw,max(b,d)+1,totw); if(b < c) New(1,1,totw,b+1,c-1); printf("%d\n",ans.len); } } int main() { freopen("find.in","r",stdin); freopen("find.out","w",stdout); init(); build(1,1,totw); work(); return 0; }