JZOJ 5123. 【NOI2017模擬5.31】diyiti
阿新 • • 發佈:2018-12-12
Description
給定n 根直的木棍,要從中選出6 根木棍,滿足:能用這6 根木棍拼出一個正方形。注意木棍不能彎折。問方案數。
正方形:四條邊都相等、四個角都是直角的四邊形。
Input
第一行一個整數n。
第二行包含n 個整數ai,代表每根木棍的長度。
Output
一行一個整數,代表方案數。
Sample Input
8
4 5 1 5 1 9 4 5
Sample Output
3
Data Constraint
對於20% 的資料,滿足:n ≤ 30
對於40% 的資料,滿足:n ≤ 200
對於60% 的資料,滿足:n ≤ 1000
對於100% 的資料,滿足:n ≤ 5000; 1 ≤ ai ≤ 10^7
Solution
有點複雜的計數題。
首先排序,相同長度的木棍壓成一個,並記錄下有多少個。
這6根木棍構成正方形的情況只有兩種:
- 2+2+1+1,即選出兩根長度都為 的,再選兩根拼起來長度為 的,再選兩根拼起來長度為 的。
- 3+1+1+1,即選出三根長度都為 的,再選三根拼起來長度為 的。
現考慮2+2+1+1的怎麼求。
記
為第
根木棍個數,這個可以預處理。
先列舉兩根長度相等的"1",那麼兩個"2"有以下四種情況:
- aa+aa
先判斷 是不是 的倍數,如果是,就得到 之前 的出現次數 ,那麼對答案貢獻就是 。 - ab+ab
列舉一個 當作b,用一個單調指標找到對應的a的位置 ,為了避免重複要保證 ,對答案貢獻就是 。 - aa+bc
判斷 是否為 的倍數,然後用類似ab+ab的方法統計bc的方案數,與aa的方案相乘即可。 - ab+cd
用類似ab+ab的方法,記錄在ab之前找到多少個cd,統計一下答案就好了。
然後考慮3+1+1+1。
列舉一個
作為3最右邊的那個,那麼3的情況有四種:
- aaa
列舉一個 作為1,統計答案跟上面aa+aa的方法差不多。 - aab
列舉一個 作為1,用一個桶求出 之前 的出現次數,然後就可以統計答案了。 - abb
列舉一個 作為1,用aab的那個桶求出 的出現次數,就可以計算答案。 - abc
列舉一個 作為1,用另一個桶記錄 中所有二元組相加的情況,在這個桶裡找到二元組相加等於 的情況即可計算答案。
不重不漏地統計完上面這些情況,就OK了。
Code
#include <cstdio>
#include <cstring>
typedef unsigned long long ull;
const int N = 5007, A = 20000007;
int n, len, a[N], cnt[N], tmp[N];
int buc[A], buc1[A];
ull ans;
ull C2(int n) { return n * (n - 1) / 2; }
ull C3(int n) { return 1ll * n * (n - 1) * (n - 2) / 6; }
ull C4(int n) { return 1ll * n * (n - 1) * (n - 2) * (n - 3) / 24; }
void sort(int l, int r)
{
if (l >= r) return;
int mid = l + r >> 1;
sort(l, mid), sort(mid + 1, r);
int i = l, j = mid + 1, len = 0;
while (j <= r)
{
while (i <= mid && a[i] <= a[j]) tmp[++len] = a[i++];
tmp[++len] = a[j++];
}
while (i <= mid) tmp[++len] = a[i++];
for (int i = 1; i <= len; i++) a[l + i - 1] = tmp[i];
}
int main()
{
freopen("yist.in", "r", stdin);
freopen("yist.out", "w", stdout);
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", a + i);
sort(1, n);
for (int i = 1; i <= n; i++)
if (a[i] != a[i - 1]) a[++len] = a[i], cnt[len] = 1;
else cnt[len]++;
for (int i = 1; i <= len; i++) //2+2+1+1
{
ull ret = 0, sum = 0;
for (int j = 1; j <= i - 1; j++) if (a[j] + a[j] == a[i]) ans += C4(cnt[j]) * C2(cnt[i]), ret += C2(cnt[j]); //aa+aa
for (int j = 1, k = i - 1; j <= i - 1; j++)
{
while (a[j] + a[k] > a[i] && k > j) k--;
if (a[j] + a[k] == a[i] && k > j)
{
ans += C2(cnt[j]) * C2(cnt[k]) * C2(cnt[i]); //ab+ab
ans += ret * cnt[j] * cnt[k] * C2(cnt[i]); //aa+bc
ans += sum * cnt[j] * cnt[k] * C2(cnt[i]); //ab+cd
sum += cnt[j] * cnt[k];
}
}
}
for (int i = 1; i <= len; i++) //3+1+1+1
{
for (int j = 1; j <= i - 2; j++) buc[a[i - 1] + a[j]] += cnt[i - 1] * cnt[j];
for (int j = i + 1; j <= len; j++) if (a[i] + a[i] + a[i] == a[j]) ans += C3(cnt[j]) * C3(cnt[i]); //a+a+a
for (int j = i + 1; j <= len; j++)
if (a[i] + a[i] < a[j])
ans += C3(cnt[j]) * C2(cnt[i]) * buc1[a[j] - a[i] - a[i]]; //a+b+b
for (int j = i + 1, k = 1; j <= len; j++)
if ((a[j] - a[i]) % 2 == 0)
ans += C3(cnt[j]) * cnt[i] * C2(buc1[(a[j] - a[i]) / 2]);