洛谷P4688 掉進兔子洞 題解
阿新 • • 發佈:2022-04-05
題面
傳送門
給定一個長度為 \(N\) 的序列 \(a\) 以及 \(M\) 次詢問,每次詢問為三個區間 \([l_1,r_1],[l_2,r_2],[l_3,r_3]\)。把三個區間中同時出現的數一個一個刪掉,問最後三個區間剩下的數的個數和。
注意這裡刪掉指的是一個一個刪,不是把等於這個值的數直接刪完,比如三個區間是 \([1,2,2,3,3,3,3]\),\([1,2,2,3,3,3,3]\) 與 \([1,1,2,3,3]\),就一起扔掉了 \(1\) 個 \(1\),\(1\) 個 \(2\),\(2\) 個 \(3\)。
\(1≤N,M≤10^5\),\(1≤a_i≤10^9\)。
題解
與普通莫隊不同,這題每個詢問為三個區間。那麼我們就將所有區間分開處理,每次將區間的資訊記錄下來,最後統一更新每個詢問的答案。但是我們發現,這題無法再用一個變數儲存區間的資訊,而需要用一個可以儲存所有數值資訊的資料結構。
先考慮一種較為暴力的解法:用陣列記錄,下標為數值,記錄出現的次數。加入和刪除均可以 \(O(1)\) 實現。對於每個詢問的區間,將陣列複製並記錄,最後統一處理。
但陣列自然不行,因為無法快速將實時迭代的資訊轉移到每個詢問的區間上,並且空間也無法承受。於是想到使用bitset
。
然而bitset
有其弊端,就是隻能記錄0/1
兩種狀態,而無法實現本題要求的“每個數出現了幾次”。這裡需要使用一種奇淫技巧:
設總共有 \(S\)
bitset
原要開 \(S\) 位;但我們開 \(N\) 位,並給予每一位一個二元組 \((x,y)\),表示數 \(x\) 是否出現了至少 \(y\) 次。顯然 \(N\) 位可以表示出所有的狀態。在這一題中,由於數字本來就要離散化,預處理很好實現。再來看莫隊的增減操作,顯然還是可以 \(O(1)\) 實現。
最後統一處理所有區間,容易想到只需將三個區間做與運算,結果為 \(1\) 的位數就是需要刪去的位數。那麼這一步也可以進一步簡化(主要是減少空間),在求每個區間答案時用迭代的
bitset
實時更新。這個做法對空間需求很大,需要開 \(10^5\) 個長度為 \(10^5\) 的
bitset
Code
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<bitset>
using namespace std;
const int N = 100010, M = 33340;
int n, a[N], mp[N], pos[N], len[M], cnt[N];
bitset<N> ans[M], res;
struct Query {
int id, l, r;
bool operator <(const Query &oth) const {
return pos[l] != pos[oth.l] ? pos[l] < pos[oth.l] : r < oth.r;
}
} q[N];
int read() {
int x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') {x = (x << 3) + (x << 1) + (c ^ 48); c = getchar();}
return x;
}
void add(int x) {
res[x + cnt[x]] = 1; cnt[x]++;
}
void del(int x) {
cnt[x]--; res[x + cnt[x]] = 0;
}
void Solve(int m) {
for (int i = 1; i <= m; i++) {
int l1 = read(), r1 = read(), l2 = read(), r2 = read(), l3 = read(), r3 = read();
q[3 * i - 2] = (Query){i, l1, r1};
q[3 * i - 1] = (Query){i, l2, r2};
q[3 * i] = (Query){i, l3, r3};
len[i] = r1 + r2 + r3 - l1 - l2 - l3 + 3;
}
sort(q + 1, q + 3 * m + 1);
for (int i = 1; i <= m; i++) ans[i].set();
memset(cnt, 0, sizeof (cnt)); res.reset();
for (int i = 1, L = 1, R = 0; i <= 3 * m; i++) {
while (R < q[i].r) add(a[++R]);
while (R > q[i].r) del(a[R--]);
while (L < q[i].l) del(a[L++]);
while (L > q[i].l) add(a[--L]);
ans[q[i].id] &= res;
}
for (int i = 1; i <= m; i++) printf("%d\n", len[i] - ans[i].count() * 3);
}
int main() {
n = read(); int m = read();
for (int i = 1; i <= n; i++) mp[i] = a[i] = read();
sort(mp + 1, mp + n + 1);
for (int i = 1; i <= n; i++) a[i] = lower_bound(mp + 1, mp + n + 1, a[i]) - mp;
int len = max(1, (int)sqrt((double)n * n / m));
for (int i = 1; i <= n; i++) pos[i] = (i - 1) / len + 1;
Solve(m / 3); Solve(m / 3); Solve(m - m / 3 * 2);
return 0;
}