1. 程式人生 > >SPOJ COT2 樹上的莫隊演算法,樹上區間查詢

SPOJ COT2 樹上的莫隊演算法,樹上區間查詢

題意:n個節點形成的一棵樹。每個節點有一個值。m次查詢,求出(u,v)路徑上出現了多少個不同的數。

樹上的莫隊演算法,同樣將樹分成siz=sqrt(n)塊,然後離線操作。先對樹dfs一遍,每當子樹節點個數num>=siz,就將這num個分成一塊。讀取所有的查詢按左端點所在塊排序。

用倍增法求lca單次要用logn複雜度,要跑3200ms。有個地方可以優化,就是知道了所有的查詢,也就是事先知道了轉移路徑,可以用離線的方法求O(n)求出所有需要用到的lca,這個寫起來比較麻煩,不過可以優化到1800ns。程式碼寫的比較挫。。。。

logn求lca:3200+ms

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
#include <algorithm>
using namespace std;
const int maxn=4e4+10;
const int maxm=1e5+10;

int n,m, siz;
vector<int> g[maxn];
int a[maxn], b[maxn], ans[maxm];
int tot[maxn], in[maxn];
int fa[maxn][20], dep[maxn];

struct Query
{
    int l, r, id;
    int st,ed;
    bool operator <(const Query& a) const
    {
        return st!=a.st? st<a.st: ed<a.ed; //先按左端點所在塊先後排序,其次考慮又右端點所在塊
    }
};

Query q[maxm];
int tag, bel[maxn];
int st[maxn], top;
int dfs(int u, int par, int d, int &cnt)
{
    dep[u]=d; fa[u][0]=par;
    int num=0;
    for(int i=0; i<g[u].size(); i++){
        int v=g[u][i];
        if(v!=par){
            num+=dfs(v, u, d+1, cnt);
            if(num>=siz){ //子樹大小>=sqrt(n),分成一塊
                for(int i=0; i<num; i++)
                    bel[st[--top]]=tag;
                tag++;
                num=0;
            }
        }
    }

    st[top++]=u;//記錄子樹遍歷的點
    return num+1;
}

void init()
{
    for(int i=0; i<=n; i++) g[i].clear();
    memset(tot, 0, sizeof(tot));
    memset(in, 0, sizeof(in));

    siz=sqrt(n);

    for(int i=1;i<=n; i++) scanf("%d",&a[i]), b[i]=a[i];
    sort(b+1, b+n+1);
    for(int i=1; i<=n; i++)
        a[i]=lower_bound(b+1, b+n+1, a[i])-b;

    for(int i=0; i<n-1; i++){
        int u,v;
        scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }

    int cnt=0; tag=top=0;
    int num=dfs(1, -1, 0, cnt);
    for(int i=0; i<num; i++)
        bel[st[--top]]=tag; //最後剩下的數也分成一塊

    for(int i=1; i<20; i++){
        for(int u=1; u<=n; u++)
            if(fa[u][i-1]==-1)
                fa[u][i]=-1;
            else fa[u][i]=fa[fa[u][i-1]][i-1];
    }

    for(int i=0; i<m; i++){
        scanf("%d%d", &q[i].l, &q[i].r);
        if(bel[q[i].l]>bel[q[i].r])
            swap(q[i].l, q[i].r);
        q[i].id=i;
        q[i].st=bel[q[i].l];
        q[i].ed=bel[q[i].r];
    }
    sort(q, q+m);
}

int lca(int u, int v)
{
    if(dep[u]>dep[v]) swap(u, v);
    for(int i=0; i<20; i++)
        if((dep[v]-dep[u])>>i&1)
            v=fa[v][i];

    if(u==v) return u;
    for(int i=19; i>=0; i--){
        if(fa[u][i]!=fa[v][i]){
            u=fa[u][i];
            v=fa[v][i];
        }
    }
    return fa[u][0];
}

