1. 程式人生 > 其它 >子串週期查詢

子串週期查詢

大概就是復讀集訓隊論文,大部分證明都略去了。

前置知識

  • WPL: \(s\) 有 period \(p_1 + p_2 \le n \implies\) \(s\) 有 period \(\gcd(p_1, p_2)\)
  • \(s\) 的長 \([l, 2l)\) borders 構成一段等差數列
  • \(s\) 的 borders 構成 \(\log |s|\) 段等差數列
  • \(2 |s| \ge |t| \implies s\)\(t\) 中出現的位置構成等差數列,且公差為 \(s\) 的最小週期(證明:反證,直接考慮 \(s\) 出現的位置覆蓋的段,應用 WPL 即可)

演算法流程

首先可以把一個 border \(b, k = \lfloor \log |b| \rfloor\) 分解成前 \(2^k\) 和後 \(2^k\)(類似 ST 表),分別比較即可。

現在考慮求出長度 \([2^k, 2^{k+1})\) 的 borders。

那麼把原串的 \(2^k\) 字首和 \(2^{k+1}\) 字尾匹配,\(2^k\) 字尾和 \(2^{k+1}\) 字首匹配(匹配位置都是等差數列),將等差數列求交即可。

處理這個需要將所有長為 \(2^k\) 的子串排序,直接用倍增法即可 \(\mathcal O(n \log n)\)

如果二分求出這段等差數列就可以得到 \(\mathcal O(\log^2 n)\)

的查詢。

考慮我們是要求一個子串 \(t\) 所有匹配位置和一段 \(2^k + 1\) 個數的區間求交,那麼將串按 \(2^k\) 分塊,一個求交的區間會恰好落在兩個塊裡,那麼我們處理出三元組 \((t, b, P)\) 表示長 \(2^k\) 子串 \(t\)\(b\) 塊中匹配位置為等差數列 \(P\)。這樣的組數不超過處理的子串總數,即 \(\mathcal O(n \log n)\)(沒有匹配任何位置則不存),那麼用字串雙 hash 和 hash 表即可 \(\mathcal O(1)\) 查詢。最後通過討論將兩個塊中查詢出的資訊合併為一個等差數列。

然後考慮對兩個等差數列求交。發現我們要求交的等差數列形如這樣:\(|x_1| = |x_2| = |y_1| = |y_2| = 2^k\)

\(x_1\)\(y_1y_2\) 中的匹配位置和 \(y_2\)\(x_1x_2\) 中的匹配位置,如果都匹配了至少 \(3\) 次,那麼公差必然一樣。

下面證明:

首先根據前置知識最後一條,設 \(r_1, r_2\) 分別為 \(x_1, x_2\) 最小週期,\(r_2 < r_1\)

畫出匹配圖,可以得出 \(x_1\) 的長度至少為 \(2r_1\) 的字尾有周期 \(r_2\)(通過觀察 \(x_1\) 的字尾匹配了 \(x_2\) 的一個字首)。使用 WPL 立即得到 \(x_1\) 長度至少 \(2r_1\) 的字尾有周期 \(\gcd(r_1, r_2)\),故 \(x_1\) 的 period 有整週期,與 \(r_1\) 是最小週期矛盾。

那麼通過一些討論也可以 \(\mathcal O(1)\) 合併等差數列。通過列舉 \(k\) 就可以得到 \(\mathcal O(\log n)\) 的演算法。

下面是 P4482 [BJWC2018]Border 的四種求法 的程式碼(求最長 border)

最好手寫固定大小 hash 表(unordered_map\(5 \times 10^6\) 次級別的查詢都可能耗費很長的時間),否則很可能跑不過 SAM 暴力 \(\mathcal O(\log^2 n)\)。原題資料不強,程式碼僅供參考。

#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#include <ctime>
#include <numeric>
#include <vector>
#include <cassert>
#include <unordered_map>
using namespace std;

