1. 程式人生 > >JZOJ 3840. 【NOI2015模擬9.20】colortree (Standard IO) 題解

JZOJ 3840. 【NOI2015模擬9.20】colortree (Standard IO) 題解

Description
給定一棵N個結點的樹,每個點一開始都是白色,進行Q次操作,操作有以下兩種:
1、給定一個節點x,把x染成藍色
2、給定一個節點x,詢問x到其他所有藍色點的距離和輸入N, Q,startSeed,threshold, maxDist。
用以下方法生成這棵樹以及Q次操作:

int curValue = startSeed;
int genNextRandom() {
    curValue = (curValue * 1999 + 17) % 1000003;
    return curValue;
}
void generateInput() {
    for
(int i = 0; i < N-1; i++) { distance[i] = genNextRandom() % maxDist; parent[i] = genNextRandom(); if (parent[i] < threshold) { parent[i] = i; } else { parent[i] = parent[i] % (i + 1); } } for (int i = 0; i < Q; i++) { queryType[
i] = genNextRandom() % 2 + 1; queryNode[i] = genNextRandom() % N; } }

以上程式輸出四個陣列:parent,distance,queryType以及queryNode。
其中parent、distance有N-1個元素,對於每個i(0<=i<=n-2),(i+1)與parent[i]有一條邊相連,長度為distance[i]。注意0<=parent[i]<=i。
queryType、queryNode有Q個元素,對於每個i,操作種類是queryType[i],操作的節點是queryNode[i].
輸出所有操作2答案的異或值。

Input
輸入一行,包含N, Q,startSeed,threshold, maxDist

Output
輸出所有操作2答案的異或值。

Sample Input
輸入1:
4 6 15 2 5
輸入2:
4 5 2 9 10
輸入3:
14750 50 29750 1157 21610

Sample Output
輸出1:
7
輸出2:
30
輸出3:
2537640

Data Constraint
2<=N<=100,000
1<=Q<=100,000
0<=startSeed<=1,000,002
0<=threshold<=1,000,003
1<=maxDist<=1,000,003

Solution

由於查詢的是一個點到所有其他藍點路徑長度和,如果我們考慮了樹上任意一對點的路徑也就考慮了所有的詢問,於是考慮點分治。

首先離線詢問,對於每個點記錄其第一次染成藍色的時間 t i t_i ,把對某個點的詢問放在該點上,並記錄詢問的時間 q i q_i 。點分治時,把樹狀陣列 t i t_i 位置加上 i i 到當前點的距離,同時再把個數用樹狀陣列統計一下,然後列舉每個點,列舉該點上的詢問,在樹狀數組裡查詢在這次詢問之前的藍點的貢獻,記錄答案。但是這樣會有兩點在同一子樹內的情況,我們對於各子樹刪去這種情況即可。

複雜度 O ( ( n + q ) l o g 2 n ) O((n+q)log^2n)

Code

#include <cstdio>
#include <cstring>

typedef long long ll;
const int N = 100007, INF = 0x3f3f3f3f;
int min(int a, int b) { return a < b ? a : b; }
int max(int a, int b) { return a > b ? a : b; }

int n, q;

int curValue, startSeed, threshold, maxDist, distance[N], parent[N], queryType[N], queryNode[N];
int genNextRandom()
{
	curValue = (curValue * 1999 + 17) % 1000003;
	return curValue;
}
void generateInput()
{
	curValue = startSeed;
	for (int i = 0; i < n - 1; i++)
	{
		distance[i] = genNextRandom() % maxDist;
		parent[i] = genNextRandom();
		if (parent[i] < threshold) parent[i] = i;
		else parent[i] = parent[i] % (i + 1);
    }
    for (int i = 0; i < q; i++)
	{
		queryType[i] = genNextRandom() % 2 + 1;
		queryNode[i] = genNextRandom() % n;
	}
}

int p, sum;
int tot, st[N], to[N << 1], nx[N << 1], size[N], mxsiz[N], del[N], tim[N], alen, arr[N];
ll len[N << 1], dis[N], ans[N], ret;
void add(int u, int v, ll w) { if (!u || !v) return; to[++tot] = v, nx[tot] = st[u], len[tot] = w; st[u] = tot; }

int cnt, head[N], link[N], next[N];
void insert(int u, int id) { link[++cnt] = id, next[cnt] = head[u], head[u] = cnt; }

void getp(int u, int from)
{
	size[u] = 1, mxsiz[u] = 0;
	for (int i = st[u]; i; i = nx[i])
		if (to[i] != from && !del[to[i]])
			getp(to[i], u), size[u] += size[to[i]], mxsiz[u] = max(mxsiz[u], size[to[i]]);
	mxsiz[u] = max(mxsiz[u], sum - size[u]);
	if (mxsiz[u] < mxsiz[p]) p = u;
}

ll tr[N][2];
void plus(int po, ll val, int k) { for (; po <= q + 1; po += (po & (-po))) tr[po][k] += val; }
ll getsum(int po, int k) { ll ret = 0; for (; po; po -= (po & (-po))) ret += tr[po][k]; return ret; }
void clear(int po) { for (; po <= q + 1; po += (po & (-po))) tr[po][0] = tr[po][1] = 0; }
void getdis(int u, int from)
{
	arr[++alen] = u;
	for (int i = st[u]; i; i = nx[i]) if (to[i] != from && !del[to[i]]) dis[to[i]] = dis[u] + len[i], getdis(to[i], u);
}
void calc(int u, int val, int t)
{
	alen = 0, dis[u] = val, getdis(u, 0);
	for (int i = 1; i <= alen; i++) if (tim[arr[i]] < INF) plus(tim[arr[i]], 1, 0), plus(tim[arr[i]], dis[arr[i]], 1);
	for (int i = 1; i <= alen; i++)
	{
		int w = arr[i];
		for (int j = head[w]; j; j = next[j]) ans[link[j]] += t * (dis[w] * getsum(link[j] - 1, 0) + getsum(link[j] - 1, 1));
	}
	for (int i = 1; i <= alen; i++) if (tim[arr[i]] < INF) clear(tim[arr[i]]);
}
void solve(int u)
{
	calc(u, 0, 1); //總共統計一次答案
	del[u] = 1;
	for (int i = st[u]; i; i = nx[i]) if (!del[to[i]]) calc(to[i], len[i], -1); //各子樹分別刪去答案
	for (int i = st[u]; i; i = nx[i])
		if (!del[to[i]])
		{
			sum = size[to[i]], p = 0;
			getp(to[i], 0), solve(p);
		}
}

int main()
{
	scanf("%d%d%d%d%d", &n, &q, &startSeed, &threshold, &maxDist);
	generateInput();
	for (int i = 0; i < n - 1; i++) add(i + 2, parent[i] + 1, distance[i]), add(parent[i] + 1, i + 2, distance[i]);
	memset(tim, 0x3f, sizeof(tim));
	for (int i = 0; i < q; i++)
		if (queryType[i] == 1) tim[queryNode[i] + 1] = min(tim[queryNode[i] + 1], i + 1);
		else insert(queryNode[i] + 1, i + 1);
	mxsiz[0] = INF;
	sum = n, p = 0;
	getp(1, 0);
	solve(p);
	ret = 0;
	for (int i = 1; i <= q; i++) ret ^= ans[i];
	printf("%lld\n", ret);
	return 0;
}