daimayuan#884. 最長上升子序列計數 (線段樹優化dp)
阿新 • • 發佈:2022-05-19
http://oj.daimayuan.top/problem/884
- f[i] 表示以a[i]結尾的最長上升子序列,cnt[i]表示以a[i]結尾的最長上升子序列的個數。 可以n方轉移: f[i] = max(f[j] + 1, f[i]); cnt[i] += cnt[j] | (f[i] == f[j] + 1)
- 發現f的轉移就是在i之前找小於a[i]的最大f[i],這個可以用權值線段樹處理
- 單點修改,區間查詢,只需要考慮如何合併
Tr comp( Tr a, Tr b) { if(!a.len) return b; //特判0 if(!b.len) return a; if( a.len == b.len ) return Tr{ a.len, (a.cnt + b.cnt) % mod }; if( a.len > b.len ) return a; return b; }
#include<bits/stdc++.h> using namespace std; #define IOS ios::sync_with_stdio(false) ,cin.tie(0), cout.tie(0); //#pragma GCC optimize(3,"Ofast","inline") #define ll long long //#define int long long const int N = 4e5 + 6; const int M = 2e6 + 6; const ll P = 1e9 + 7; const int INF = 0x3f3f3f3f; const ll LNF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double PI = acos(-1.0); struct Tr { ll len, cnt; } tree[N << 2], f[N]; Tr comp( Tr a, Tr b) { if(!a.len) return b; if(!b.len) return a; if( a.len == b.len ) return Tr{ a.len, (a.cnt + b.cnt) % mod }; if( a.len > b.len ) return a; return b; } void build( int l, int r, int rt ) { if( l == r ) { tree[rt] = Tr{0, 1}; return; } int mid = l + r >> 1; build( l, mid , rt << 1 ); build( mid + 1, r, rt << 1 | 1 ); tree[rt] = comp( tree[rt << 1], tree[rt << 1 | 1] ); } void modify ( int pos, Tr val, int l, int r, int rt ) { if( l == r ) { if( tree[rt].len == val.len ) tree[rt].cnt = (tree[rt].cnt + val.cnt) % mod; else tree[rt] = val; return; } int mid = l + r >> 1; if( pos <= mid ) modify( pos, val, l, mid, rt << 1 ); else modify( pos, val, mid + 1, r, rt << 1 | 1 ); tree[rt] = comp( tree[rt << 1], tree[rt << 1 | 1] ); } Tr query( int a, int b, int l, int r, int rt ) { if( b < l || a > r) return Tr{0, 1}; if( l >= a && r <= b ) { return tree[rt]; } int mid = l + r >> 1; return comp( query( a, b, l, mid, rt << 1 ), query( a, b, mid + 1, r, rt << 1 | 1 ) ); } int a[N]; int main() { IOS int n; cin >> n; vector<int> ve; for ( int i = 1; i <= n; ++ i ) { cin >> a[i]; ve.push_back(a[i]); } sort(ve.begin(), ve.end()); ve.erase( unique(ve.begin(), ve.end()), ve.end()); for ( int i = 1; i <= n ;++ i ) { a[i] = lower_bound( ve.begin(), ve.end(), a[i]) - ve.begin() + 1; } int limit = ve.size(); build( 0, limit, 1 ); for ( int i = 1; i <= n; ++ i ) { f[i] = query( 0, a[i] - 1, 0, limit, 1); f[i].len += 1; Tr de1 = f[i]; modify( a[i], f[i], 0, limit, 1); } int mx = 1, cnt = 0 ; for ( int i = 1; i <= n; ++ i ) { if( f[i].len > mx) { mx = f[i].len, cnt = f[i].cnt; } else if( f[i].len == mx) { cnt += f[i].cnt; cnt %= mod; } } cout << cnt << '\n'; return 0; }