1. 程式人生 > >尋找符合子序列要求的區間個數 - 牛客

尋找符合子序列要求的區間個數 - 牛客

連結:https://ac.nowcoder.com/acm/contest/217/B
來源:牛客網

題目描述

msc和mcc是一對好朋友,有一天他們得到了一個長度為n的字串s.

這個字串s十分妙,其中只有’m’,’s’和’c’三種字元。

定義s[i,j]表示s中從第i個到第j個字元按順序拼接起來得到的字串。

定義一個字串t的子序列為從t中選出一些位置並且將這些位置上面的字元按順序拼接起來得到的字串。

兩個子序列重合當且僅當存在一個位置x使得兩個子序列同時選擇了位置x。

由於msc和mcc是一對很好很好的好朋友,所以她們希望選擇兩個數字x和y滿足1≤x≤y≤n使得s[x,y]中同時存在兩個**不重合的子序列**使得其中一個是’msc’且另外一個是’mcc’

現在給出n和字串s,問她們可以選出多少對不同的(x,y)。

輸入描述:

第一行一個正整數n,表示字串s的長度。

第二行一個長度為n的字串s,其中s只包含字元’m’,’s’和,’c’。

輸出描述:

一行一個正整數,表示答案。
示例1

輸入

複製
6
mscmcc

輸出

複製
1

備註:

1≤n≤100,000

題意 : 給你一個字串, 尋找有多少個子區間存在 "msc" 以及 "mcc" , 並且要求找到的子序列中沒有共用的字母
思路分析 :
  暴力的想法是直接列舉所有的區間, n ^ 3 的做法,顯然是TLE
  要怎麼優化一下呢 ?
  考慮以每一個字母為起始的情況,去尋找到哪個位置剛好存在一個這樣的兩個子序列,這時可以直接統計答案
  而且因為字母比較少,所有符合要求的最短的序列總共有 8 種,
  再預處理一下從每個位置到 m, s, c 的最近的位置是多少,這樣每次尋找一個子序列就是 O(6)
  因此總體複雜度是 6*8*n
程式碼示例 :
#define ll long long
const ll maxn = 1e5+5;
const ll inf = 0x3f3f3f3f;

ll n;
char s[maxn];
ll f[maxn][4]; // m - 1 , s - 2, c - 3
string str[10];

void init() {
    ll p1 = 0, p2 = 0, p3 = 0;
    char ch1='\0', ch2='\0', ch3='\0';
    for(ll i = 1; i <= n; i++){
        while((p1 <= i || ch1 != 'm') && p1 <= n) {
            p1++; ch1 = s[p1];
        }
        while((p2 <= i || ch2 != 's') && p2 <= n) {
            p2++; ch2 = s[p2];
        }
        while((p3 <= i || ch3 != 'c') && p3 <= n) {
            p3++; ch3 = s[p3];
        }
        f[i][1] = p1, f[i][2] = p2, f[i][3] = p3;
    }
    str[1] = "mscmcc";  str[5] = "mmccsc";
    str[2] = "msmccc";  str[6] = "mccmsc";
    str[3] = "mmsccc";  str[7] = "mcmscc";
    str[4] = "mmcscc";  str[8] = "mcmcsc";
}

ll fun(ll pos, ll p){
    ll pt, start = 0;
    if (s[pos] == 'm') start = 1;
    for(ll i = start; i < 6; i++){
        if (str[p][i] == 'm') pt = 1;
        else if (str[p][i] == 's') pt = 2;
        else pt = 3; 
        pos = f[pos][pt]; 
        if (pos > n) return inf;
    }
    return pos;
}

void solve() {
    ll ans = 0;
    
    for(ll i = 1; i <= n; i++){
        ll len = inf;
        for(ll j = 1; j <= 8; j++){
            len = min(len, fun(i, j));
        }
        if (len > n) break;
        ans += n-len+1; 
    }   
    printf("%lld\n", ans);
}

int main() {
    cin >> n;
    scanf("%s", s+1);
    init();
    solve();
    return 0;
}