1. 程式人生 > 其它 >AT4168 [ARC100C] Or Plus Max

AT4168 [ARC100C] Or Plus Max

高維字首和

Or Plus Max

給你一個長度為 \(2^n\) 的序列 \(a\),每個 \(1\le K\le 2^n-1\),找出最大的 \(a_i+a_j \left( i\ | \ j\le K \ ,i,j \le 2^n\right)\) 並輸出。

第一次接觸這個比較有趣的思想,主要演算法數高維字首和

高維字首和主要用於解決形似 \(\sum_{i\subset S} f(i)\) 的問題,是一種不可多得的人類智慧巧妙演算法。

首先一維字首和不用講,但二維字首和其實還有一種不為人知的形式:

for(int i = 1; i <= n; i ++)
    for(int j = 1; j <= n; j ++) a[i][j] += a[i - 1][j];
for(int i = 1; i <= n; i ++)
    for(int j = 1; j <= n; j ++) a[i][j] += a[i][j - 1];

簡單分析一下還是比較顯然的。然後拓展到三維:

for(int i = 1; i <= n; i ++)
    for(int j = 1; j <= n; j ++)
        for(int k = 1; k <= n; k ++) a[i][j][k] += a[i - 1][j][k];
for(int i = 1; i <= n; i ++)
    for(int j = 1; j <= n; j ++)
        for(int k = 1; k <= n; k ++) a[i][j][k] += a[i][j - 1][k];
for(int i = 1; i <= n; i ++)
    for(int j = 1; j <= n; j ++)
        for(int k = 1; k <= n; k ++) a[i][j][k] += a[i][j][k - 1];

看出規律了對吧,逐次按位考慮,每次微調一維。於是利用位運算模擬維數,我們不難得到 \(n\) 維字首和。

for(int j = 0; j < n; j ++)
    for(int i = 0; i < (1 << n); i ++)
        if(i >> j & 1) a[i] += a[i ^ (1 << j)];

然而基礎演算法是一方面,應用又是一方面。

本題我們按照 \(K\) 的不同來考慮,設 \(A_k=\max_{i|j=k}\{a_i+a_j\}\),那麼 \(ans_k=\max_{1\leq i\leq k}\{A_i\}\)

考慮到 \(\max\) 的區間可交性,不妨將限制放寬,與 ST 表的原理類似,設 \(A_k=\max_{i,j\subset k}\{a_i+a_j\}\)

然後就是比較裸的高維字首和問題,只是把求和變為取 \(\max\) 了,然後每個數代表兩個狀態,即最大和次大值。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

const int N = 1 << 19;
const int INF = 1 << 30;
int n;
struct node{int x, y;} a[N];

int read(){
	int x = 0, f = 1; char c = getchar();
	while(c < '0' || c > '9') f = (c == '-') ? -1 : 1, c = getchar();
	while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
	return x * f;
}

node Merge(node a, node b){
	if(a.x < b.x) swap(a, b);
	node now = a;
	if(a.y < b.x) now.y = b.x;
	return now;
}

int main(){
	n = read();
	for(int i = 0; i < (1 << n); i ++){
		a[i].x = read();
		a[i].y = - INF;
	}
	for(int j = 0; j < n; j ++)
		for(int i = 0; i < (1 << n); i ++)
			if(i >> j & 1) a[i] = Merge(a[i], a[i ^ (1 << j)]);
	int ans = 0;
	for(int i = 1; i < (1 << n); i ++){
		ans = max(ans, a[i].x + a[i].y);
		printf("%d\n", ans);
	}
	return 0;
}