poj 3468 A Simple Problem with Integers splay
阿新 • • 發佈:2019-01-08
之前用線段樹寫的,這次用splay寫了下。。
#include <cstdio> #include <vector> #include <algorithm> #include <cstring> using namespace std; #define ll long long #define lson tr[x][0] #define rson tr[x][1] const int maxn = 1e5 + 10; const int INF = 1e9; int root, tot; int fa[maxn], tr[maxn][2], sz[maxn]; ll tag[maxn], num[maxn], sum[maxn], a[maxn]; int judge(int x) { return tr[fa[x]][1] == x; } void pushup(int x) { if (x) { sz[x] = sz[tr[x][0]] + sz[tr[x][1]] + 1; sum[x] = sum[tr[x][0]] + sum[tr[x][1]] + num[x]; } } void pushdown(int x) { if (tag[x]) { tag[lson] += tag[x]; tag[rson] += tag[x]; sum[lson] += (tag[x] * sz[lson]); sum[rson] += (tag[x] * sz[rson]); num[lson] += tag[x]; num[rson] += tag[x]; tag[x] = 0; } } void rotate(int x) { int y = fa[x], d = judge(x); if (tr[y][d] = tr[x][d ^ 1]) fa[tr[y][d]] = y; if (fa[x] = fa[y]) tr[fa[y]][judge(y)] = x; tr[fa[y] = x][d ^ 1] = y; pushup(y); } void splay(int x, int k) { for (int y; (y = fa[x]) != k; rotate(x)) if (fa[y] != k) rotate(judge(x) == judge(y) ? y : x); pushup(x); if (!k) root = x; } int kth(int x, int k) { pushdown(x); if (sz[tr[x][0]] + 1 == k) return x; if (k <= sz[tr[x][0]]) return kth(tr[x][0], k); return kth(tr[x][1], k - 1 - sz[tr[x][0]]); } void update(int l, int r, int c) { int x = kth(root, l - 1); int y = kth(root, r + 1); splay(x, 0); splay(y, x); int rt = tr[y][0]; tag[rt] += c; num[rt] += c; sum[rt] += (c * sz[rt]); pushup(y);pushup(x); } ll query(int l, int r) { int x = kth(root, l - 1); int y = kth(root, r + 1); splay(x, 0); splay(y, x); return sum[tr[y][0]]; } void build(int &x, int l, int r) { if (l > r) { x = 0; return; } int mid = (l + r) / 2; x = ++tot; num[x] = a[mid]; sz[x] = 1; tag[x] = sum[x] = 0; build(tr[x][0], l, mid - 1); build(tr[x][1], mid + 1, r); fa[tr[x][0]] = fa[tr[x][1]] = x; pushup(x); } int main() { int n, m, l, r, c; char op[10]; while (scanf("%d%d", &n, &m) != EOF) { a[1] = a[n + 2] = 0; fa[1] = 0; for (int i = 2; i <= n + 1; i++) { scanf("%lld", &a[i]); } root = tot = 0; build(root, 1, n + 2); while (m--) { scanf("%s", op); if (op[0] == 'C') { scanf("%d%d%d", &l, &r, &c); l++; r++; update(l, r, c); } if (op[0] == 'Q') { scanf("%d%d", &l, &r); l++; r++; printf("%lld\n", query(l, r)); } } } return 0; }