#define LOG(f...) fprintf(stderr, f)
// #define DBG(f...) printf(f)
#define DBG(f...) void()
#define all(cont) begin(cont), end(cont)
#ifdef __linux__
#define getchar getchar_unlocked
#define putchar putchar_unlocked
#endif

using ll = long long;
using ull = unsigned long long;

template <class T> void read(T &x) {
  char ch; x = 0;
  int f = 1;
  while (isspace(ch = getchar()));
  if (ch == '-') ch = getchar(), f = -1;
  do x = x * 10 + (ch - '0'); while(isdigit(ch = getchar()));
  x *= f;
}
template <class T, class ...A> void read(T &x, A&... args) { read(x); read(args...); }

const int N = 200005;
const int M = 0x7FFFFFFF;
const ull MAGIC = 0x21b699768c4aed5f;
const int B1 = 131, B2 = 248;
int cnt = 0;

// arithmetic progression
struct ap {
  int s, t, d;
};
const ap EMPTY = {1, 0, 0};
bool contains(const ap &a, int x) {
  if (a.s > a.t) return false;
  if (!a.d) return x == a.s || x == a.t;
  return a.s <= x && x <= a.t && (x - a.s) % a.d == 0;
}

char s[N];
int n;
int h1[N], h2[N], np1[N], np2[N];

void init_hash() {
  np1[0] = np2[0] = M - 1;
  for (int i = 0; i < n; ++i) {
    np1[i + 1] = (ll)np1[i] * B1 % M;
    np2[i + 1] = (ll)np2[i] * B2 % M;
  }
  for (int i = 0; i < n; ++i) {
    h1[i + 1] = ((ll)B1 * h1[i] + s[i]) % M;
    h2[i + 1] = ((ll)B2 * h2[i] + s[i]) % M;
  }
}
ull range(int l, int r) {
  return ull((h1[r] + (ull)h1[l] * np1[r - l]) % M) << 32 | ull((h2[r] + (ull)h2[l] * np2[r - l]) % M);
}

struct hasher {
  ull operator()(const pair<ull, int> &p) const { return p.first + p.second * MAGIC; }
};

struct hashtable {
  static const int MASK = (1 << 22) - 1;
  struct node {
    ull k;
    ap v;
    node *nxt;
  } v[N * 18];
  node *hd[MASK + 1], *alloc = v;

  void emplace(ull p, ap v) { *alloc = {p, v, hd[p & MASK]}; hd[p & MASK] = alloc++; }
  node *find(ull p) { node *n = hd[p & MASK]; while (n && n->k != p) n = n->nxt; return n; }
} dict;

// unordered_map<pair<ull, int>, ap, hasher> dict;
int maxw;

namespace internal {
  int sa[N], rk[N], sec[N], m;
  int pos[N];

  void radix_sort(int n) {
    memset(pos, 0, sizeof(pos));
    for (int i = 0; i < n; ++i)
      ++pos[rk[i]];
    partial_sum(pos, pos + m, pos);
    for (int i = n - 1; i >= 0; --i)
      sa[--pos[rk[sec[i]]]] = sec[i];
  }

