[BJOI2018] 鏈上二次求和 題解
Description
給你一根 \(n\) 個點的鏈,每個點有權值 \(a_i\) ,要求支援兩種操作:
- 將 \(u,v\) 間的數加上 \(d\)
- 詢問鏈上所有長度在 \([l,r]\) 間的路徑權值和。
\(n\le2\times10^5\)
Sol
考慮每個點對於答案產生的貢獻,手玩一下可以得出下面的表(其中點 \((i,j)\) 表示點 \(j\) 在所有長度為 \(i\) 的路徑上的出現次數)。
\[\begin{gathered} \begin{bmatrix} 1 & 1 & 1 & 1 & 1 & 1 & 1 \\ 1 & 2 & 2 & 2 & 2 & 2 & 1 \\1 & 2 & 3 & 3 & 3 & 2 & 1 \\1 & 2 & 3 & 4 & 3 & 2 & 1 \\1 & 2 & 3 & 3 & 3 & 2 & 1 \\1 & 2 & 2 & 2 & 2 & 2 & 1 \\1 & 1 & 1 & 1 & 1 & 1 & 1 \\\end{bmatrix} \quad \end{gathered} \]觀察到矩陣很有規律,我們考慮線段樹維護所有長度為 \(i\) 的路徑的總全職和,即線段樹最底層一個點維護的是矩陣的一行的所有點的 \(a_i\times cnt_{i,j}\) ,直接維護因為每一行的 \(cnt\) 都不一樣,我們考慮換一種方式:
我們假設每一條路徑上結點的出現次數都是形如 \(1234321\) 型別的,那麼我們的答案就是總貢獻減去多出的貢獻。
總貢獻即 \(a_1+2a_2+3a_3+...+(k-1)a{k-1}+ka_k+(k-1)a_{k+1}+na_n\),拿線段樹簡單維護即可,考慮多出的貢獻,假設我們查詢長度在 \([2,5]\) 之間的路徑權值和,那麼多出來的部分即為:
那麼我們多出來的貢獻即為 \(a_2+4a_3+9a_4+4a_5+a_6\),正著維護 \(1^2a_1+2^2a_2+3^2a_3+...+n^2a_n\) 與 \(n^2a_1+(n-1)^2a_2+(n-2)^2a_3+...+a_n\) ,查詢時減去即可。
還有一種維護方式就是將上圖繼續分割,分成 \(4\) 個小塊,如圖所示:
分別計算每一塊的貢獻第一塊為 \(1,3\),第二塊為 \(1\),第三塊為 \(1,3,6\),第四塊為 \(1,3\)。
那麼我們維護 \(1a_1+3a_2+6a_3+10a_4+...\) 與 \(1a_n+3a_{n-1}+6a_{n-2}+...\)
修改時維護上面的貢獻即可。
時間複雜度 \(O(n\log n)\)。
Code
程式碼實現有那麼億點點噁心...
#include<bits/stdc++.h>
#define int long long
#define Mod 1000000007
#define inv6 166666668
#define il inline
#define re register
using namespace std;
il int Read() {
int x = 0, f = 1; char ch = getchar();
while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)) {x = (x << 3) + (x << 1) + ch - '0'; ch = getchar();}
return x * f;
}
struct node {
int sz, sum, sum2, sum3, _sum2, _sum3, addv;
node(int Sz = 0, int Sum = 0, int Sum2 = 0, int Sum3 = 0, int _Sum2 = 0, int _Sum3 = 0, int Addv = 0) {sz = Sz, _sum2 = _Sum2, _sum3 = _Sum3, sum = Sum, sum2 = Sum2, sum3 = Sum3, addv = Addv;}
}seg[800005];
il node Merge(node A, node B) {
node C;
C.sz = A.sz + B.sz;
C.sum = (A.sum + B.sum) % Mod;
C.sum2 = (A.sum2 + B.sum2 + A.sz * B.sum % Mod) % Mod;
C.sum3 = (A.sum3 + B.sum3 + A.sz * B.sum2 % Mod + (A.sz + 1) * A.sz / 2 % Mod * B.sum % Mod) % Mod;
C._sum2 = (B._sum2 + A._sum2 + B.sz * A.sum % Mod) % Mod;
C._sum3 = (B._sum3 + A._sum3 + B.sz * A._sum2 % Mod + (B.sz + 1) * B.sz / 2 % Mod * A.sum % Mod) % Mod;
return C;
}
int a[200005];
il void build(int o, int l, int r) {
seg[o].sz = r - l + 1;
if(l == r) {
seg[o].sum = seg[o].sum2 = seg[o].sum3 = seg[o]._sum2 = seg[o]._sum3 = a[l];
return ;
}
int mid = (l + r) >> 1;
build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r);
seg[o] = Merge(seg[o << 1], seg[o << 1 | 1]);
}
il void pusha(int o, int l, int r, int x) {
int len = (r - l + 1);
(seg[o].sum3 += (len * (len + 1) % Mod * (len + 2) % Mod * inv6 % Mod) * x % Mod) %= Mod;
(seg[o]._sum3 += (len * (len + 1) % Mod * (len + 2) % Mod * inv6 % Mod) * x % Mod) %= Mod;
(seg[o].sum2 += (len * (len + 1) / 2 % Mod) * x % Mod) %= Mod;
(seg[o]._sum2 += (len * (len + 1) / 2 % Mod) * x % Mod) %= Mod;
(seg[o].sum += len * x % Mod) %= Mod;
seg[o].addv = (seg[o].addv + x) % Mod;
}
il void pushdown(int o, int l, int r) {
int mid = (l + r) >> 1;
if(seg[o].addv) {
pusha(o << 1, l, mid, seg[o].addv);
pusha(o << 1 | 1, mid + 1, r, seg[o].addv);
seg[o].addv = 0;
}
}
il void modify(int o, int l, int r, int nl, int nr, int x) {
if(nl <= l && r <= nr) return pusha(o, l, r, x);
pushdown(o, l, r); int mid = (l + r) >> 1;
if(nl <= mid) modify(o << 1, l, mid, nl, nr, x);
if(mid < nr) modify(o << 1 | 1, mid + 1, r, nl, nr, x);
seg[o] = Merge(seg[o << 1], seg[o << 1 | 1]);
}
il node query(int o, int l, int r, int nl, int nr) {
if(nl > nr) return (node){0ll, 0ll, 0ll, 0ll, 0ll, 0ll, 0ll};
if(nl <= l && r <= nr) return seg[o];
pushdown(o, l, r); int mid = (l + r) >> 1;
if(nl <= mid) {
if(mid < nr) return Merge(query(o << 1, l, mid, nl, nr), query(o << 1 | 1, mid + 1, r, nl, nr));
return query(o << 1, l, mid, nl, nr);
}
return query(o << 1 | 1, mid + 1, r, nl, nr);
}
signed main() {
int n = Read(), m = Read();
for(re int i = 1; i <= n; i++) a[i] = Read();
build(1, 1, n);
for(re int i = 1; i <= m; i++) {
int opt = Read(), x = Read(), y = Read(), z;
if(x > y) swap(x, y);
if(opt == 1) {
z = Read();
modify(1, 1, n, x, y, z);
}
else {
if(x > y) {puts("0"); continue;}
if(n % 2 == 1) {
int mid = (1 + n) >> 1; node A, B, C, D;
A = query(1, 1, n, 1, mid); B = query(1, 1, n, mid + 1, n);
int ans = (A.sum2 + B._sum2) % Mod * (y - x + 1) % Mod;
if(y <= mid) {
if(y == mid) --y;
A = query(1, 1, n, x + 1, mid); B = query(1, 1, n, mid + 1, 2 * mid - x - 1);
C = query(1, 1 ,n, y + 2, mid); D = query(1, 1, n, mid + 1, 2 * mid - y - 2);
ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
ans = (ans + Mod) % Mod;
}
else if(x >= mid) {
if(x == mid) ++x;
A = query(1, 1, n, 2 * mid - y + 1, mid); B = query(1, 1, n, mid + 1, y - 1);
C = query(1, 1, n, 2 * mid - x + 2, mid); D = query(1, 1, n, mid + 1, x - 2);
ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
ans = (ans + Mod) % Mod;
}
else {
A = query(1, 1, n, x + 1, mid), B = query(1, 1, n, mid + 1, 2 * mid - x - 1);
C = query(1, 1, n, 2 * mid - y + 1, mid), D = query(1, 1, n, mid + 1, y - 1);
ans -= (A.sum3 + C.sum3 + B._sum3 + D._sum3) % Mod;
ans = (ans + Mod) % Mod;
}
printf("%lld\n", ans);
}
if(n % 2 == 0) {
int Lmid = (1 + n) >> 1, Rmid = Lmid + 1; node A, B, C, D;
A = query(1, 1, n, 1, Lmid); B = query(1, 1, n, Rmid, n);
int ans = (A.sum2 + B._sum2) % Mod * (y - x + 1) % Mod;
if(y <= Rmid) {
if(y == Rmid) --y;
if(y == Lmid) --y;
A = query(1, 1, n, x + 1, Lmid); B = query(1, 1, n, Rmid, 2 * Lmid - x);
C = query(1, 1, n, y + 2, Lmid); D = query(1, 1, n, Rmid, 2 * Lmid - y - 1);
ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
ans = (ans + Mod) % Mod;
}
else if(x >= Lmid) {
if(x == Lmid) ++x;
if(x == Rmid) ++x;
A = query(1, 1, n, 2 * Lmid - y + 2, Lmid); B = query(1, 1, n, Rmid, y - 1);
C = query(1, 1, n, 2 * Lmid - x + 3, Lmid); D = query(1, 1, n, Rmid, x - 2);
ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
ans = (ans + Mod) % Mod;
}
else {
A = query(1, 1, n, x + 1, Lmid), B = query(1, 1, n, Rmid, 2 * Lmid - x);
C = query(1, 1, n, 2 * Lmid - y + 2, Lmid), D = query(1, 1, n, Rmid, y - 1);
ans -= (A.sum3 + C.sum3 + B._sum3 + D._sum3) % Mod;
ans = (ans + Mod) % Mod;
}
printf("%lld\n", ans);
}
}
}
return 0;
}