1. 程式人生 > 實用技巧 >線段樹優化DP之Monotonicity

線段樹優化DP之Monotonicity

題目

P3506 [POI2010]MOT-Monotonicity 2

思路

定義f[i]為處理到第i位,所得匹配的最長長度,根據f[i]我們可以求出它後面要跟的符號(可以用符號填滿,避免一些取模運算),對於i,我們列舉每一個i前面的j,判斷是否合法,那麼\(n^2\)的做法就可以寫出來了

#include<bits/stdc++.h>
using namespace std;
const int maxn=20000+10;
int f[maxn],a[maxn];
char op[100+10];
int sta[maxn],top,ans,id;
int path[maxn];
void print(int id){
    if(id==0)return;
    print(path[id]);
    printf("%d ",a[id]);
}
int main(){
    int n,k;
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n;i++)scanf("%d",&a[i]);
    for(int i=1;i<=k;i++){
        scanf(" %c",&op[i]);
    }
    f[1]=ans=id=1;
    for(int i=2;i<=n;i++){
        f[i]=1;
        for(int j=1;j<i;j++){
            char ch=op[(f[j]-1)%k+1];
            if((ch=='>'&&a[j]>a[i])||(ch=='<'&&a[j]<a[i])||(ch=='='&&a[j]==a[i])){
                if(f[i]<f[j]+1){
                    f[i]=f[j]+1;
                    path[i]=j;
                }
            }
        }
    }
    for(int i=1;i<=n;i++){
        if(ans<f[i]){
            ans=f[i];
            id=i;
        }
    }
    printf("%d\n",ans);
    print(id);
}

不過根據這道題的資料,\(n^2\)顯然是過不了的,那麼我們就考慮優化,所以就有了下面的線段樹優化做法

  • 維護三棵線段樹(貌似兩棵也可以),分別維護後面該接 =(root1為根),<(root2為根), >(root3為根) 的位置
  • insert 根據當前的f[i]求出下一個符號該接什麼,然後放到相應的線段樹裡面
    --->後面接”=“, insert(root1,1,1e6,i,f[i]);
    --->後面接“<”, insert(root2,1,1e6,i,f[i]);
    --->後面接“>”, insert(root3,1,1e6,i,f[i]);
  • query 分別在三棵線段樹中找一個符合情況的最大f[j](這裡就是優化所在,在查詢最優j時變成了\(log\)
    級別),因為要記錄路徑,所以我們返回值為位置
    --->在等於的樹中取f值最大對應的位置 query(root1,1,1e6,a[i],a[i]);
    --->在小於的樹中取f值最大對應的位置 query(root2,1,1e6,1,a[i]-1);
    --->在大於的樹中取f值最大對應的位置 query(root3,1,1e6,a[i]+1,1e6);
    對於樹上的每一個節點,我們維護tree[i](f值),pos[i] (在陣列中對應的位置),ls[i] (左兒子),rs[i] (右兒子),線段樹開值域應該比較好維護。

附上程式碼

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 10;
int f[maxn * 21], a[maxn];
int nodecnt, root1, root2, root3, ans, id, poss;
int path[maxn], ls[maxn * 21], rs[maxn * 21], tree[maxn * 21], pos[maxn * 21];
char op[maxn];
inline int read() {
    int s = 0, w = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-')
            w = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
    return s * w;
}

void print(int id) {
    if (id == 0)
        return;
    print(path[id]);
    printf("%d ", a[id]);
}
void up(int rt) {
    if (tree[ls[rt]] < tree[rs[rt]]) {
        tree[rt] = tree[rs[rt]];
        pos[rt] = pos[rs[rt]];
    } else {
        tree[rt] = tree[ls[rt]];
        pos[rt] = pos[ls[rt]];
    }
}
void insert(int &rt, int l, int r, int val, int p) {
    if (rt == 0)
        rt = ++nodecnt;
    if (l == r) {
        if (tree[rt] < val) {
            pos[rt] = p;
            tree[rt] = val;
        }
        return;
    }
    int mid = (l + r) >> 1;
    if (a[p] <= mid)
        insert(ls[rt], l, mid, val, p);
    else
        insert(rs[rt], mid + 1, r, val, p);
    up(rt);
    return;
}
int query(int rt, int l, int r, int s, int t) {
    if (l >= s && r <= t)
        return pos[rt];
    if (s > t || !rt)
        return 0;
    int mid = (l + r) >> 1;
    int ans = 0, poss = 0;
    if (s <= mid) {
        int lpos = query(ls[rt], l, mid, s, t);
        if (ans < f[lpos]) {
            ans = f[lpos];
            poss = lpos;
        }
    }
    if (t > mid) {
        int rpos = query(rs[rt], mid + 1, r, s, t);
        if (ans < f[rpos]) {
            ans = f[rpos];
            poss = rpos;
        }
    }
    return poss;
}
int main() {
    int n, k;
    n = read(), k = read();
    for (int i = 1; i <= n; ++i) a[i] = read(), f[i] = 1;
    for (int i = 1; i <= k; ++i) scanf(" %c", &op[i]);
    for (int i = k + 1; i < n; ++i) op[i] = op[(i - 1) % k + 1];
    for (int i = 1, poss; i <= n; ++i) {
        poss = query(root1, 1, 1e6, a[i], a[i]);
		if (f[i] < f[poss] + 1) {
            f[i] = f[poss] + 1;
            path[i] = poss;
        }
        poss = query(root2, 1, 1e6, 1, a[i] - 1);
		if (f[i] < f[poss] + 1) {
            f[i] = f[poss] + 1;
            path[i] = poss;
        }
        poss = query(root3, 1, 1e6, a[i] + 1, 1e6);
        if (f[i] < f[poss] + 1) {
            f[i] = f[poss] + 1;
            path[i] = poss;
        }
        if (ans < f[i]) {
            ans = f[i];
            id = i;
        }
        if (op[f[i]] == '=')
            insert(root1, 1, 1e6, f[i], i);
        else if (op[f[i]] == '<')
            insert(root2, 1, 1e6, f[i], i);
        else if (op[f[i]] == '>')
            insert(root3, 1, 1e6, f[i], i);
    }
    printf("%d\n", ans);
    print(id);
    return 0;
}