bzoj2243: [SDOI2011]染色(樹鏈剖分)
樹鏈剖分好題啊!
題目描述:給定一顆n個點的樹,有m個操作,操作有兩種。
1、將節點a到節點b路徑上所有的點都染成顏色c。
2、詢問節點a到節點b路徑上的顏色段數量(連續的被認為是同一段)。
輸入格式:第一行包含兩個整數n和m,表示節點數和操作個數。
第二行n個整數,表示每個節點的初始顏色。
接下來n - 1行,每行兩個整數描述一棵樹。
接下來m行,每行表示一個操作。
輸出格式:對於每一個詢問顏色段數量的操作,輸出一行一個整數,表示顏色段的數量。
輸入樣例:
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
輸出樣例:
3
1
2
解析:很顯然是樹剖,問題是如何線上段樹上維護不同顏色的段數。
用sum[o]表示一段的不同顏色的段數,lf[o]表示這一段最左邊的顏色,rt[o]表示這一段最右邊的顏色。
那麼在進行合併時,lf[o] = lf[o << 1],rt[o] = rt[o << 1 | 1]。需要注意的是sum[o]的維護,若左端點的右端顏色等於右端點的左端顏色,則sum[o]要減1。
即若顏色相同,sum[o] = sum[o << 1] + sum[o << 1 | 1] - 1;若顏色不同,sum[o] = sum[o << 1] + sum[o << 1 | 1]
另一個需要思考的地方便是如何計算答案,由於在樹上剖鏈時剖出的鏈是不連續的,所以不能單純進行累加。
這時就要用到 lca 了,可以求出兩個點的lca,分別對兩段進行累加,這樣答案就可以計算了。
由於在樹剖時是從深度大的往深度小的剖,所以線上段樹中較右的節點會先別訪問到,所以可以記錄一個last,表示上一次剖到的左端點顏色是last,這樣就可以將答案累加。
有很多細節需要注意,細節可以看程式碼。
程式碼如下:
1 #include<cstdio> 2 #include<vector> 3 #include<algorithm> 4 #include<cstring> 5 #define lc o << 1 6 #define rc o << 1 | 1 7 using namespace std; 8 9 const int maxn = 1e5 + 5; 10 int n, m, col[maxn], bj[maxn * 4], sum[maxn * 4], lf[maxn * 4], rt[maxn * 4], ans, last; 11 int dep[maxn], fa[maxn], size[maxn], heavy[maxn], seq[maxn], dfn[maxn], top[maxn], cnt; 12 char s[5]; 13 vector <int> ve[maxn]; 14 15 int read(void) { 16 char c; while (c = getchar(), c < '0' || c >'9'); int x = c - '0'; 17 while (c = getchar(), c >= '0' && c <= '9') x = x * 10 + c - '0'; return x; 18 } 19 20 void dfs1(int u, int pre) { 21 dep[u] = dep[pre] + 1; 22 fa[u] = pre; size[u] = 1; 23 for (int i = 0; i < ve[u].size(); ++ i) { 24 int v = ve[u][i]; 25 if (v == pre) continue; 26 dfs1(v, u); 27 size[u] += size[v]; 28 if (size[v] > size[heavy[u]]) heavy[u] = v; 29 } 30 } 31 32 void dfs2(int u, int cur) { 33 dfn[u] = ++ cnt; seq[cnt] = u; 34 top[u] = cur; 35 if (!heavy[u]) return; 36 dfs2(heavy[u], cur); 37 for (int i = 0; i < ve[u].size(); ++ i) { 38 int v = ve[u][i]; 39 if (v == fa[u] || v == heavy[u]) continue; 40 dfs2(v, v); 41 } 42 } 43 44 void maintain(int o) { //維護每段的資訊 45 lf[o] = lf[lc]; rt[o] = rt[rc]; 46 if (rt[lc] == lf[rc]) sum[o] = sum[lc] + sum[rc] - 1; 47 else sum[o] = sum[lc] + sum[rc]; 48 } 49 50 void pushdown(int o) { //標記下放 51 sum[lc] = sum[rc] = 1; 52 lf[lc] = lf[rc] = rt[lc] = rt[rc] = bj[o]; 53 bj[lc] = bj[rc] = bj[o]; bj[o] = -1; 54 } 55 56 void build(int o, int l, int r) { //建樹 57 if (l == r) { 58 lf[o] = col[seq[l]]; 59 rt[o] = col[seq[l]]; 60 sum[o] = 1; 61 return; 62 } 63 int mid = l + r >> 1; 64 build(lc, l, mid); build(rc, mid + 1, r); 65 maintain(o); 66 } 67 68 void modify(int o, int l, int r, int ql, int qr, int c) { //區間修改 69 if (ql <= l && qr >= r) { 70 lf[o] = rt[o] = c; 71 sum[o] = 1; bj[o] = c; 72 return; 73 } 74 int mid = l + r >> 1; 75 if (bj[o] != -1) pushdown(o); 76 if (ql <= mid) modify(lc, l, mid, ql, qr, c); 77 if (qr > mid) modify(rc, mid + 1, r, ql, qr, c); 78 maintain(o); 79 } 80 81 void query(int o, int l, int r, int ql, int qr) { 82 if (ql <= l && qr >= r) { 83 if (rt[o] == last) ans += sum[o] - 1; //如果右端的顏色和上一個左端相同,就-1 84 else ans += sum[o]; 85 last = lf[o]; //更新last表示的左端點 86 return; 87 } 88 int mid = l + r >> 1; 89 if (bj[o] != -1) pushdown(o); 90 if (qr > mid) query(rc, mid + 1, r, ql, qr); //由於是從右向左更新答案,所以線段樹上詢問時也要優先向右詢問! 91 if (ql <= mid) query(lc, l, mid, ql, qr); 92 } 93 94 void chain_modify(int x, int y, int c) { //樹上修改 95 int fax = top[x], fay = top[y]; 96 while (fax != fay) { 97 if (dep[fax] < dep[fay]) { 98 swap(fax, fay); 99 swap(x, y); 100 } 101 modify(1, 1, n, dfn[fax], dfn[x], c); 102 x = fa[fax]; 103 fax = top[x]; 104 } 105 if (dep[x] > dep[y]) swap(x, y); 106 modify(1, 1, n, dfn[x], dfn[y], c); 107 } 108 109 void chain_query(int x, int y) { //樹上詢問 110 int fax = top[x], fay = top[y]; 111 while (fax != fay) { 112 if (dep[fax] < dep[fay]) { 113 swap(fax, fay); 114 swap(x, y); 115 } 116 query(1, 1, n, dfn[fax], dfn[x]); 117 x = fa[fax]; 118 fax = top[x]; 119 } 120 if (dep[x] > dep[y]) swap(x, y); 121 query(1, 1, n, dfn[x], dfn[y]); 122 } 123 124 int getlca(int x, int y) { //求lca 125 int fax = top[x], fay = top[y]; 126 while (fax != fay) { 127 if (dep[fax] < dep[fay]) { 128 swap(fax, fay); 129 swap(x, y); 130 } 131 x = fa[fax]; 132 fax = top[x]; 133 } 134 if (dep[x] > dep[y]) swap(x, y); 135 return x; 136 } 137 138 int main() { 139 n = read(); m = read(); 140 for (int i = 1; i <= n; ++ i) col[i] = read(); 141 for (int i = 1; i < n; ++ i) { 142 int x = read(), y = read(); 143 ve[x].push_back(y); 144 ve[y].push_back(x); 145 } 146 dfs1(1, 0); 147 dfs2(1, 1); 148 build(1, 1, n); 149 memset(bj, -1, sizeof(bj)); //顏色可以為0!所以初始標記是-1 150 while (m --) { 151 scanf("%s", s + 1); 152 if (s[1] == 'C') { 153 int x = read(), y = read(), c = read(); 154 int lca = getlca(x, y); 155 chain_modify(x, lca, c); chain_modify(lca, y, c); 156 } 157 else { //要求兩次答案,並累加答案 158 int x = read(), y = read(); ans = 0; last = -1; 159 int lca = getlca(x, y); 160 chain_query(x, lca); 161 last = -1; 162 chain_query(lca, y); 163 printf("%d\n", ans - 1); //這裡ans必須-1,因為lca處的顏色必定相同 164 } 165 } 166 return 0; 167 }