1. 程式人生 > 其它 >AtCoder Beginner Contest 234 G - Divide a Sequence

AtCoder Beginner Contest 234 G - Divide a Sequence

傳送門

題目描述

把一個長度為 \(N\) 的陣列 \(A\), 分為幾個連續的子序列 \(B_1, B_2, ... , B_k\),有 \(2^{N-1}\) 種劃分方式

先給出陣列 \(A\) 求出所有劃分方式的價值之和,並對 \(998244353\) 取模.

對於一種劃分方式 \(B_1, B_2, ... , B_k\) 的價值為 \({\textstyle \prod_{i=1}^{k}} (\max(B_i) - \min(B_i))\)

對於一個子序列\(B_i = (B_{i,1}, B_{i,1}, ..., B_{i,j} )\) ,其中最大的元素為 \(\max(B_i)\)

,最小的元素為 \(\min(B_i)\)

思路

首先我們可以想到一個 \(n^2\)\(DP\)
定義 \(f_i\) 為前 \(i\) 個數字的所有劃分方式的價值之和
那麼可以得到轉移方程 \(f_i = \sum_{j=1}^{i-1} f_j \times (\max(a_{j+1},.., a_{i}) - \min(a_{j+1},.., a_{i}))\)
(不包含 \(f_{i-1}\) 的原因是 單個數字的價值為 \(0\))
通過倒序遍歷可以維護出 \(\max, \min\) 因此複雜度為 \(O(n^2)\)

這麼寫肯定是會 \(TLE\)

因此我們考慮如何去優化

我們可以把 \(\max\)\(\min\) 的貢獻單獨去考慮 (這也是常用的一個套路)

首先分析 \(\max\)

轉移方程 \(f_i = \sum_{j=1}^{i-1} f_j \times (\max(a_{j+1},.., a_{i}) - \min(a_{j+1},.., a_{i}))\)
可以轉換為 \(f_i = \sum_{j=1}^{i-1} f_j \times \max(a_{j+1},.., a_{i}) - \sum_{j=1}^{i-1} f_j \times \min(a_{j+1},.., a_{i})\)

我們可以用 \(m_i\) 代表當前的 \(\max\)

的價值和
那麼 \(m_i\)\(m_{i-1}\) 是否存在什麼聯絡呢? 答案是存在的

我們考慮的是最大值對答案的貢獻
對於 \(f_i\) 來說,產生貢獻的最大值一定是單調上升的一些數字,因為我們 \(f_i\) 進行轉移的時候是根據最後一段子序列進行分類的
\(f_i\) 是由 \(f_{i-2}, f_{i-3}, ..., f_{1}, f_{0}\) 轉移過來的,我們字尾最大值的貢獻一定是在一個連續的區間,並且是單調遞增的
那麼我們就可以利用一個單調棧,每次彈出棧頂的時候減去棧頂的所有貢獻,然後在最後加上當前位置的貢獻即可

最小值同理

CODE
/********************
Author:  Nanfeng1997
Contest: AtCoder - AtCoder Beginner Contest 234
URL:     https://atcoder.jp/contests/abc234/tasks/abc234_g
When:    2022-03-16 10:28:26

Memory:  1024MB
Time:    2000ms
********************/
#include  <bits/stdc++.h>
using namespace std;
using LL = long long;
const int MOD = 998244353;

inline int mod(int x) {return x >= MOD ? x - MOD : x;}

inline int ksm(int a, int b) {
  int ret = 1; a = mod(a);
  for(; b; b >>= 1, a = 1LL * a * a % MOD) if(b & 1) ret = 1LL * ret * a % MOD;
  return ret;
}

template<int MOD> 
struct modint {
  int x;
  modint() {x = 0; }
  modint(int y) {x = y;}
  inline modint inv() const { return modint{ksm(x, MOD - 2)}; }
  explicit inline operator int() { return x; }
  friend inline modint operator + (const modint &a, const modint& b) { return modint(mod(a.x + b.x)); }
  friend inline modint operator - (const modint &a, const modint& b) { return modint(mod(a.x - b.x + MOD)); }
  friend inline modint operator * (const modint &a, const modint& b) { return modint(1ll * a.x * b.x % MOD); }
  friend inline modint operator - (const modint &a) { return modint(mod(MOD - a.x)); }
  friend inline modint& operator += (modint &a, const modint& b) { return a = a + b; }
  friend inline modint& operator -= (modint &a, const modint& b) { return a = a - b; }
  friend inline modint& operator *= (modint &a, const modint& b) { return a = a * b; }
  inline int operator == (const modint &b) { return x == b.x; }
  inline int operator != (const modint &b) { return x != b.x; }
  inline int operator < (const modint &a) { return x < a.x; }
  inline int operator <= (const modint &a) { return x <= a.x; }
  inline int operator > (const modint &a) { return x > a.x; }
  inline int operator >= (const modint &a) { return x >= a.x; }
};

typedef modint<MOD> mint;

inline mint ksm(mint a, int b) {
  mint ret = 1; 
  for(; b; b >>= 1, a = a * a ) if(b & 1) ret = ret * a ;
  return ret;
}

const int N = 3e5 + 10;

int n;
int a[N], s1[N], s2[N];
mint dp[N], tr[N];

void add(int a, mint k) {
  while(a <= n) tr[a] += k, a += a & -a;
}

mint query(int x) {
  mint ret = 0; if(x >= 0) ret += 1; //樹狀陣列的邊界是1, 因此我們手動加上0處的貢獻
  while(x > 0) ret += tr[x], x -= x & -x;
  return ret;
}

mint ask(int l, int r) {return query(r) - query(l - 1); }

void solve() {
  scanf("%d", &n);
  for(int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
  int t1 = 0, t2 = 0;
  mint mx = 0, mi = 0; 

  //dp[0] = 1, 因為我們進行轉移的時候是乘法,乘法的么元是 1
  //mx 是最大值的貢獻 
  //mi 是最小值的貢獻
  
  for(int i = 1; i <= n; i ++ ) {

    while(t1 && a[s1[t1]] <= a[i]) {
      int t = s1[t1 --];
      mint ret = ask(s1[t1], t - 1);
      mx = mx - ret * a[t];
    }

    mx = mx + ask(s1[t1], i) * a[i];
    s1[++ t1] = i;
    
    while(t2 && a[s2[t2]] >= a[i]) {
      int t = s2[t2 --];
      mint ret = ask(s2[t2], t - 1);
      mi = mi - ret * a[t];
    }

    mi = mi + ask(s2[t2], i) * a[i];
    s2[++ t2] = i;

    dp[i] = mx - mi;
    add(i, dp[i]);
  }
  printf("%d", (int)dp[n]);

}

int main() {
  // ios::sync_with_stdio(false);
  // cin.tie(nullptr);
  int T = 1; //cin >> T;
  while(T --) solve();

  return 0;
}