1. 程式人生 > 其它 >洛谷P4688 掉進兔子洞 題解

洛谷P4688 掉進兔子洞 題解

題面

傳送門
給定一個長度為 \(N\) 的序列 \(a\) 以及 \(M\) 次詢問,每次詢問為三個區間 \([l_1,r_1],[l_2,r_2],[l_3,r_3]\)。把三個區間中同時出現的數一個一個刪掉,問最後三個區間剩下的數的個數和。
注意這裡刪掉指的是一個一個刪,不是把等於這個值的數直接刪完,比如三個區間是 \([1,2,2,3,3,3,3]\)\([1,2,2,3,3,3,3]\)\([1,1,2,3,3]\),就一起扔掉了 \(1\)\(1\)\(1\)\(2\)\(2\)\(3\)
\(1≤N,M≤10^5\)\(1≤a_i≤10^9\)

題解

與普通莫隊不同,這題每個詢問為三個區間。那麼我們就將所有區間分開處理,每次將區間的資訊記錄下來,最後統一更新每個詢問的答案。但是我們發現,這題無法再用一個變數儲存區間的資訊,而需要用一個可以儲存所有數值資訊的資料結構。
先考慮一種較為暴力的解法:用陣列記錄,下標為數值,記錄出現的次數。加入和刪除均可以 \(O(1)\) 實現。對於每個詢問的區間,將陣列複製並記錄,最後統一處理。
但陣列自然不行,因為無法快速將實時迭代的資訊轉移到每個詢問的區間上,並且空間也無法承受。於是想到使用bitset
然而bitset有其弊端,就是隻能記錄0/1兩種狀態,而無法實現本題要求的“每個數出現了幾次”。這裡需要使用一種奇淫技巧:
設總共有 \(S\)

個數,那麼bitset原要開 \(S\) 位;但我們開 \(N\) 位,並給予每一位一個二元組 \((x,y)\),表示數 \(x\) 是否出現了至少 \(y\) 次。顯然 \(N\) 位可以表示出所有的狀態。在這一題中,由於數字本來就要離散化,預處理很好實現。
再來看莫隊的增減操作,顯然還是可以 \(O(1)\) 實現。
最後統一處理所有區間,容易想到只需將三個區間做與運算,結果為 \(1\) 的位數就是需要刪去的位數。那麼這一步也可以進一步簡化(主要是減少空間),在求每個區間答案時用迭代的bitset實時更新。
這個做法對空間需求很大,需要開 \(10^5\) 個長度為 \(10^5\)bitset
。估算為1200M,而本題空間限制為500M。可以將所有詢問分成三次做,時間複雜度少有變化,空間可以減小為三分之一。

Code

#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<bitset>
using namespace std;
const int N = 100010, M = 33340;
int n, a[N], mp[N], pos[N], len[M], cnt[N];
bitset<N> ans[M], res;
struct Query {
    int id, l, r;
    bool operator <(const Query &oth) const {
        return pos[l] != pos[oth.l] ? pos[l] < pos[oth.l] : r < oth.r;
    }
} q[N];
int read() {
    int x = 0; char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') {x = (x << 3) + (x << 1) + (c ^ 48); c = getchar();}
    return x;
}
void add(int x) {
    res[x + cnt[x]] = 1; cnt[x]++;
}
void del(int x) {
    cnt[x]--; res[x + cnt[x]] = 0;
}
void Solve(int m) {
    for (int i = 1; i <= m; i++) {
        int l1 = read(), r1 = read(), l2 = read(), r2 = read(), l3 = read(), r3 = read();
        q[3 * i - 2] = (Query){i, l1, r1};
        q[3 * i - 1] = (Query){i, l2, r2};
        q[3 * i] = (Query){i, l3, r3};
        len[i] = r1 + r2 + r3 - l1 - l2 - l3 + 3;
    }
    sort(q + 1, q + 3 * m + 1);
    for (int i = 1; i <= m; i++) ans[i].set();
    memset(cnt, 0, sizeof (cnt)); res.reset();
    for (int i = 1, L = 1, R = 0; i <= 3 * m; i++) {
        while (R < q[i].r) add(a[++R]);
        while (R > q[i].r) del(a[R--]);
        while (L < q[i].l) del(a[L++]);
        while (L > q[i].l) add(a[--L]);
        ans[q[i].id] &= res;
    }
    for (int i = 1; i <= m; i++) printf("%d\n", len[i] - ans[i].count() * 3);
}
int main() {
    n = read(); int m = read();
    for (int i = 1; i <= n; i++) mp[i] = a[i] = read();
    sort(mp + 1, mp + n + 1);
    for (int i = 1; i <= n; i++) a[i] = lower_bound(mp + 1, mp + n + 1, a[i]) - mp;
    int len = max(1, (int)sqrt((double)n * n / m));
    for (int i = 1; i <= n; i++) pos[i] = (i - 1) / len + 1;
    Solve(m / 3); Solve(m / 3); Solve(m - m / 3 * 2);
    return 0;
}