[CF1083D]The Fair Nut’s getting crazy[單調棧+線段樹]
阿新 • • 發佈:2019-01-03
題意
給定一個長度為 \(n\) 的序列 \(\{a_i\}\)。你需要從該序列中選出兩個非空的子段,這兩個子段滿足: - 兩個子段非包含關係。 - 兩個子段存在交。 - 位於兩個子段交中的元素在每個子段中只能出現一次。 求共有多少種不同的子段選擇方案。輸出總方案數對 \(10^9 + 7\) 取模後的結果。 需要注意的是,選擇子段 \([a, b]\)、\([c, d]\) 與選擇子段 \([c, d]\)、\([a, b]\)
分析
- 考慮列舉一個區間 \([b,c]\) 作為交,記錄 \(L_i,R_i\) 表示距離 \(i\) 最近的和 \(i\) 顏色相同的位置。
- 有: \(a\in[\max\limits_{i=b}^c{L_i},b),d\in(c,\min\limits_{i=b}^c{R_i}]\)。
- 記錄可以取到的左端點的最小值(滿足區間中不存在兩個相同的數) \(pos\)
考慮從左到右列舉交區間的右端點 \(i\) ,用單調棧維護每個位置的 \(mi, mx\) 。容易得到以 \(i\) 為交區間的右端點的方案數為 \(\sum_{j=pos}^i(mi_j-i)(j-mx_j)\),拆開然後用線段樹分別維護。
總時間複雜度為 \(O(nlogn)\)。
程式碼
#include<bits/stdc++.h> using namespace std; typedef long long LL; #define go(u) for(int i = head[u], v = e[i].to; i; i=e[i].lst, v=e[i].to) #define rep(i, a, b) for(int i = a; i <= b; ++i) #define pb push_back #define re(x) memset(x, 0, sizeof x) inline int gi() { int x = 0,f = 1; char ch = getchar(); while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar();} while(isdigit(ch)) { x = (x << 3) + (x << 1) + ch - 48; ch = getchar();} return x * f; } template <typename T> inline void Max(T &a, T b){if(a < b) a = b;} template <typename T> inline void Min(T &a, T b){if(a > b) a = b;} const int N = 1e5 + 7, mod = 1e9 + 7; int n, vc; LL ans; int lst[N], L[N], R[N], V[N], a[N]; int st1[N], st2[N], tp1, tp2; #define Ls o << 1 #define Rs (o << 1 | 1) LL s1(int n) { return 1ll * n * (n + 1) / 2; } LL ami[N << 2], amx[N << 2]; struct data { LL mi, mx, smi, tm; data operator +(const data &rhs) const { return (data){ (mi + rhs.mi) % mod, (mx + rhs.mx) % mod, (smi + rhs.smi) % mod, (tm + rhs.tm) % mod}; } }t[N << 2]; void add(LL &a, LL b) { a += b;if(a >= mod) a -= mod; } void stmi(int l, int r, int o, int v) { add(ami[o], v); add(t[o].tm, 1ll * v * t[o].mx % mod); add(t[o].mi, 1ll * (r - l + 1) * v % mod); add(t[o].smi, (s1(r) - s1(l - 1)) % mod * v % mod); } void stmx(int l, int r, int o, int v) { add(amx[o], v); add(t[o].tm, 1ll * v * t[o].mi % mod); add(t[o].mx, 1ll * (r - l + 1) * v % mod); } void pushdown(int l, int r, int o) { int mid = l + r >> 1; if(ami[o]) { stmi(l, mid, Ls, ami[o]); stmi(mid + 1, r, Rs, ami[o]); } if(amx[o]) { stmx(l, mid, Ls, amx[o]); stmx(mid + 1, r, Rs, amx[o]); } ami[o] = amx[o] = 0; } void pushup(int o) { t[o] = t[Ls] + t[Rs]; } void modify(int L, int R, int l, int r, int o, int v, int opt) { if(L <= l && r <= R) { if(!opt) stmi(l, r, o, v); else stmx(l, r, o, v); return; } pushdown(l, r, o);int mid = l + r >> 1; if(L <= mid) modify(L, R, l, mid, Ls, v, opt); if(R > mid) modify(L, R, mid + 1, r, Rs, v, opt); pushup(o); } data query(int L, int R, int l, int r, int o) { if(L <= l && r <= R) return t[o]; pushdown(l, r, o);int mid = l + r >> 1; if(R <= mid) return query(L, R, l, mid, Ls); if(L > mid) return query(L, R, mid + 1, r, Rs); return query(L, R, l, mid, Ls) + query(L, R, mid + 1, r, Rs); } int main() { n = gi(); rep(i, 1, n) a[i] = gi(), V[i] = a[i]; sort(V + 1, V + 1 + n); vc = unique(V + 1, V + 1 + n) - V - 1; rep(i, 1, n) a[i] = lower_bound(V + 1, V + 1 + vc, a[i]) - V; rep(i, 1, n) { L[i] = lst[a[i]] + 1; lst[a[i]] = i; } rep(i, 1, vc) lst[i] = n + 1; for(int i = n; i; --i) { R[i] = lst[a[i]] - 1; lst[a[i]] = i; } for(int i = 1, gg = 1; i <= n; ++i) { for(; tp1 && L[i] >= L[st1[tp1]]; --tp1) { modify(st1[tp1 - 1] + 1, st1[tp1], 1, n, 1, mod - L[st1[tp1]], 1); } modify(st1[tp1] + 1, i, 1, n, 1, L[i], 1); st1[++tp1] = i; for(; tp2 && R[i] <= R[st2[tp2]]; --tp2) { modify(st2[tp2 - 1] + 1, st2[tp2], 1, n, 1, mod - R[st2[tp2]], 0); } modify(st2[tp2] + 1, i, 1, n, 1, R[i], 0); st2[++tp2] = i; Max(gg, L[i]); data res = query(gg, i, 1, n, 1); LL tmp = ((res.smi + i * res.mx % mod - res.tm - (s1(i) - s1(gg - 1)) % mod * i % mod) % mod + mod) % mod; add(ans, tmp); } printf("%lld\n", ans); return 0; }