1. 程式人生 > 實用技巧 >【無聊亂搞】如何用 std::set 過 gamma

【無聊亂搞】如何用 std::set 過 gamma

一道毒瘤題

\(\gamma\) by DPair

題目描述

維護一個正整數集 \(S\),元素 \(\in\) 值域 \(U\),需要支援:

  • \(\texttt{1 l r}\)\(S\gets S\cup [l,r]\)
  • \(\texttt{2 l r}\)\(S \gets \{x|x\in S \land x\notin [l,r]\}\)
  • \(\texttt{3 l r}\):求滿足 \(x\in [l,r]\land x\notin S\) 的最小 \(x\)
  • \(\texttt{4 l r}\):求 \(\sum_{x\in[l,r]}[x\in S]\)

資料規模

  • \(1\le U \le 10^{18}\)
  • \(1\le Q\le 5\times 10^6\)
  • \(1000\ ms,\texttt{-O2}\)
  • 隨機資料

Naive Solution

注意到操作 1、2 相當於區間賦值。

那麼不難想到 ODT。然而基於 std::set 的 ODT 實現常數過大,不過手寫連結串列可以通過。

但是為了挑戰自我 筆者決定使用 std::set 通過這道題。

下面是一份來自 DPair 的 Naive ODT 實現(我自己懶得寫):

struct NODE{
    LL l, r;
    mutable int val;
    NODE (LL tmp1, LL tmp2 = -1, int tmp3 = 0) : l(tmp1), r(tmp2), val(tmp3){}
    inline bool operator < (const NODE &tmp) const{return l < tmp.l;}
};
set <NODE> ODT;
typedef set <NODE> :: iterator IT;
inline IT split(LL x){
    IT it = ODT.lower_bound(NODE(x));
    if(it != ODT.end() && it -> l == x) return it;
    -- it;
    LL L = it -> l, R = it -> r;
    int Val = it -> val;
    ODT.erase(it);
    ODT.insert(NODE(L, x - 1, Val));
    return ODT.insert(NODE(x, R, Val)).first;
}
inline void assign(LL l, LL r, int val){
    IT R = split(r + 1), L = split(l);
    ODT.erase(L, R);
    ODT.insert(NODE(l, r, val));
}
inline LL getSum(LL l, LL r){
    IT R = split(r + 1), L = split(l);
    LL ret = 0;
    while(L != R){
        ret += (L -> r - L -> l + 1) * (L -> val);
        ++ L;
    }
    return ret;
}
inline LL getMex(LL l, LL r){
    IT R = split(r + 1), L = split(l);
    LL ret = 0;
    while(L != R){
        if(!(L -> val)) chmax(ret, L -> r);
        ++ L;
    }
    return ret;
}

上面這份程式碼複雜度為 \(O(Q\log U)\),但由於常數被連結串列吊打。

Improved Solution

我們並不打算更換演算法,而是在原來的程式碼上優化實現。

Improvement #1:只維護一種顏色

考慮到我們的值只有兩種:\(0,1\)。那麼考慮只保留其中一種值,這樣 set 維護的連續段數理論上會減少一半。

那麼到底維護 \(0\) 還是 \(1\) 呢?看詢問:4 操作其實 \(0,1\) 都差不多,但是 3 操作就不太一樣了,如果維護 \(1\) 的話需要找到第一個不連續的位置,如果存在大量虛假的斷點(即兩個不同的連續段實際上相鄰)就很浪費些時間,不過維護 \(0\) 就不太一樣了,我們只要找第一個迭代器的左端點就是第一個 \(0\)

的位置,或者左右迭代器相等判斷無解。

下面是在原來基礎上略加修改的 split 函式:

std::set<std::pair<LL, LL> > odt;
setIt CutItv(LL p) { // make breakpoint in front of position p.(split)
  setIt it = odt.lower_bound(std::make_pair(p, 0));
  if (it == odt.begin()) return it;
  else --it;
  if (it->second >= p) {
    std::pair<LL, LL> rec = *it; odt.erase(it);
    odt.insert(std::make_pair(rec.first, p - 1));
    return odt.insert(std::make_pair(p, rec.second)).first;
  }
  return ++it;
}

