1. 程式人生 > 實用技巧 >P4168 [Violet]蒲公英

P4168 [Violet]蒲公英

題意描述

蒲公英

題面好美好啊

強制線上查詢區間眾數。

演算法分析

分塊模板題了吧。

類似區間眾數等不滿足區間加法的問題很難用線段樹或樹狀陣列來實現,所以這裡採用分塊。

其實分塊幾乎都是趨近於大塊維護,小塊暴力的思想,所以程式碼實現難度和思維難度沒有上面提到的資料結構高。

思路是將 \(N\) 個數分為 \(T\) 個區間,至於 \(T\) 的取值之後再討論。

首先進行離散化是肯定的。

對於每個區間維護 \(c(i,j,k)\) 表示:區間 \(i\) 到區間 \(j\) 之間有多少個 \(k\)

同時維護 \(f(i,j)\)\(d(i,j)\) 分別表示區間 \(i\) 與區間 \(j\)

之間的眾數的數量和數字。

那麼每次查詢區間 \([x,y]\) 時,只需要對於兩邊的小塊進行處理即可。

時間複雜度為 \(O(NT^2+MN/T)\),顯然當 \(NT^2=MN/T\) 時總時間複雜度最低。

假設 \(N,M\) 為同一數量級,那麼當 \(T=\sqrt[3]{N}\) 時總時間複雜度最小,此時為 \(O(N^{5/3})\) 量級。

考慮一下優化,發現時間卡在 \(O(NT^2)\) 的預處理上,其實並不需要這麼高複雜度的預處理。

我們可以維護 \(sum(i,j)\) 表示前 \(i\) 個區間中數字 \(j\) 的個數。

此時時間複雜度變為 \(O(T^3+MN/T)\),此時 \(T\)

\(\sqrt N\) 可以實現 \(O(N\sqrt N)\) 的時間複雜度。

至此實現了分塊能達到的最快時間複雜度了吧。(dalao 有更多優化可以私信蒟蒻)

程式碼實現

首先是 \(O(N^{5/3})\) 的分塊:

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#define N 40010
#define T 40
using namespace std;

int n,q,t;
int nowl,nowr,num,mx;
int a[N],fa[N],L[T],R[T];
int c[T][T][N],f[T][T],d[T][T];

int read(){
	int x=0,f=1;char c=getchar();
	while(c<'0' || c>'9') f=(c=='-')?-1:1,c=getchar();
	while(c>='0' && c<='9') x=x*10+c-48,c=getchar();
	return x*f; 
}

void pre_work(){
    //分塊
    t=(int)pow(n*1.0,1.0/3);
    int l;
    if(t) l=n/t;
    for(int i=1;i<=t;i++) L[i]=(i-1)*l+1,R[i]=i*l;
    if(R[t]<n) L[t+1]=R[t]+1,R[++t]=n;
    //離散化
    sort(fa+1,fa+n+1);
    int m=unique(fa+1,fa+n+1)-(fa+1);
    for(int i=1;i<=n;i++) a[i]=lower_bound(fa+1,fa+m+1,a[i])-fa;
    //預處理
    for(int i=1;i<=t;i++)
        for(int j=i;j<=t;j++){
            for(int k=L[i];k<=R[j];k++) c[i][j][a[k]]++;
            for(int k=1;k<=m;k++)
                if(c[i][j][k]>f[i][j] || c[i][j][k]==f[i][j] && k<d[i][j])
                    f[i][j]=c[i][j][k],d[i][j]=k; 
        }
    return;
}

void update(int x){
    c[nowl][nowr][a[x]]++;
    if(c[nowl][nowr][a[x]]>mx ||c[nowl][nowr][a[x]]==mx && a[x]<num) 
        num=a[x],mx=c[nowl][nowr][a[x]];
    return;
}

int solve(int x,int y){
    int l,r;
    if(x>y) swap(x,y);
    for(int i=1;i<=t;i++) if(x<=R[i]) {l=i;break;}
    for(int i=t;i>=1;i--) if(y>=L[i]) {r=i;break;}
    if(l+1<=r-1) nowl=l+1,nowr=r-1; else nowl=nowr=0;
    //如果兩者之間的距離不足一個區間的長度就直接暴力。
    num=d[nowl][nowr];mx=f[nowl][nowr];
    if(l==r){
        for(int i=x;i<=y;i++) update(i);
        for(int i=x;i<=y;i++) c[nowl][nowr][a[i]]--;
    }else{//維護兩邊的剩餘區間即可。
        for(int i=x;i<=R[l];i++) update(i);
        for(int i=L[r];i<=y;i++) update(i);
        for(int i=x;i<=R[l];i++) c[nowl][nowr][a[i]]--;
        for(int i=L[r];i<=y;i++) c[nowl][nowr][a[i]]--;
    }
    return fa[num];
}

