Codeforces 251D - Two Sets(異或方程組)
題意:
你有一個可重集 \(S=\{a_1,a_2,\dots,a_n\}\),你要把它劃分成兩個可重集 \(S_1,S_2\) 使得 \(S\) 中每個元素都恰好屬於 \(S_1\) 與 \(S_2\) 之一。
記 \(X_1\) 為 \(S_1\) 中所有元素的異或和,\(X_2\) 為 \(S_2\) 中所有元素的異或和。
最大化 \(X_1+X_2\),如果有多種分配方案,再最小化 \(X_1\)。
\(n \in [1,10^5],a_i \in [1,10^{18}]\)
暑假省選班講過這道題,當時聽得一臉懵B,還問lxr為什麼線性基本質上就是高斯消元。。。。。。wtclwtcl
設 \(X=a_1 \oplus a_2 \oplus a_3 \oplus\dots\oplus a_n\)
考慮 \(X\) 二進位制上的每一位,如果 \(X\) 的第 \(i\) 位為 \(1\),那麼意味著它只能拆成 \(0\) 和 \(1\),不會對 \(X_1+X_2\) 產生影響。
但如果 \(X\) 的第 \(i\) 位為 \(0\),那麼它可以拆成 \(0,0\) 或者 \(1,1\),我們的目標是讓 \(X_1+X_2\) 儘可能大,我們就要儘量選擇 \(1,1\),也就是要儘量讓 \(X_1\) 的第 \(i\) 位為 \(1\)。
我們假設 \(n\) 個未知數 \(x_1,x_2,\dots,x_n\),\(x_i=1\) 表示 \(i\)
那麼 \(X_1\) 的第 \(b\) 位為 \(1\) 等價於一個異或方程 \(t_1x_1\oplus t_2x_2\oplus\dots\oplus t_nx_n=1\),其中 \(t_i\) 表示 \(a_i\) 二進位制下的第 \(b\) 位是否為 \(1\)。
具體地來說,我們找到 \(X\) 中最高的為 \(0\) 的二進位制位 \(b\),根據之前的推論可以列出一個異或方程,如果該異或方程有解,那麼我們肯定要在這一位上放 \(1\)。因為如果你在這一位上放 \(1\),哪怕後面都是 \(0\),那 \(X_1+X_2\)
我們考慮這樣的貪心做法:從高位向低位列舉每一個 \(X\) 二進位制下為 \(0\) 的二進位制位 \(b\),我們嘗試著在這一位上放 \(1\),如果存在一種方案,它既能夠滿足前面的條件(在第 \(b\) 位前面放 \(1\) 的位都對應一個異或方程,把它們聯立起來得到的異或方程組),那麼我們就在這一位上放 \(1\),否則就在這一位上放 \(0\)。
最大化 \(X_1+X_2\) 之後,我們再考慮 \(X_1\) 儘量小這個條件。這時候 \(X\) 為 \(1\) 的二進位制位就要派上用場了。對於 \(X\) 二進位制下為 \(1\) 的位,它又可以細分為第 \(1\) 堆分配 \(0\),第 \(2\) 堆分配 \(1\),以及第 \(1\) 堆分配 \(1\),第 \(2\) 堆分配 \(0\)。我們肯定希望第一堆分配地儘可能少,於是我們重複一遍前面的操作,找到一個 \(1\) 位就嘗試填 \(0\),就可以了。
於是我們有了優秀的 \(n \log^3a_i\) 的做法,每次就聯立出一個異或方程組,然後高斯消元判斷這個異或方程組是否有解。
但其實並不用每次都重新消一遍,對於每個新的異或方程,都用前面的方程消去它的最高位(類似於線性基?)。這樣是 \(n\log^2a_i\) 的,再注意到每一位係數都是 \(0/1\),可以用 bitset
再搞掉一個 \(\log\)。
#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define fz(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define ffe(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++)
#define fill0(a) memset(a,0,sizeof(a))
#define fill1(a) memset(a,-1,sizeof(a))
#define fillbig(a) memset(a,63,sizeof(a))
#define pb push_back
#define ppb pop_back
#define mp make_pair
typedef pair<int,int> pii;
typedef long long ll;
const int MAXN=1e5+5;
const int MAXB=63+2;
int n;ll a[MAXN],s=0;
bitset<MAXN> bt[MAXB];
int hi[MAXN],pos[MAXN],cur=0,ans[MAXN];
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%lld",&a[i]),s^=a[i];
for(int i=62;~i;i--) if(!(s>>i&1)){
++cur;
for(int j=1;j<=n;j++) if((a[j]>>i)&1) bt[cur][j]=1;
bt[cur][n+1]=1;
for(int j=1;j<cur;j++) if(bt[cur][hi[j]]) bt[cur]^=bt[j];
for(int j=1;j<=n;j++) if(bt[cur][j]){hi[cur]=j;break;}
if(!hi[cur]){bt[cur].reset();cur--;continue;}
for(int j=1;j<cur;j++){
if(bt[j][hi[cur]]) bt[j]^=bt[cur];
}
}
for(int i=62;~i;i--) if(s>>i&1){
++cur;
for(int j=1;j<=n;j++) if((a[j]>>i)&1) bt[cur][j]=1;
bt[cur][n+1]=0;
for(int j=1;j<cur;j++) if(bt[cur][hi[j]]) bt[cur]^=bt[j];
for(int j=1;j<=n;j++) if(bt[cur][j]){hi[cur]=j;break;}
if(!hi[cur]){bt[cur].reset();cur--;continue;}
for(int j=1;j<cur;j++){
if(bt[j][hi[cur]]) bt[j]^=bt[cur];
}
}
for(int i=1;i<=cur;i++) ans[hi[i]]=bt[i][n+1];
for(int i=1;i<=n;i++) printf("%d ",2-ans[i]);
return 0;
}