1. 程式人生 > >[BZOJ2738]矩陣乘法

[BZOJ2738]矩陣乘法

整體二分+二維樹狀陣列。

好題啊!寫了一個來小時。

一看這道題,主席樹不會搞,只能用離線的做法了。

整體二分真是個好東西,啥都可以搞,尤其是區間第 \(k\) 大這種東西。

我們二分答案,然後用二維樹狀陣列實現 \(\log^2 n\) 的單點修改,時間複雜度 \(O(q\log^2 n\log 10^9)\)

\(Code\ Below:\)

#include <bits/stdc++.h>
#define lowbit(x) ((x)&(-(x)))
#define id(x,y) (((x)-1)*n+(y)) 
using namespace std;
const int maxn=300000+10;
const int lim=1e9;
int n,m,c[510][510],ans[maxn];
 
struct Element{
    int x,y,k;
}e[maxn],e1[maxn],e2[maxn];
 
bool cmp(Element a,Element b){
    return a.k<b.k;
}
 
struct Query{
    int x1,y1,x2,y2,k,id;
}q[maxn],q1[maxn],q2[maxn];
 
inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}
 
inline void add(int x,int y,int z){
    for(int i=x;i<=n;i+=lowbit(i))
        for(int j=y;j<=n;j+=lowbit(j)) c[i][j]+=z;
}
 
inline int sum(int x,int y){
    int ans=0;
    for(int i=x;i;i-=lowbit(i))
        for(int j=y;j;j-=lowbit(j)) ans+=c[i][j];
    return ans;
}
 
void solve(int L,int R,int Le,int Ri,int l,int r){
    if(L>R||Le>Ri) return ;
    if(l==r){
        for(int i=Le;i<=Ri;i++) ans[q[i].id]=l;
        return ;
    }
    int mid=(l+r)>>1,cnt1=0,cnt2=0,tot1=0,tot2=0,val;
    for(int i=L;i<=R;i++){
        if(e[i].k<=mid) add(e[i].x,e[i].y,1),e1[++cnt1]=e[i];
        else e2[++cnt2]=e[i];
    }
    for(int i=1;i<=cnt1;i++) e[L+i-1]=e1[i];
    for(int i=1;i<=cnt2;i++) e[L+i+cnt1-1]=e2[i];
    for(int i=Le;i<=Ri;i++){
        val=sum(q[i].x2,q[i].y2)-sum(q[i].x1-1,q[i].y2)-sum(q[i].x2,q[i].y1-1)+sum(q[i].x1-1,q[i].y1-1);
        if(val>=q[i].k) q1[++tot1]=q[i];
        else q[i].k-=val,q2[++tot2]=q[i];
    }
    for(int i=L;i<=L+cnt1-1;i++) add(e[i].x,e[i].y,-1);
    for(int i=1;i<=tot1;i++) q[Le+i-1]=q1[i];
    for(int i=1;i<=tot2;i++) q[Le+i+tot1-1]=q2[i];
    solve(L,L+cnt1-1,Le,Le+tot1-1,l,mid);
    solve(L+cnt1,R,Le+tot1,Ri,mid+1,r);
}
 
int main()
{
    n=read(),m=read();
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++) e[id(i,j)].x=i,e[id(i,j)].y=j,e[id(i,j)].k=read();
    sort(e+1,e+n*n+1,cmp);
    for(int i=1;i<=m;i++) q[i].x1=read(),q[i].y1=read(),q[i].x2=read(),q[i].y2=read(),q[i].k=read(),q[i].id=i;
    solve(1,n*n,1,m,0,lim);
    for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
    return 0;
}