Improvement #2:mutable

所謂 mutable,即“可變的”,具體解釋如下:

mutable 的意思是“可變的”,讓我們可以在後面的操作中修改 v 的值。在 C++ 中,mutable 是為了突破 const 的限制而設定的。被 mutable 修飾的變數(mutable 只能用於修飾類中的非靜態資料成員),將永遠處於可變的狀態,即使在一個 const 函式中。這意味著,我們可以直接修改已經插入 set 的元素的 v 值,而不用將該元素取出後重新加入 set

——OI Wiki

其中上面 DPair 的實現中也用到了 multable,不過,如上所說,僅僅是修飾了值的變數。

然而其實 r 也是可以 mutable 的,並且在新的 split(CutItv) 實現中也沒有用好這個特性,可以發現它可以使我們的 split 少一次 erase、少一次 insert,是非常可觀的一個優化。

Improvement #3:emplace

C++11 中,std::set 中有了一種新的插入元素的方法:emplace

它和 insert 的功能集合一樣(包括返回值),但是 emplace 是原位構造元素,相比 insert 可以避免大量的不必要的複製移動,從而常數進一步得到優化。

詳情可見 cppreference - std::set<Key,Compare,Allocator>::emplace

結合 優化#2 的程式碼:

struct Interval {
  LL l; mutable LL r;
  inline Interval(LL l, LL r) : l(l), r(r) { }
  inline bool operator < (const Interval& rhs) const { return l < rhs.l; }
};
std::set<Interval> odt({Interval(1, (LL)1e18)});
std::set<Interval>::iterator CutItv(LL p) {
  auto it = odt.lower_bound(Interval(p, 0ll));
  if (it == odt.begin()) return it;
  else --it;
  if (it->r >= p) {
    LL tr = it->r; it->r = p - 1;
    return odt.emplace(p, tr).first;
  }
  return ++it;
}

Improvement #4:emplace_hint

emplace 很快,但 emplace_hint 更快,前提是在用的好的時候。

emplace_hint 相比 emplace 又多了一個引數 hint(一個迭代器),插入操作會在容器中儘可能接近於 hint 的位置進行。這意味著插入操作可以節約很大一部分查詢的時間。

emplace_hint 改良實現:

std::set<Interval>::iterator CutItv(LL p) {
  auto it = odt.lower_bound(Interval(p, 0ll));
  if (it == odt.begin()) return it;
  else --it;
  if (it->r >= p) {
    LL tr = it->r; it->r = p - 1;
    return odt.emplace_hint(it, p, tr);
  }
  return ++it;
}

不僅僅是 split 部分,其他設計插入操作的都可以這樣操作:

void Insert(LL l, LL r) {
  auto itr = CutItv(r + 1), itl = CutItv(l);
  odt.emplace_hint(odt.erase(itl, itr), l, r);//其實 erase 也有返回值
}

Improvement #5:及時合併虛假斷點

也許現在的連續段應該是這樣:\([1,100]\)

但可能你的 std::set 中是這樣:\([1, 15],[16,51],\cdots,[81,89] , [89,100]\)。這很難受,白白增大了 set 的大小。

於是我們在 InsertgetMexgetSum 三個操作之後都加一個機制,把 set 中與區間對應的兩個迭代器周圍相鄰的段合併。

實測 \([1, 10^{18}]\) 這樣的區間,隨機資料下所有時刻 set 的大小的平均值僅為 \(12\)

Final Version

最後又發現 set 中的元素只按左端點排序,右端點有事可變的,於是又有了 std::map 的版本,詳見第二個程式碼:

#include <algorithm>
#include <cstdio>
#include <iterator>
#include <set>

