codeforces round#353 trians and statistic dp+貪心+線段樹
題目描述:有n個車站,第i(1<= i <= n - 1)個車站可以買從i到i+1 , i + 2 , ...,a[i]的火車票,用p(i , j)表示從車站i到車站j最少買多少張車票,問
sum = Σp(i , j) (1 <= i < j <= n)是多少
思路:在車站i處,可以用一張車票到達[i + 1 , a[i]]中的一站,那麼,應選擇m車站再買一張車票,其中m∈[i + 1 , a[i] ]且a[m]最大,為什麼是這樣呢?如圖:
設m'是不同於m的一個車站,a[m'] < a[m] , 可以看到,假設現在人在位置i,
線段( i , a[m'] ]部分通過兩種換乘方式到達所需要的票數一樣;
而線段( a[m'] , a[m] ]的部分, 從m換乘的話兩張票就能到, 但從m'換乘的話兩張票不一定能到;
再看線段( a[m] , n ]部分,假設這個區間中有一點x ,如果從m'換乘最終到此區間需要k張票,那麼從m換乘需要的票數小於等於k張票,因為這兩種換乘方式相比,大於a[i]的部分,藍色的包在了紅色的裡面,所以藍色能有的買票方式,紅的一定能有,但紅色能有的買票方式,藍色的不一定能有;
如此一來,從m處換乘的原因得證。
設dp[i]表示Σp(i , j) (其中i + 1 <= j <= n) ,那麼dp[i] = dp[m] + n - i -( a[i] - m ) , 為什麼呢?
想象從i+1到n每個位置都對應了一張票:
那麼在( i , m ]區間,每個位置p上的票用來從i走到p;
在( m , n ] 區間,每個位置p上的票用來從i走到m,再從m走到p;
但是可以看到,按照上述規則( m , a[i] ] 部分的買票的方式是先從i走到m,再從m走到每個位置, 而事實上,從i走到每個位置只要一張票,因此要減掉 a[i] - m
另外提一句,有在找m的時候,如果遇到多個a[m]相等,找任一個都可以,這是可以從圖上看出來的,很多題解上說要找dp[m]最小的是錯的,因為所有dp[m]一樣大。
#pragma warning(disable:4786) #pragma comment(linker, "/STACK:102400000,102400000") #include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<stack> #include<queue> #include<map> #include<set> #include<vector> #include<cmath> #include<string> #include<sstream> #define LL long long #define FOR(i,f_start,f_end) for(int i=f_start;i<=f_end;++i) #define mem(a,x) memset(a,x,sizeof(a)) #define lson l,m,x<<1 #define rson m+1,r,x<<1|1 using namespace std; const int INF = 0x3f3f3f3f; const int mod = 1e9 + 7; const double PI = acos(-1.0); const double eps=1e-6; const int maxn = 1e5 + 5 ; int a[maxn]; LL dp[maxn] ; struct node { int v , p ; }t[4 * maxn]; void pushup(int l , int r , int x) { if(t[x<<1].v < t[x<<1|1].v){ t[x].v = t[x<<1|1].v; t[x].p = t[x<<1|1].p; } else if(t[x<<1].v > t[x<<1|1].v){ t[x].v = t[x<<1].v; t[x].p = t[x<<1].p; } else{ t[x].v = t[x<<1].v; t[x].p = t[x<<1|1].p ; } return ; } void build(int l , int r , int x) { if(l == r ){ t[x].v = a[l] ; t[x].p = l ; return ; } int m = l + (r - l) / 2 ; build(lson) ; build(rson); pushup(l , r , x) ; } node query(int L , int R , int l , int r , int x ) { if(L == l && R == r){ return t[x]; } int m = l + ( r - l ) / 2 ; if(R <= m) return query(L , R , lson) ; else if( L > m) return query(L , R , rson) ; else{ node ret1 = query(L , m , lson) ; node ret2 = query( m + 1 , R , rson) ; if(ret2.v > ret1.v) return ret2; else return ret1 ; } } int main() { int n ; scanf("%d",&n); for(int i = 1 ; i<= n - 1 ; i++){ scanf("%d",&a[i]); } build(1 , n , 1); LL ans = 1; dp[n - 1] = 1 ; for(int i = n - 2 ; i>= 1 ; i--){ node st = query(i + 1 , a[i] , 1 , n , 1 ) ; int m = st.p ; dp[i] = dp[m] + n - i - (a[i] - m) ; ans += dp[i] ; } printf("%lld\n",ans); return 0; }