Matrix Equation(高斯消元)
技術標籤:數論
2020icpc-濟南站 A-Matrix Equation(高斯消元)
題意: 先給定了兩個 n × n n \times n n×n 的 01 01 01 矩陣 A A A 和 B B B,現在有兩種運算
×
\times
× :
D
n
×
n
=
A
n
×
n
×
B
n
×
n
D_{n \times n} = A_{n \times n} \times B_{n \times n}
Dn×n=An×n×Bn×
⊙
⊙
⊙ :
D
n
×
n
=
A
n
×
n
⊙
B
n
×
n
D_{n \times n} = A_{n \times n} ⊙ B_{n \times n}
Dn×n=An×n⊙Bn×n 表示連個矩陣點乘。即
D
i
j
=
A
i
j
∗
B
i
j
D_ij = A_{ij} * B_{ij}
現在要求 A × C = B ⊙ C A \times C = B\ ⊙\ C A×C=B⊙C ,求有多少 01 01 01 矩陣 C C C 滿足條件。
思路:
要滿足等式,即滿足
不知道為什麼,在Typora裡寫好的公式,複製過來就是顯示不了,直接圖片了
當
j
j
j 確定一個值時,此時對於
k
k
k 從
1
1
1 到
n
n
n ,可以列出
n
n
n 個同餘方程組。高斯消元得到自由元個數
x
x
x 後,矩陣
C
C
C 中 第
j
j
j 列中的可能組數即為
2
x
2^x
2x 中,每一列都算出來有多少組,最後累加即可。當有一組無解時,答案就直接記為
0
0
程式碼:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 210;
const ll mod = 998244353;
int a[maxn][maxn];//增廣矩陣
int x[maxn];//解集
int freeX[maxn];//自由變元
ll qpow(ll a, ll b){
ll ans = 1;
while(b > 0){
if(b & 1) ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans;
}
int Gauss(int equ, int var) {
for(int i = 0; i <= var; i++) {
x[i] = 0;
freeX[i] = 0;
}
int col = 0;//當前處理的列
int num = 0;//自由變元的序號
int row;//當前處理的行
for(row = 0; row < equ && col < var; row++, col++){//列舉當前處理的行
int maxRow = row;//當前列絕對值最大的行
for(int i = row + 1; i < equ; i++){//尋找當前列絕對值最大的行
if(abs(a[i][col]) > abs(a[maxRow][col]))
maxRow = i;
}
if(maxRow != row){//與第row行交換
for(int j = row; j < var + 1; j++)
swap(a[row][j], a[maxRow][j]);
}
if(a[row][col] == 0){//col列第row行以下全是0,處理當前行的下一列
freeX[num++] = col;//記錄自由變元
row--;
continue;
}
for(int i = row + 1; i < equ; i++){
if(a[i][col] != 0){
for(int j = col; j < var + 1; j++){//對於下面出現該列中有1的行,需要把1消掉
a[i][j] ^= a[row][j];
}
}
}
}
for(int i = row; i < equ; i++)
if(a[i][col] != 0)
return -1;
int temp = var - row;//自由變元有var-row個
if(row < var)//返回自由變元數
return temp;
return 0;
}
int A[maxn][maxn], B[maxn][maxn];
int main() {
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
scanf("%d", &A[i][j]);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
scanf("%d", &B[i][j]);
ll ans = 1;
for (int j = 0; j < n; j++) {
for (int i = 0; i < n; i++) {
for (int k = 0; k < n; k++) {
a[i][k] = A[i][k];
}
}
for (int i = 0; i < n; i++) {
a[i][i] = (A[i][i] - B[i][j] + 2) % 2;
}
int r = Gauss(n, n);
if (r == -1) {
ans = 0;
break;
}
ans *= qpow(2, r);
ans %= mod;
}
printf("%lld\n", ans);
return 0;
}