1. 程式人生 > >高維字首和

高維字首和

高維字首和就是求關於一個集合子集(或超集)的狀態的和

牛客有一道題寫的很好
傳送門
題面就已經說明了高維字首和的原理
連結:https://ac.nowcoder.com/acm/contest/167/C
來源:牛客網

對於一個一維陣列求部分和,可以使用如下程式碼

for (int i = 1; i <= n; i++) {
    a[i] += a[i - 1];
}

對於一個二維陣列求部分和,可以使用如下程式碼

for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= n; j++) {
        a[
i][j] += a[i - 1][j] + a[i][j - 1] - a[i - 1][j - 1]; } }

或如下程式碼

for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= n; j++) {
        a[i][j] += a[i][j - 1]
    }
}
for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= n; j++) {
        a[i][j] += a[i - 1][j]
    }
}

第二份程式碼看起來更麻煩更慢,來考慮一下三維的情況。

for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= n; j++) {
        for (int k = 1; k <= n; k++) {
            a[i][j][k] += a[i][j][k - 1] + a[i][j - 1][k] + a[i - 1][j][k];
            a[i][j][k] -= a[i][j - 1][k - 1] + a[i - 1][j - 1][k] + a[i - 1][j][k - 1];
            a[i][j]
[k] += a[i - 1][j - 1][k - 1]; } } }

或如下程式碼

for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= n; j++) {
        for (int k = 1; k <= n; k++) {
            a[i][j][k] += a[i][j][k - 1];
        }
    }
}
for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= n; j++) {
        for (int k = 1; k <= n; k++) {
            a[i][j][k] += a[i][j - 1][k];
        }
    }
}
for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= n; j++) {
        for (int k = 1; k <= n; k++) {
            a[i][j][k] += a[i - 1][j][k];
        }
    }
}

第二份程式碼就不一定更慢了(第二份複雜度大約3n3,第一份複雜度大概8n3)
隨著維度更高,第一份程式碼容斥時項數越來越多,而第二份只是多一次遍歷整個陣列,優勢越來越大。
同樣的思路能不能推廣到更高維的情況呢?

最後的解決方法就是高維字首和,從低到高列舉位數,然後列舉從 0 0 n 1 n-1 的所有元素
核心程式碼如下:

for(int i=0;(1<<i)<n;i++)
		for(int j=0;j<n;j++)
			if(j&(1<<i)) a[j]+=a[j^(1<<i)];

複雜度 O ( n × 2 n ) O(n\times 2^n)
完整程式碼:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;

inline int rd(){
	int x=0,f=1;char c=' ';
	while(c<'0' || c>'9') f=c=='-'?-1:1,c=getchar();
	while(c<='9' && c>='0') x=x*10+c-'0',c=getchar();
	return x*f; 
}

const int maxn=(1<<20)+5;
int n,a[maxn],sum[maxn];

int main(){
	n=rd();
	for(int i=0;i<n;i++) a[i]=rd();
	for(int i=0;(1<<i)<n;i++)
		for(int j=0;j<n;j++)
			if(j&(1<<i)) a[j]+=a[j^(1<<i)];
	for(int i=0;i<n;i++) printf("%d\n",a[i]);
	return 0;
}