[ZJOI2009]對稱的正方形 (二維雜湊+二分)
阿新 • • 發佈:2020-10-13
題意
求矩陣中上下對稱且左右對稱的正方形子矩陣的個數。
思路
快速對比矩陣是否上下對稱或者左右對稱可以考慮對二維矩陣雜湊。
雜湊之後通過列舉正方形中心點位置,對正方形邊長進行二分。
在列舉正方形中心點時,需考慮正方形邊長分別為奇偶的情況,即中心點為格子交接點還是格子中心點。
二維矩陣雜湊值維護:
\(hash[x][y] = hash[x][y-1] \times base1 + hash[x - 1][y] \times base2 + value[x][y]\)
二維矩陣雜湊值求解:
\(ans = hash[x][y] - hash[x - len_x][y] \times fac1[len_x] - hash[x][y -len_y] \times fac2[len_y] + hash[x - len_x][y - len_y] \times fac1[len_x] \times fac2[len_y]\)
其中 \(fac1\), \(fac2\) 分別為 \(base1\), \(base2\) 的 \(len\) 次方
程式碼
#include <cstdio> #include <algorithm> #include <iostream> using namespace std; typedef unsigned long long ull; typedef long long ll; const int maxn = 2010; const int base1 = 87; const int base2 = 31; int n, m; int mp[maxn][maxn]; int lr[maxn][maxn]; int up[maxn][maxn]; ull hash_mp[maxn][maxn]; ull hash_lr[maxn][maxn]; ull hash_up[maxn][maxn]; ull fac1[maxn], fac2[maxn]; bool check(int x, int y, int len) { if (x > n || y > m) return 0; if (x < len || y < len) return 0; ull ans1 = hash_mp[x][y] - hash_mp[x - len][y] * fac2[len] - hash_mp[x][y - len] * fac1[len] + hash_mp[x - len][y - len] * fac1[len] * fac2[len]; int cow_y = m - (y - len); ull ans2 = hash_lr[x][cow_y] - hash_lr[x - len][cow_y] * fac2[len] - hash_lr[x][cow_y - len] * fac1[len] + hash_lr[x - len][cow_y - len] * fac1[len] * fac2[len]; if (ans1 != ans2) return 0; int row_x = n - (x - len); ull ans3 = hash_up[row_x][y] - hash_up[row_x - len][y] * fac2[len] - hash_up[row_x][y - len] * fac1[len] + hash_up[row_x - len][y - len] * fac1[len] * fac2[len]; if (ans1 != ans3) return 0; return (ans1 == ans3 && ans2 == ans3); } void solve() { scanf("%d%d", &n, &m); for (int i = 1; i <= n; ++i) { for (int j = 1; j <= m; ++j) scanf("%d", &mp[i][j]); } for (int i = 1; i <= n; ++i) { for (int j = 1; j <= m; ++j) { lr[i][j] = mp[i][m - j + 1]; up[i][j] = mp[n - i + 1][j]; } } fac1[0] = fac2[0] = 1; for (int i = 1; i <= n; ++i) fac1[i] = fac1[i - 1] * base1; for (int i = 1; i <= m; ++i) fac2[i] = fac2[i - 1] * base2; for (int i = 1; i <= n; ++i) { for (int j = 1; j <= m; ++j) { hash_mp[i][j] = hash_mp[i][j - 1] * base1 + mp[i][j]; hash_lr[i][j] = hash_lr[i][j - 1] * base1 + lr[i][j]; hash_up[i][j] = hash_up[i][j - 1] * base1 + up[i][j]; } } for (int i = 1; i <= n; ++i) { for (int j = 1; j <= m; ++j) { hash_mp[i][j] += hash_mp[i - 1][j] * base2; hash_lr[i][j] += hash_lr[i - 1][j] * base2; hash_up[i][j] += hash_up[i - 1][j] * base2; } } ll ans = 0; int R = max(n, m); for (int i = 1; i <= n; ++i) { for (int j = 1; j <= m; ++j) { int l = 1, r = R; while (l <= r) { int mid = l + r >> 1; if (check(i + mid, j + mid, mid << 1 | 1)) { l = mid + 1; } else { r = mid - 1; } } // cout << "~ " << l << " " << r << endl; ans += r; } } for (int i = 1; i <= n; ++i) { for (int j = 1; j <= m; ++j) { int l = 1, r = R; while (l <= r) { int mid = l + r >> 1; if (check(i + mid, j + mid, mid << 1)) { l = mid + 1; } else { r = mid - 1; } } // cout << "~~ " << i << " " << j << " " << l << " " << r << endl; ans += r; } } printf("%lld\n", ans + n * m); } int main() { int T = 1; // scanf("%d", &T); while (T--) solve(); return 0; }