1. 程式人生 > 實用技巧 >[BJOI2018] 鏈上二次求和 題解

[BJOI2018] 鏈上二次求和 題解

Description

給你一根 \(n\) 個點的鏈,每個點有權值 \(a_i\) ,要求支援兩種操作:

  • \(u,v\) 間的數加上 \(d\)
  • 詢問鏈上所有長度在 \([l,r]\) 間的路徑權值和。

\(n\le2\times10^5\)

Sol

考慮每個點對於答案產生的貢獻,手玩一下可以得出下面的表(其中點 \((i,j)\) 表示點 \(j\) 在所有長度為 \(i\) 的路徑上的出現次數)。

\[\begin{gathered} \begin{bmatrix} 1 & 1 & 1 & 1 & 1 & 1 & 1 \\ 1 & 2 & 2 & 2 & 2 & 2 & 1 \\1 & 2 & 3 & 3 & 3 & 2 & 1 \\1 & 2 & 3 & 4 & 3 & 2 & 1 \\1 & 2 & 3 & 3 & 3 & 2 & 1 \\1 & 2 & 2 & 2 & 2 & 2 & 1 \\1 & 1 & 1 & 1 & 1 & 1 & 1 \\\end{bmatrix} \quad \end{gathered} \]

觀察到矩陣很有規律,我們考慮線段樹維護所有長度為 \(i\) 的路徑的總全職和,即線段樹最底層一個點維護的是矩陣的一行的所有點的 \(a_i\times cnt_{i,j}\) ,直接維護因為每一行的 \(cnt\) 都不一樣,我們考慮換一種方式:

我們假設每一條路徑上結點的出現次數都是形如 \(1234321\) 型別的,那麼我們的答案就是總貢獻減去多出的貢獻。

總貢獻即 \(a_1+2a_2+3a_3+...+(k-1)a{k-1}+ka_k+(k-1)a_{k+1}+na_n\),拿線段樹簡單維護即可,考慮多出的貢獻,假設我們查詢長度在 \([2,5]\) 之間的路徑權值和,那麼多出來的部分即為:

那麼我們多出來的貢獻即為 \(a_2+4a_3+9a_4+4a_5+a_6\),正著維護 \(1^2a_1+2^2a_2+3^2a_3+...+n^2a_n\)\(n^2a_1+(n-1)^2a_2+(n-2)^2a_3+...+a_n\) ,查詢時減去即可。

還有一種維護方式就是將上圖繼續分割,分成 \(4\) 個小塊,如圖所示:

分別計算每一塊的貢獻第一塊為 \(1,3\),第二塊為 \(1\),第三塊為 \(1,3,6\),第四塊為 \(1,3\)

那麼我們維護 \(1a_1+3a_2+6a_3+10a_4+...\)\(1a_n+3a_{n-1}+6a_{n-2}+...\)

即可。

修改時維護上面的貢獻即可。

時間複雜度 \(O(n\log n)\)

Code

程式碼實現有那麼億點點噁心...