  void build() {
    // dict.reserve(n * 40);
    for (int i = 0; i < n; ++i)
      rk[i] = s[i] - 'a', sec[i] = i;
    m = 26;
    radix_sort(n);

    for (int w = 2; w < n; w <<= 1) {
      int p = 0, l = w >> 1, cnt = n - w + 1;
      int bw = __lg(w);
      for (int i = 0; i < n; ++i)
        if (sa[i] + l <= n && sa[i] >= l)
          sec[p++] = sa[i] - l;
      radix_sort(cnt);
      memcpy(sec, rk, sizeof(rk));
      rk[sa[0]] = 0;
      for (int i = 1; i < cnt; ++i)
        rk[sa[i]] = rk[sa[i - 1]] + (sec[sa[i]] != sec[sa[i - 1]] || sec[sa[i] + l] != sec[sa[i - 1] + l]);
      m = rk[sa[cnt - 1]] + 1;
      if (m == cnt) break;
      maxw = bw;

      for (int l = 0, r; l < cnt; l = r) {
        r = l;
        while (r != cnt && rk[sa[r]] == rk[sa[l]]) ++r;
        ull hsh = range(sa[l], sa[l] + w);
        int last = -1;
        ap prog;
        for (int i = l; i < r; ++i) {
          if (sa[i] >> bw != last) {
            if (~last) {
              // dict.insert({make_pair(hsh, last), prog});
              dict.emplace(hsh + last * MAGIC, prog);
            }
            prog = {sa[i], sa[i], 0};
            last = sa[i] >> bw;
          }
          else {
            prog.d = sa[i] - sa[i - 1];
            prog.t = sa[i];
          }
        }
        dict.emplace(hsh + last * MAGIC, prog);
        // dict.insert({make_pair(hsh, last), prog});
      }
    }
  }
}

ap _reduce(ap a, int l, int r) {
  if (a.s > a.t) return a;
  if (a.s + a.d == a.t) {
    if (l <= a.s && a.t < r) return a;
    if (l <= a.s && a.s < r) return {a.s, a.s, 0};
    if (l <= a.t && a.t < r) return {a.t, a.t, 0};
    return EMPTY;
  }
  if (a.s < l) a.s += (l - a.s + a.d - 1) / a.d * a.d;
  if (a.t >= r) a.t -= (a.t - r + a.d) / a.d * a.d;
  return a;
}
ap occurence(int l, int r, int pl, int pr, int bs) {
  ull hsh = range(l, r);
  int bl = pl >> bs;
  auto it1 = dict.find(hsh + bl * MAGIC), it2 = dict.find(hsh + (bl + 1) * MAGIC);
  ap a = it1 ? it1->v : EMPTY;
  ap b = it2 ? it2->v : EMPTY; 
  // auto it1 = dict.find(make_pair(hsh, bl)), it2 = dict.find(make_pair(hsh, bl + 1));
  // ap a = it1 == dict.end() ? EMPTY : it1->second;
  // ap b = it2 == dict.end() ? EMPTY : it2->second;
  ++cnt;
  a = _reduce(a, pl, pr); b = _reduce(b, pl, pr);
  if (a.s > a.t) return b;
  if (b.s > b.t) return a;
  return {a.s, b.t, b.s - a.t};
}

int query(int l, int r) {
  if (l == r) return 0;
  int k = __lg(r - l);
  for (int i = min(k, maxw); i; --i) {
    int lb = 1 << i, rb = min(1 << (i + 1), r - l - 1);
    ap a = occurence(l, l + lb, r - rb, r - lb + 1, i);
    ap b = occurence(r - lb, r, l, l + lb + 1, i);
    if (a.s > a.t || b.s > b.t) continue;
    tie(a.s, a.t) = make_pair(l + r - a.t - lb, l + r - a.s - lb);
    if (b.s + b.d == b.t) swap(a, b);
    int max_inter = -1;
    if (a.s + a.d == a.t) {
      if (contains(b, a.t)) max_inter = a.t;
      else if (contains(b, a.s)) max_inter = a.s;
    }
    else {
      if ((b.s - a.s) % b.d != 0) continue;
      int l = max(a.s, b.s), r = min(a.t, b.t);
      if (l <= r) max_inter = r;
    }
    if (~max_inter)
      return max_inter - l + lb;
  }
  return l + 1 != r && s[l] == s[r - 1];
}

int main() {
#ifdef LOCAL
  freopen("input.txt", "r", stdin);
  freopen("output.txt", "w", stdout);
#endif
  scanf("%s", s);
  n = strlen(s);
  init_hash();
  internal::build();
  int qc;
  read(qc);
  while (qc--) {
    int l, r;
    read(l, r);
    --l;
    printf("%d\n", query(l, r));
  }
  LOG("hashes : %d\n", cnt);
  return 0;
}