C. 痛苦的 01 矩陣 (推公式,樹狀陣列維護)
阿新 • • 發佈:2018-12-13
現有一個 n×n 的 01 矩陣 M。
定義 cost(i,j) 為:把第 i 行和第 j 列全部變成 1 最少需要改動多少個元素。
定義矩陣的痛苦值 pain(M) 為:
pain(M)=(∑i=1n∑j=1n(cost(i,j))2)mod(109+7)
要求求出初始矩陣的痛苦值和每次修改操作之後的痛苦值。
Input
第一行三個正整數 n,k,q (2≤n≤2⋅105, 1≤k≤min(n2,2⋅105), 0≤q≤2⋅105)。k 表示這個矩陣中有 k 個 1。q 表示修改操作次數。
接下來 k 行,每行兩個正整數 xi, yi (1≤xi,yi≤n),表示有一個 1 在第 xi 行,第 yi 列。保證所有 (xi,yi) 各不相同。
接下來 q 行,每行兩個正整數 ui, vi (1≤ui,vi≤n),表示修改第 ui 行,第 vi 列。如果該位置原先為 0,則改為 1;如果該位置原先為 1,則改為 0。
Output
輸出 q+1 行,依次為所有修改發生前的痛苦值,和每次修改操作後的痛苦值。
Examples
Input
3 4 9 1 1 1 2 2 3 3 1 3 3 1 2 1 3 2 2 2 2 2 1 3 1 1 1 2 3
Output
73 48 75 52 29 52 33 52 77 104
#include<bits/stdc++.h> using namespace std; typedef long long LL; typedef unsigned long long ULL; #define rep(i,a,b) for(int i=a;i<b;++i) #define per(i,a,b) for(int i=b-1;i>=a;--i) #define lowbit(x) (x&(-x)) const int mod=1e9+7; const int N=2e5+10; LL tr_r2[N],tr_c2[N],tr_r[N],tr_c[N]; LL n; void update(LL tr[],int x,LL val) { val%=mod; while(x<=n) { tr[x]=(tr[x]+val)%mod; x+=lowbit(x); } } LL query(LL tr[],LL x) { LL res=0; while(x>0) { res=(res+tr[x])%mod; if(res<0)res+=mod; x-=lowbit(x); } return res; } LL r[N],c[N]; set<int> st[N]; LL s; LL solve(LL n) { LL ans1=query(tr_r2,n),ans2=query(tr_c2,n); LL ans3=query(tr_r,n), ans4=query(tr_c,n); //printf("ans1:%lld ans2:%lld ans3:%lld ans4:%lld\n",ans1,ans2,ans3,ans4); ans1=ans1*(n-2)%mod; ans2=ans2*(n-2)%mod; ans3=ans3*ans4%mod; ans3=ans3*2%mod; ans1=(((ans1+ans2)%mod+s)%mod+ans3)%mod; return ans1; } void change(int x,int y,LL v) { update(tr_r2,x,-r[x]*r[x]); update(tr_r2,x,(r[x]+v)*(r[x]+v)); update(tr_c2,y,-c[y]*c[y]); update(tr_c2,y,(c[y]+v)*(c[y]+v)); update(tr_r,x,-r[x]); update(tr_r,x,r[x]+v); update(tr_c,y,-c[y]); update(tr_c,y,c[y]+v); r[x]=r[x]+v; //if(r[x]>=mod)r[x]-=mod; if(r[x]<=-mod)r[x]+=mod; c[y]=c[y]+v; //if(c[y]>=mod)c[y]-=mod; if(c[y]<=-mod)c[y]+=mod; // printf("y:%d c[y]:%lld\n\n",y,c[y]); s=s+v; s%=mod;//if(s>=mod)s-=mod; if(s<=-mod)s+=mod; } /* 3 4 9 1 1 1 2 2 3 3 1 3 3 1 2 1 3 2 2 2 2 2 1 3 1 1 1 2 3 */ int main() { LL K,Q; scanf("%lld %lld %lld",&n,&K,&Q); s=n*n%mod; for(int i=1; i<=n; i++) { update(tr_r2,i,n*n); update(tr_c2,i,n*n); update(tr_r,i,n); update(tr_c,i,n); c[i]=n;r[i]=n; } rep(i,0,K) { int x,y; scanf("%d %d",&x,&y); st[x].insert(y); change(x,y,-1); } // printf("s:%lld\n",s); //rep(i,1,n+1)printf("i:%d %lld %lld\n",i,r[i],c[i]); LL ans=solve(n); printf("%lld\n",ans); rep(i,0,Q) { int x,y; scanf("%d %d",&x,&y); if(st[x].count(y)){ change(x,y,1); st[x].erase(y); }else{ change(x,y,-1); st[x].insert(y); } ans=solve(n); printf("%lld\n",ans); } return 0; }