void solve()
{
    int res=0;
    int cu=1, cv=1;

    for(int i=0; i<m; i++){
        int nu=q[i].l, nv=q[i].r;
        int par=lca(cu, nu);
        while(cu!=par){
            if(in[cu]){
                if(--tot[a[cu]]==0)
                    res--;
            }
            else if(++tot[a[cu]]==1)
                res++;
            in[cu]^=1;
            cu=fa[cu][0];
        }

        cu=nu;
        while(cu!=par){
            if(in[cu]){
                if(--tot[a[cu]]==0)
                    res--;
            }
            else if(++tot[a[cu]]==1)
                res++;
            in[cu]^=1;
            cu=fa[cu][0];
        }
        cu=nu;


        par=lca(cv, nv);
        while(cv!=par){
            if(in[cv]){
                if(--tot[a[cv]]==0)
                    res--;
            }
            else if(++tot[a[cv]]==1)
                res++;
            in[cv]^=1;
            cv=fa[cv][0];
        }

        cv=nv;
        while(cv!=par){
            if(in[cv]){
                if(--tot[a[cv]]==0)
                    res--;
            }
            else if(++tot[a[cv]]==1)
                res++;
            in[cv]^=1;
            cv=fa[cv][0];
        }
        cv=nv;

        par=lca(cu, cv);
        ans[q[i].id]=res+(!tot[a[par]]);
    }
}

int main()
{
    while(scanf("%d%d", &n, &m)==2){
        init();
        solve();
        for(int i=0; i<m; i++)
            printf("%d\n", ans[i]);
    }
    return 0;
}

離線查詢lca:1800+ms

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
#include <algorithm>
using namespace std;
#pragma comment(linker, "/STACK:1024000000,1024000000")
typedef pair<int,int> P;
#define fir first
#define sec second
const int maxn=4e4+10;
const int maxm=1e5+10;

int n,m, siz;
vector<int> g[maxn];
int first[maxn],ltot=0, nxt[6*maxm];
P lq[6*maxm];//所有需要查詢的lca,lq[i].first儲存v,second儲存查詢的id

int a[maxn], b[maxn], ans[maxm];
int tot[maxn], in[maxn], fa1[maxn];
int fa[maxn], lca[3*maxm], col[maxn];
int bel[maxn],st[maxn],top=0;
struct Query
{
    int l, r, id;
    int st,ed;
    bool operator <(const Query& a) const
    {
        return st!=a.st? st<a.st: ed<a.ed;
    }
};

Query q[maxm];
int tag;
int dfs(int u, int par, int &cnt)//分塊
{
    fa1[u]=par;
    int num=0;
    for(int i=0; i<g[u].size(); i++){
        int v=g[u][i];
        if(v!=par)
            num+=dfs(v, u, cnt);
        if(num>=siz){
            for(int i=0; i<num; i++)
                bel[st[--top]]=tag;
            tag++;
            num=0;
        }

    }
    st[top++]=u;
    return num+1;
}

int find(int u)
{
    return fa[u]==u?u:(fa[u]=find(fa[u]));
}

int unite(int x, int y)
{
    x=fa[x];
    y=fa[y];
    fa[y]=x;
}


void dfs2(int u, int par)//離線查詢所有lca
{
    col[u]=1;
    for(int i=first[u]; i!=-1; i=nxt[i]){
        int v=lq[i].fir, id=lq[i].sec;
        if(!col[v]) continue;
        else if(col[v]==1){
            lca[id]=v;
        }
        else{
            lca[id]=find(v);
        }
    }

    for(int i=0; i<g[u].size(); i++){
        int v=g[u][i];
        if(v!=par)
            dfs2(v, u);
    }
    col[u]=2;
    unite(par, u);
}

void add(int u, int v, int id)//查詢m<=1e5,數比較多所以用前向星實現優化
{
    lq[ltot]=P(v,id);
    nxt[ltot]=first[u];
    first[u]=ltot++;
}

