1. 程式人生 > 其它 >CF380C Sereja and Brackets (線段樹)

CF380C Sereja and Brackets (線段樹)

CF380C Sereja and Brackets

Mean

給定一個括號串\(S\)\(|S|<=10^6\)\(m\)次詢問,每次詢問區間\([l,r]\)內最長的合法括號串子序列,輸出長度。

Sol

線段樹。

注意到題目要求的是子序列。

發現去掉合法子序列後,剩餘部分字串表現為\(...)))((((...\)

套路第一次見,直接講做法。

考慮線段樹自底向上區間合併資訊,每個節點維護未匹配的左括號數\(lsum\),未匹配的右括號數\(rsum\)

那麼兩個節點合併向上合併時有如下式子

\(tr[rt].lsum = tr[rt<<1].lsum+tr[rt<<1|1].lsum-min(tr[rt<<1].lsum,tr[rt<<1|1].rsum);\)


\(tr[rt].rsum = tr[rt<<1].rsum+tr[rt<<1|1].rsum-min(tr[rt<<1].lsum,tr[rt<<1|1].rsum);\)

用圖形來理解的話,左節點\(..)))(((...\),和右節點\(...)))((((...\)合併,左邊的未匹配左括號會和右邊的未匹配右括號匹配,數目少的會被全部匹配掉。

最後在查詢時採用同樣的合併操作,則可以得到查詢區間\([L,R]\)內的未匹配右括號數\(rsum\),與未匹配左括號數\(lsum\)。最後答案即為\(ans=(R-L+1)-lsum-rsum\)

Code

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define dep(i,a,b) for(int i=(a);i>=(b);--i)
#define lowbit(x) (x&(-x))
#define debug(x) cout<<#x<<" :"<<x<<endl
#define debug1(x) cout<<#x<<" :"<<x<<" "
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=1e6+20;
const int MAX=10000007;
inline int read() {
    char c=getchar();int x=0,f=1;
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0',c=getchar();}
    return x*f;
}
char s[N];
/**
 * 區間最長合法括號子序列
 * @param x [description]
 */
inline void out(int x) {   
    if(x>9) out(x/10);   
    putchar(x%10+'0'); 
}     
int q[N];
struct node{
    int lsum;//未匹配的左括號數量
    int rsum;//未匹配的右括號數量
}tr[N<<2];

#define mid ((l+r)>>1)
void pushup(int rt){
    tr[rt].lsum = tr[rt<<1].lsum+tr[rt<<1|1].lsum-min(tr[rt<<1].lsum,tr[rt<<1|1].rsum);
    tr[rt].rsum = tr[rt<<1].rsum+tr[rt<<1|1].rsum-min(tr[rt<<1].lsum,tr[rt<<1|1].rsum);
}
void build(int l,int r,int rt){
    if(l==r)
    {
        if(s[l]=='(')tr[rt].lsum+=1;
        else tr[rt].rsum+=1;
        return ;
    }
    build(l,mid,rt<<1);
    build(mid+1,r,rt<<1|1);
    pushup(rt);
}
node query(int L,int R,int l,int r,int rt){
    if(L<=l&&r<=R){
        return tr[rt];
    }
    node ls,rs,ans;
    ls = (node){0,0};
    rs = (node){0,0};
    ans = (node){0,0};
    if(L<=mid)ls=query(L,R,l,mid,rt<<1);
    if(R>mid)rs=query(L,R,mid+1,r,rt<<1|1);
    ans.lsum = ls.lsum+rs.lsum-min(ls.lsum,rs.rsum);
    ans.rsum = ls.rsum+rs.rsum-min(ls.lsum,rs.rsum);
    return ans;
}
int t;

int main(){
    scanf("%s",s+1);
    int lens=strlen(s+1);
    build(1,lens,1);
    scanf("%d",&t);
    while(t--){
        int l,r;
        scanf("%d%d",&l,&r);
        node ans = query(l,r,1,lens,1);
        printf("%d\n",r-l+1-ans.lsum-ans.rsum);
    }
    return 0;
}