1. 程式人生 > 其它 >Codeforces 780 F2. Promising String (hard version) (樹狀陣列)

Codeforces 780 F2. Promising String (hard version) (樹狀陣列)

Codeforces 780 F2. Promising String (hard version) (樹狀陣列)


題目

題目大意是一個字串,每項為- 或者 +,如果-,+個數一樣這個字串是平衡的。你能進行操作:將兩個相鄰的-轉變成+,如果一個字串能通過操作變成平衡的被稱為 “有希望平衡的” 字串。問在給出的字串的所有子串中有多少是“有希望平衡”的子串,輸出數量。

思路

easy version 暴力就能做。對每段子串,sum表示+數量,則len - sum是-數量,只要-數量大於+數量,且差值是3的倍數,則這段是有希望平衡的。暴力算一下就行。
在hard version中 n <= 2e5, 沒法暴力算。我們發現我們要找一段子串[l...r] ,令 val = (r - l + 1) - 2 * (sum[r] - sum[l - 1])

,只要val > 0 && val % 3 == 0則對答案有貢獻。

這道題解法利用一個結論:如果區間[1,i]%3a, 並且區間[1,j]%3a (j < i) 則 區間[j, i] % 3 == 0, 這樣我們只需要遍歷一次1-i區間,求出結果[1,i]%3==a, 然後我們只要能快速計算出前i個字首裡面有幾個結果是a。這裡說的不太清楚,只是一個大概的hint。具體見程式碼,總之就是這個快速求得方法能用樹狀陣列解決。

程式碼

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
#include<stack>
#include<vector>
#include<string>
#include<set>
#include<fstream>
using namespace std;
#define rep(i, a, n) for(int i = a; i <= n; ++ i)
#define per(i, a, n) for(int i = n; i >= a; -- i)
typedef long long ll;
typedef pair<int, int> PII;
typedef pair<ll, int> PLI;
typedef pair<ll, ll> PLL;
const int N = 2e6 + 50;
const int mod = 998244353;
const double Pi = acos(- 1.0);
const ll INF = 1e12;
const int G = 3, Gi = 332748118;
ll qpow(ll a, ll b, ll p) { ll res = 1; while(b){ if(b & 1) res = (res * a) % p; a = (a * a) % p; b >>= 1;} return res % p; }
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }

char s[N];
int n;
ll c[3][N];
ll lowbit(int x){ return x & -x; }

ll ask(ll i, ll x){
	int ans = 0;
	for(; x; x -= lowbit(x)) ans += c[i][x];
	return ans;
}

void add(ll i, ll x,ll y){
	for( ; x<=N ; x += lowbit(x)) c[i][x] += y;
}

void solve(){
    scanf("%d",&n);
    ll n2 = (n + 1) << 1;
    for(int i = 0; i <= n2; ++ i) c[0][i] = c[1][i] = c[2][i] = 0;
    scanf("%s",s + 1);

    ll sum = 0, res = 0;
    add((n + 1) % 3, n + 1, 1);
    for(int i = 1; i <= n; ++ i){
        sum += (s[i] == '+');
        int tp = i - 2 * sum + n + 1;
        res += ask(tp % 3, tp);
        add(tp % 3, tp, 1);
    }

    printf("%lld\n",res);
}

int main() {
    freopen("temp.in", "r", stdin);
    freopen("temp.out", "w", stdout);
    int T;scanf("%d",&T);
    while(T --){
        solve();
    }
    return 0;
}