int main(){
    n=read(),q=read();
    for(int i=1;i<=n;i++) a[i]=read(),fa[i]=a[i];
    pre_work();
    int ans=0;
    for(int i=1;i<=q;i++){
        int x=read(),y=read();
        ans=solve((x+ans-1)%n+1,(y+ans-1)%n+1);
        printf("%d\n",ans);
    }
    return 0;
}

然後是字首和優化之後的 \(O(N\sqrt N)\)

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#define N 40010
#define T 220
using namespace std;

int n,q,t;
int a[N],fa[N],b[N];
int L[T],R[T];
int sum[T][N];
struct node{
    int num,cnt;
}p[T][T];

int read(){
	int x=0,f=1;char c=getchar();
	while(c<'0' || c>'9') f=(c=='-')?-1:1,c=getchar();
	while(c>='0' && c<='9') x=x*10+c-48,c=getchar();
	return x*f; 
}

void pre_work(){
    //分塊
    t=sqrt(n);
    int l;
    if(t) l=n/t;
    for(int i=1;i<=t;i++) L[i]=(i-1)*l+1,R[i]=i*l;
    if(R[t]<n) L[t+1]=R[t]+1,R[++t]=n;
    //離散化
    sort(fa+1,fa+n+1);
    int m=unique(fa+1,fa+n+1)-(fa+1);
    for(int i=1;i<=n;i++) a[i]=lower_bound(fa+1,fa+m+1,a[i])-fa;
    //預處理
    for(int i=1;i<=t;i++){
        memset(b,0,sizeof(b));node now;
        now.cnt=now.num=0;
        for(int j=i;j<=t;j++){
            for(int k=(j-1)*l+1;k<=min(n,j*l);k++){
                b[a[k]]++;
                if(b[a[k]]>now.cnt || b[a[k]]==now.cnt && a[k]<now.num)
                    now.num=a[k],now.cnt=b[a[k]];
            }
            p[i][j]=now;
        }
    }
    for(int i=1;i<=t;i++){
        for(int j=1;j<=m;j++) sum[i][j]=sum[i-1][j];
        for(int j=(i-1)*l+1;j<=min(n,i*l);j++) sum[i][a[j]]++;
    }
    return;
}

int num,mx,nowl,nowr;

void update(int x){
    sum[nowr][a[x]]++;
    //這裡有一點特殊的改變,因為如果 nowl=0,那麼 nowl-1=-1 導致陣列越界。
    int k=(!nowl)?sum[nowr][a[x]]:sum[nowr][a[x]]-sum[nowl-1][a[x]];
    if(k>mx || k==mx && a[x]<num) mx=k,num=a[x];
}

int solve(int x,int y){
    int l,r;
    if(x>y) swap(x,y);
    for(int i=1;i<=t;i++) if(x<=R[i]) {l=i;break;}
    for(int i=t;i>=1;i--) if(y>=L[i]) {r=i;break;}
    if(l+1<=r-1) nowl=l+1,nowr=r-1; else nowl=nowr=0;
    num=p[nowl][nowr].num;mx=p[nowl][nowr].cnt;
    if(l==r){
        for(int i=x;i<=y;i++) update(i);
        for(int i=x;i<=y;i++) sum[nowr][a[i]]--;
    }else{
        for(int i=x;i<=R[l];i++) update(i);
        for(int i=L[r];i<=y;i++) update(i);
        for(int i=x;i<=R[l];i++) sum[nowr][a[i]]--;
        for(int i=L[r];i<=y;i++) sum[nowr][a[i]]--;
    }
    return fa[num];
}

int main(){
    n=read(),q=read();
    for(int i=1;i<=n;i++) a[i]=read(),fa[i]=a[i];
    pre_work();
    int ans=0;
    for(int i=1;i<=q;i++){
        int x=read(),y=read();
        ans=solve((x+ans-1)%n+1,(y+ans-1)%n+1);
        printf("%d\n",ans);
    }
    return 0;
}

完結撒❀。