typedef long long LL;
namespace My_Rand{
  int index, MT[624];
  inline void sd(int seed){
    index = 0;
    MT[0] = seed;
    for (register int i = 1;i < 624;i ++){
      int t = 1812433253 * (MT[i - 1] ^ (MT[i - 1] >> 30)) + i;
      MT[i] = t & 0xffffffff;
    }
  }
  inline void rotate(){
    for (register int i = 0;i < 624;i ++){
      int tmp = (MT[i] & 0x80000000) + (MT[(i + 1) % 624] & 0x7fffffff);
      MT[i] = MT[(i + 397) % 624] ^ (tmp >> 1);
      if(tmp & 1) MT[i] ^= 2567483615;
    }
  }
  inline int rd(){
    if(!index) rotate();
    int ret = MT[index];
    ret = ret ^ (ret >> 11);
    ret = ret ^ ((ret << 7) & 2636928640);
    ret = ret ^ ((ret << 15) & 4022730752);
    ret = ret ^ (ret >> 18);
    index = (index + 1) % 624;
    return ret;
  }
  const LL limit = 1000000000;
  inline void gen(int &opt, LL &l, LL &r, LL ans){
    opt = rd() % 4 + 1;
    ans = ans % limit;
    l = ((rd() ^ ans) % limit) * limit + ((rd() ^ ans) % limit);
    r = ((rd() ^ ans) % limit) * limit + ((rd() ^ ans) % limit);
    if(l > r) std::swap(l, r);
  }
} // namespace My_Rand

struct Interval {
  LL l; mutable LL r;
  inline Interval(LL l, LL r) : l(l), r(r) { }
  inline bool operator < (const Interval& rhs) const { return l < rhs.l; }
};
std::set<Interval> odt({Interval(1, (LL)1e18)});

std::set<Interval>::iterator CutItv(LL p) { // make breakpoint in front of position p.
  auto it = odt.lower_bound(Interval(p, 0ll));
  if (it == odt.begin()) return it;
  else --it;
  if (it->r >= p) {
    LL tr = it->r; it->r = p - 1;
    return odt.emplace_hint(it, p, tr);
  }
  return ++it;
}
void Insert(LL l, LL r) {
  auto itr = CutItv(r + 1), itl = CutItv(l);
  auto it = odt.emplace_hint(odt.erase(itl, itr), l, r);
  if (it != odt.begin())
    if (prev(it)->r + 1 == l) prev(it)->r = it->r, it = odt.erase(it);
  if (it != odt.begin())
    if (prev(it)->r + 1 == l) prev(it)->r = it->r, it = odt.erase(it);
}
void Erase(LL l, LL r) {
  auto itr = CutItv(r + 1), itl = CutItv(l);
  odt.erase(itl, itr);
}
LL getMex(LL l, LL r) {
  auto itr = CutItv(r + 1), itl = CutItv(l);
  if (itl == itr) return 0;
  LL ans = itl->l;
  if (itl != odt.begin())
    if (prev(itl)->r + 1 == l) prev(itl)->r = itl->r, odt.erase(itl);
  if (itr != odt.end())
    if (itr->l == r + 1) prev(itr)->r = itr->r, odt.erase(itr);
  return ans;
}
LL getSum(LL l, LL r) {
  auto itr = CutItv(r + 1), itl = CutItv(l);
  LL ret = 0;
  for (auto it = itl; it != itr; it++) ret += it->r - it->l + 1;
  if (itl != odt.begin())
    if (prev(itl)->r + 1 == l) prev(itl)->r = itl->r, odt.erase(itl);
  if (itr != odt.end())
    if (itr->l == r + 1) prev(itr)->r = itr->r, odt.erase(itr);
  return r - l + 1 - ret;
}

signed main() {
  int seed, Q;
  scanf("%d%d", &Q, &seed);
  My_Rand::sd(seed);

  LL last = 0ll, axor = 0ll;
  while (Q--) {
    int opt; LL l, r;
    My_Rand::gen(opt, l, r, last);
    
    if (opt == 2) Insert(l, r);
    else if (opt == 1) Erase(l, r);
    else if (opt == 3) axor ^= (last = getMex(l, r));
    else axor ^= (last = getSum(l, r));
  }

  printf("%lld\n", axor);
  return 0;
}
#include <algorithm>
#include <cstdio>
#include <map>

