P2345 奶牛集會(樹狀陣列or模擬)
題目背景
MooFest, 2004 Open
題目描述
約翰的N 頭奶牛每年都會參加“哞哞大會”。哞哞大會是奶牛界的盛事。集會上的活動很
多,比如堆乾草,跨柵欄,摸牛仔的屁股等等。它們參加活動時會聚在一起,第i 頭奶牛的座標為Xi,沒有兩頭奶牛的座標是相同的。奶牛們的叫聲很大,第i 頭和第j 頭奶牛交流,會發出max{Vi; Vj}×|Xi − Xj | 的音量,其中Vi 和Vj 分別是第i 頭和第j 頭奶牛的聽力。假設每對奶牛之間同時都在說話,請計算所有奶牛產生的音量之和是多少。
輸入輸出格式
輸入格式:
• 第一行:單個整數N,1 ≤ N ≤ 20000
• 第二行到第N + 1 行:第i + 1 行有兩個整數Vi 和Xi,1 ≤ Vi ≤ 20000; 1 ≤ Xi ≤ 20000
輸出格式:
• 單個整數:表示所有奶牛產生的音量之和
輸入輸出樣例
輸入樣例#1: 複製
4
3 1
2 5
2 6
4 3
輸出樣例#1: 複製
57
說明
樸素O(N2)
類似於歸併排序的二分O(N logN)
樹狀陣列O(N logN)
思路:這個題要說思路還是不是很難的,不過容易出錯,為了防止出錯以及後續看不懂,我們先來一個簡單的模擬解法,程式碼如下
#include<iostream> #include<algorithm> using namespace std; const int maxn = 20010; long long n, tempSum, ans; struct Node { int x, v; }node[maxn]; int cmp(Node a, Node b) { return a.v > b.v; } int main() { cin >> n; for (int i = 1; i <= n; i++) { cin >> node[i].v >> node[i].x; } sort(node + 1, node + n + 1, cmp); for (int i = 1; i <= n; i++) { long long sum = 0; for (int j = i + 1; j <= n; j++) { if (node[j].x <= node[i].x) { sum += node[i]. x - node[j].x; } else { sum += (node[j].x - node[i].x); } } ans += node[i].v * sum; } cout << ans << endl; }
是不是非常簡單,要是你只滿足於做出這道題,那麼下面你就可以不用看了,若你還想進一步優化這個題目,那麼你可以繼續看看下面的樹狀陣列的講解
首先,根據上面我們模擬的做法,我們知道其實求的就是該點的V值乘以V值比他小的點到該點的橫座標的和,將這個過程執行n次並且把每次結果相加就可以得到最後的結果,在模擬的過程中,我們是把V值從大到小排列,那麼在樹狀陣列中,我們使用從小到大排列
我們用c【】來表示該樹狀數組裡面的值,也就是這個區間內X值的和,num【】陣列表示這區間內的點,sum表示X的座標小於等於傳入引數的和,cnt表示X的座標小於等於傳入引數的個數,因為我們是按V值從小到大排序,那麼我們取出的每一個Node可以和它之前的Node進行運算,然後運算結果有倆種情況需要討論:
1.之前的結點的X值比當前的節點的X值小,那麼 ans = 當前節點的V值(node【i】.v)乘以所有之前所有結點到當前結點的X值的和,所以ans = node【i].v * (node[i].x * cnt- sum);解釋一下就是當前節點的值已經在運算過程中抵消了,我們是這樣運算的,共有cnt個結點的X值小於當前的X值,那麼我們求他們到當前X值的距離可以先算原點到當前X值的總距離然後減去每個點到原點的總距離即可
2.之前的節點的X值比當前的節點的X值大,那麼我們每次用tot減去比節點X小的節點的座標的和,得到的就是所有比當前節點X的值大的和,然後減去比當前節點X大的節點的個數乘以當前的節點的X值,也就計算出比X值大的節點到X的總距離,稍微腦袋裡模擬一下便知道了,tot表示之前出現的點的X座標的和
#include<iostream>
#include<algorithm>
using namespace std;
#define lowbit(i) ((i) & (-i))
const int maxn = 20010;
struct Node{
int v, x;
}node[maxn];
int cmp(Node a, Node b) {
return a.v < b.v;
}
long long n, sum, cnt, ans, tot;
int c[maxn], num[maxn];
void update(int x, int v) {
for (int i = x; i < maxn; i += lowbit(i)) {
c[i] += v;
num[i] += 1;
}
}
void getSum(int x) {
for (int i = x; i > 0; i -= lowbit(i)) {
sum += c[i];
cnt += num[i];
}
}
int main() {
cin >> n;
for(int i = 1; i <= n; i++) {
cin >> node[i].v >> node[i].x;
}
sort(node + 1, node + 1 + n, cmp);
for (int i = 1; i <= n; i++) {
sum = 0, cnt = 0;
update(node[i].x, node[i].x);
getSum(node[i].x);
tot += node[i].x;
ans += node[i].v * (cnt * node[i].x - sum + tot - sum - node[i].x * (i - cnt));
}
cout << ans << endl;
return 0;
}