1. 程式人生 > >Atcoder Beginner 077 C題題解

Atcoder Beginner 077 C題題解

題意:
有三個陣列 a[]、 b[]、c[],有一個數n表示這三個陣列的元素個數,從這三個數組裡各挑一個數,組成三元組<a[i],b[j],c[k]>,且要求,a[i] < b[j] < c[k]。問能組成多少個這樣的三元組。
注意一點,比如 n = 3時,
三個陣列分別為:
1  2  3
1  2  3
1  2  3
答案是27,不是1,因為儘管只能構成三元組<1,2,3>,但是每次選取的位置是不盡相同的。這就是大致題意。

思路:
首先我對這三個陣列,均進行sort升序排序,然後我用upper_bound( )函式處理出,針對b陣列的每個元素,c陣列中比它大的元素的個數。
這裡要注意三點,
第一點就是,upper_bound( )函式是一個封裝好的二分,所以複雜度和二分相同。
第二點就是,upper_bound( )函式的功能是,返回一個數組內,第一個大於指定元素的位置。比如說,我要在c陣列內找到,第一個,比b陣列內pos這個位置的元素,大的位置,即第一個比b[pos]大的位置,我就只需要
int POS = upper_bound(c + 1,c + n + 1,b[pos]) - c;
POS即為所求,由於我編碼習慣,我就經常初始化c[n + 1] = Inf,防止一些細節錯誤,在哪個數組裡找,就初始化陣列末尾元素tmp[n]的下一項tmp[n + 1]為無窮大可以了。
第三點就是,明確了upper_bound( )函式的功能,我要處理出,c陣列內比當前b[i]大的元素個數,而不是位置,所以這個個數cnt,就是n - POS + 1,另外來一個數組map存每次的cnt即可。
接下來是重點,我map陣列,每個map[i],對應的是,在c陣列中,比當前的b[i]大的元素的個數。
然後,舉個例子,
12  13  14
15  16  17
17  18  25
我for迴圈對每個a[i]進行一次upper_bound( ),找出b陣列內第一個大於a[i]的位置,那我這個位置的元素,其後面的元素一定都比它大,因為我sort排序了。所以後面的所有情況都要算進最後的ans,那麼我就不能單獨累加每個b[i]所對應的map[i]了,而是要記一個字尾和陣列sumbk,初始化sumbk[n + 1] = 0,然後相當於一個逆向的字首和,就這麼每次記一個字尾就ok了,這樣所有情況就都算進去了,對每個a[i],我upper_bound( )在b陣列內找出第一個大於它的位置p,那麼我p用字尾維護出來的sumbk[p],就是當前b[p]後面能形成的所有情況。這裡有點只可意會不可言傳的感覺了,描述的比較抽象,如果這裡聽的不太懂,仔細理解一下我AC程式碼,就知道字尾和陣列的作用了,每次把當前位置p對應的字尾和sumbk[p]累加在ans上就可以了。

核心程式碼如下:

for(int i = 1; i <= n; i++) {
        mp[i] = n - (upper_bound(c + 1, c + n + 1, b[i]) - c) + 1;
    }
    sumbk[n + 1] = 0;
    for(int i = n; i >= 1; i--) {
        sumbk[i] = sumbk[i + 1] + mp[i];
    }
    for(int i = 1; i <= n; i++) {
        int pos = upper_bound(b + 1, b + n + 1, a[i]) - b;
        ans += sumbk[pos];
    }



我的AC程式碼:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxx = 1e5 + 7;
const int Inf = 1 << 30;
int a[maxx], b[maxx], c[maxx];
ll mp[maxx];
ll sumbk[maxx];
int n;
ll ans;

int main() {
    scanf("%d", &n);
    for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
    for(int i = 1; i <= n; i++) scanf("%d", &b[i]);
    for(int i = 1; i <= n; i++) scanf("%d", &c[i]);
    sort(a + 1, a + n + 1);
    sort(b + 1, b + n + 1);
    sort(c + 1, c + n + 1);
    b[n + 1] = Inf, c[n + 1] = Inf;
    for(int i = 1; i <= n; i++) {
        mp[i] = n - (upper_bound(c + 1, c + n + 1, b[i]) - c) + 1;
    }
    sumbk[n + 1] = 0;
    for(int i = n; i >= 1; i--) {
        sumbk[i] = sumbk[i + 1] + mp[i];
    }
    for(int i = 1; i <= n; i++) {
        int pos = upper_bound(b + 1, b + n + 1, a[i]) - b;
        ans += sumbk[pos];
    }
    printf("%lld\n", ans);
}