typedef long long LL;
namespace My_Rand{
  int index, MT[624];
  inline void sd(int seed){
    index = 0;
    MT[0] = seed;
    for (register int i = 1;i < 624;i ++){
      int t = 1812433253 * (MT[i - 1] ^ (MT[i - 1] >> 30)) + i;
      MT[i] = t & 0xffffffff;
    }
  }
  inline void rotate(){
    for (register int i = 0;i < 624;i ++){
      int tmp = (MT[i] & 0x80000000) + (MT[(i + 1) % 624] & 0x7fffffff);
      MT[i] = MT[(i + 397) % 624] ^ (tmp >> 1);
      if(tmp & 1) MT[i] ^= 2567483615;
    }
  }
  inline int rd(){
    if(!index) rotate();
    int ret = MT[index];
    ret ^= (ret >> 11);
    ret ^= ((ret << 7) & 2636928640);
    ret ^= ((ret << 15) & 4022730752);
    ret ^= (ret >> 18);
    (++index) %= 624;
    return ret;
  }
  const LL limit = 1000000000;
  inline void gen(int &opt, LL &l, LL &r, LL ans){
    opt = (rd() & 3) + 1;
    ans = ans % limit;
    l = ((rd() ^ ans) % limit) * limit + ((rd() ^ ans) % limit);
    r = ((rd() ^ ans) % limit) * limit + ((rd() ^ ans) % limit);
    if(l > r) std::swap(l, r);
  }
} // namespace My_Rand

std::map<LL, LL> odt({std::make_pair(1, (LL)1e18)});

std::map<LL, LL>::iterator CutItv(LL p) {
  auto it = odt.upper_bound(p);
  if (it == odt.begin()) return it;
  if ((--it)->second >= p) {
    LL tr = it->second; it->second = p - 1;
    return odt.emplace_hint(it, p, tr);
  }
  return ++it;
}
void Insert(LL l, LL r) {
  auto itr = CutItv(r + 1), itl = CutItv(l);
  auto it = odt.emplace_hint(--odt.erase(itl, itr), l, r);
  if (it != odt.begin()) if (prev(it)->second + 1 == l)
    prev(it)->second = it->second, it = odt.erase(it);
  if (it != odt.begin()) if (prev(it)->second + 1 == l)
    prev(it)->second = it->second, it = odt.erase(it);
}
void Erase(LL l, LL r) {
  auto itr = CutItv(r + 1), itl = CutItv(l);
  odt.erase(itl, itr);
}
LL getMex(LL l, LL r) {
  auto itr = CutItv(r + 1), itl = CutItv(l);
  if (itl == itr) return 0;
  LL ans = itl->first;
  if (itl != odt.begin()) if (prev(itl)->second + 1 == l)
    prev(itl)->second = itl->second, odt.erase(itl);
  if (itr != odt.end()) if (itr->first == r + 1)
    prev(itr)->second = itr->second, odt.erase(itr);
  return ans;
}
LL getSum(LL l, LL r) {
  auto itr = CutItv(r + 1), itl = CutItv(l);
  LL ans = 0;
  for (auto it = itl; it != itr; it++)
    ans += it->second - it->first + 1;
  if (itl != odt.begin()) if (prev(itl)->second + 1 == l)
    prev(itl)->second = itl->second, odt.erase(itl);
  if (itr != odt.end()) if (itr->first == r + 1)
    prev(itr)->second = itr->second, odt.erase(itr);
  return r - l + 1 - ans;
}

signed main() {
  int seed, Q;
  scanf("%d%d", &Q, &seed);
  My_Rand::sd(seed);

  LL last = 0ll, axor = 0ll;
  ++Q; while (--Q) {
    int opt; LL l, r;
    My_Rand::gen(opt, l, r, last);

    if (opt == 2) Insert(l, r);
    else if (opt == 1) Erase(l, r);
    else if (opt == 3) axor ^= (last = getMex(l, r));
    else axor ^= (last = getSum(l, r));
  }

  return printf("%lld\n", axor), 0;
}

End

這道題就這樣卡過去了,甚至比連結串列還快一點。

也許有人問:為什麼不手寫平衡樹?然而開了 O2 的 std::set 說實話並不比手寫慢,而且手寫實現難度更大。

所以千萬不要低估 STL 的實力,在用得好的情況下並不會遜色於手寫 DS。

當然前提是對 STL 足夠熟悉,並且能夠靈活運用。