#include<bits/stdc++.h>
#define int long long
#define Mod 1000000007
#define inv6 166666668
#define il inline
#define re register
using namespace std;
il int Read() {
    int x = 0, f = 1; char ch = getchar();
    while(!isdigit(ch)) {if(ch == '-')  f = -1; ch = getchar();}
    while(isdigit(ch)) {x = (x << 3) + (x << 1) + ch - '0'; ch = getchar();}
    return x * f;
}
struct node {
    int sz, sum, sum2, sum3, _sum2, _sum3, addv;
    node(int Sz = 0, int Sum = 0, int Sum2 = 0, int Sum3 = 0, int _Sum2 = 0, int _Sum3 = 0, int Addv = 0) {sz = Sz, _sum2 = _Sum2, _sum3 = _Sum3, sum = Sum, sum2 = Sum2, sum3 = Sum3, addv = Addv;}
}seg[800005];
il node Merge(node A, node B) {
    node C;
    C.sz = A.sz + B.sz;
    C.sum = (A.sum + B.sum) % Mod;
    C.sum2 = (A.sum2 + B.sum2 + A.sz * B.sum % Mod) % Mod;
    C.sum3 = (A.sum3 + B.sum3 + A.sz * B.sum2 % Mod + (A.sz + 1) * A.sz / 2 % Mod * B.sum % Mod) % Mod;
    C._sum2 = (B._sum2 + A._sum2 + B.sz * A.sum % Mod) % Mod;
    C._sum3 = (B._sum3 + A._sum3 + B.sz * A._sum2 % Mod + (B.sz + 1) * B.sz / 2 % Mod * A.sum % Mod) % Mod;
    return C;
}
int a[200005];
il void build(int o, int l, int r) {
    seg[o].sz = r - l + 1;
    if(l == r) {
        seg[o].sum = seg[o].sum2 = seg[o].sum3 = seg[o]._sum2 = seg[o]._sum3 = a[l];
        return ;
    }
    int mid = (l + r) >> 1;
    build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r);
    seg[o] = Merge(seg[o << 1], seg[o << 1 | 1]);
}
il void pusha(int o, int l, int r, int x) {
    int len = (r - l + 1);
    (seg[o].sum3 += (len * (len + 1) % Mod * (len + 2) % Mod * inv6 % Mod) * x % Mod) %= Mod;
    (seg[o]._sum3 += (len * (len + 1) % Mod * (len + 2) % Mod * inv6 % Mod) * x % Mod) %= Mod;
    (seg[o].sum2 += (len * (len + 1) / 2 % Mod) * x % Mod) %= Mod;
    (seg[o]._sum2 += (len * (len + 1) / 2 % Mod) * x % Mod) %= Mod;
    (seg[o].sum += len * x % Mod) %= Mod;
    seg[o].addv = (seg[o].addv + x) % Mod;
}
il void pushdown(int o, int l, int r) {
    int mid = (l + r) >> 1;
    if(seg[o].addv) {
        pusha(o << 1, l, mid, seg[o].addv);
        pusha(o << 1 | 1, mid + 1, r, seg[o].addv);
        seg[o].addv = 0;
    }
}
il void modify(int o, int l, int r, int nl, int nr, int x) {
    if(nl <= l && r <= nr)  return pusha(o, l, r, x);
    pushdown(o, l, r); int mid = (l + r) >> 1;
    if(nl <= mid)  modify(o << 1, l, mid, nl, nr, x);
    if(mid < nr)  modify(o << 1 | 1, mid + 1, r, nl, nr, x);
    seg[o] = Merge(seg[o << 1], seg[o << 1 | 1]);
}
il node query(int o, int l, int r, int nl, int nr) {
    if(nl > nr)  return (node){0ll, 0ll, 0ll, 0ll, 0ll, 0ll, 0ll};
    if(nl <= l && r <= nr)  return seg[o];
    pushdown(o, l, r); int mid = (l + r) >> 1;
    if(nl <= mid) {
        if(mid < nr)  return Merge(query(o << 1, l, mid, nl, nr), query(o << 1 | 1, mid + 1, r, nl, nr));
        return query(o << 1, l, mid, nl, nr);
    }
    return query(o << 1 | 1, mid + 1, r, nl, nr);
}
signed main() {
    int n = Read(), m = Read();
    for(re int i = 1; i <= n; i++)  a[i] = Read();
    build(1, 1, n);
    for(re int i = 1; i <= m; i++) {
        int opt = Read(), x = Read(), y = Read(), z;
        if(x > y)  swap(x, y);
        if(opt == 1) {
            z = Read();
            modify(1, 1, n, x, y, z);
        }
        else {
            if(x > y)  {puts("0"); continue;}
            if(n % 2 == 1) {
                int mid = (1 + n) >> 1; node A, B, C, D;
                A = query(1, 1, n, 1, mid); B = query(1, 1, n, mid + 1, n);
                int ans = (A.sum2 + B._sum2) % Mod * (y - x + 1) % Mod;
                if(y <= mid) {
                    if(y == mid)  --y;
                    A = query(1, 1, n, x + 1, mid); B = query(1, 1, n, mid + 1, 2 * mid - x - 1);
                    C = query(1, 1 ,n, y + 2, mid); D = query(1, 1, n, mid + 1, 2 * mid - y - 2);
                    ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                else if(x >= mid) {
                    if(x == mid)  ++x;
                    A = query(1, 1, n, 2 * mid - y + 1, mid); B = query(1, 1, n, mid + 1, y - 1);
                    C = query(1, 1, n, 2 * mid - x + 2, mid); D = query(1, 1, n, mid + 1, x - 2);
                    ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                else {
                    A = query(1, 1, n, x + 1, mid), B = query(1, 1, n, mid + 1, 2 * mid - x - 1);
                    C = query(1, 1, n, 2 * mid - y + 1, mid), D = query(1, 1, n, mid + 1, y - 1);
                    ans -= (A.sum3 + C.sum3 + B._sum3 + D._sum3) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                printf("%lld\n", ans);
            }
            if(n % 2 == 0) {
                int Lmid = (1 + n) >> 1, Rmid = Lmid + 1; node A, B, C, D;
                A = query(1, 1, n, 1, Lmid); B = query(1, 1, n, Rmid, n);
                int ans = (A.sum2 + B._sum2) % Mod * (y - x + 1) % Mod;
                if(y <= Rmid) {
                    if(y == Rmid)  --y; 
                    if(y == Lmid)  --y;
                    A = query(1, 1, n, x + 1, Lmid); B = query(1, 1, n, Rmid, 2 * Lmid - x);
                    C = query(1, 1, n, y + 2, Lmid); D = query(1, 1, n, Rmid, 2 * Lmid - y - 1);
                    ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                else if(x >= Lmid) {
                    if(x == Lmid)  ++x;
                    if(x == Rmid)  ++x;
                    A = query(1, 1, n, 2 * Lmid - y + 2, Lmid); B = query(1, 1, n, Rmid, y - 1);
                    C = query(1, 1, n, 2 * Lmid - x + 3, Lmid); D = query(1, 1, n, Rmid, x - 2);
                    ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                else {
                    A = query(1, 1, n, x + 1, Lmid), B = query(1, 1, n, Rmid, 2 * Lmid - x);
                    C = query(1, 1, n, 2 * Lmid - y + 2, Lmid), D = query(1, 1, n, Rmid, y - 1);
                    ans -= (A.sum3 + C.sum3 + B._sum3 + D._sum3) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                printf("%lld\n", ans);
            }    
        }
    }
    return 0;
}