void init()
{
    for(int i=0; i<=n; i++) g[i].clear();
    memset(tot, 0, sizeof(tot));
    memset(in, 0, sizeof(in));

    siz=sqrt(n);

    for(int i=1;i<=n; i++) scanf("%d", a+i), b[i]=a[i];
    sort(b+1, b+n+1);
    for(int i=1; i<=n; i++)
        a[i]=lower_bound(b+1, b+n+1, a[i])-b;

    for(int i=0; i<n-1; i++){
        int u,v;
        scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }


    int cnt=0; top=0; tag=0;
    int num=dfs(1, -1, cnt);
    for(int i=0; i<num; i++)
        bel[st[--top]]=tag;

    for(int i=0; i<m; i++){
        scanf("%d%d", &q[i].l, &q[i].r);
        if(bel[q[i].l]>bel[q[i].r])
            swap(q[i].l, q[i].r);
        q[i].id=i;
        q[i].st=bel[q[i].l];
        q[i].ed=bel[q[i].r];
    }
    sort(q, q+m);

    cnt=0; ltot=0;
    memset(first, -1, sizeof(first));
    add(1, q[0].l, cnt);
    add(q[0].l, 1, cnt++);
    add(1, q[0].r, cnt);
    add(q[0].r, 1, cnt++);
    add(q[0].r, q[0].l, cnt);
    add(q[0].l, q[0].r, cnt++);
    //add(q[0].r, q[0].l, cnt++);


    for(int i=0; i<m-1; i++){
    add(q[i].l, q[i+1].l, cnt);//第i個查詢左端點向第i+1個左端點轉移,所以需要它們之間的lca
    add(q[i+1].l, q[i].l, cnt++);
    add(q[i].r, q[i+1].r, cnt);//第i個查詢右端點向第i+1個右端點轉移
    add(q[i+1].r, q[i].r, cnt++);
    add(q[i+1].r, q[i+1].l, cnt);//左端點和右端點的lca
    add(q[i+1].l, q[i+1].r,cnt++);
    }
    for(int i=0; i<=n; i++) fa[i]=i;
    memset(col, 0, sizeof(col));
    dfs2(1, 0);
}



void solve()
{
    int res=0;
    int cu=1, cv=1;

    for(int i=0; i<m; i++){
        int nu=q[i].l, nv=q[i].r;
        //cout<<lca[i*3]<<' '<<lca[i*3+1]<<' '<<lca[i*3+2]<<endl;
        int par=lca[i*3];
        while(cu!=par){
            if(in[cu]){
                if(--tot[a[cu]]==0)
                    res--;
            }
            else if(++tot[a[cu]]==1)
                res++;
            in[cu]^=1;
            cu=fa1[cu];
        }

        cu=nu;
        while(cu!=par){
            if(in[cu]){
                if(--tot[a[cu]]==0)
                    res--;
            }
            else if(++tot[a[cu]]==1)
                res++;
            in[cu]^=1;
            cu=fa1[cu];
        }
        cu=nu;


        par=lca[i*3+1];
        while(cv!=par){
            if(in[cv]){
                if(--tot[a[cv]]==0)
                    res--;
            }
            else if(++tot[a[cv]]==1)
                res++;
            in[cv]^=1;
            cv=fa1[cv];
        }

        cv=nv;
        while(cv!=par){
            if(in[cv]){
                if(--tot[a[cv]]==0)
                    res--;
            }
            else if(++tot[a[cv]]==1)
                res++;
            in[cv]^=1;
            cv=fa1[cv];
        }
        cv=nv;

        par=lca[i*3+2];
        ans[q[i].id]=res+(!tot[a[par]]);
    }
}

int main()
{
    while(cin>>n>>m){
        init();
        solve();
        for(int i=0; i<m; i++)
            printf("%d\n", ans[i]);
    }
    return 0;
}