Steven_MengのBlog

可以这么写$w$维前缀和(原文):

1
2
3
4
5
6
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
a[i][j]+=a[i][j-1];
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
a[i][j]+=a[i-1][j];
1
2
3
4
5
6
7
8
9
10
11
12
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
for(int k=1;k<=p;k++)
a[i][j][k]+=a[i-1][j][k];
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
for(int k=1;k<=p;k++)
a[i][j][k]+=a[i][j-1][k];
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
for(int k=1;k<=p;k++)
a[i][j][k]+=a[i][j][k-1];

这样其实是$O(w \times n^w)$,但是如果用容斥原理的话,是$O(2^w \times n^w)$($2^n = \sum _{i=0}^n C_n^i$)。

注意到$n^w$相比$2^w$来说增长很快,所以这个优化很屑吗?

很多情况我们使用的是$n=2$,这样第一种是$O(w \times 2^w)$,第二种是$O(2^w \times 2^w)$。这样差的就很大了。

比如说子集统计,有$m$个大小为$20$的集合$S_i$。($m \le 1000000$)要你求$F(S)=\sum _{i=1}^m [S \in S_i]$,容易看出$w=20,n=2$,代表每个元素选或者不选。

可以写出以下的代码:

1
2
3
4
5
6
7
8
9
10
for(int i=1;i<=1;i++)
for(int j=0;j<=1;j++)
for(int k=0;k<=1;k++)
......
a[i][j][k][][][]+=a[i-1][j][k][][][];
for(int i=0;i<=1;i++)
for(int j=1;j<=1;j++)
for(int k=0;k<=1;k++)
......
a[i][j][k][][][]+=a[i][j-1][k][][][];

但是显然这样又臭又长。

考虑将$a[][][][]$数组后面的$w$维压成$1$维,用二进制表示。

可以这么写:

1
2
3
4
5
for (int i=0;i<w;++i){
for (int j=0;j<max(a[i]);++j){
if (j&(1<<i)) f[j^(1<<i)]=(f[j^(1<<i)]+f[j])%MOD;
}
}

例题

CF449D Jzzhu and Numbers

考虑一个容斥:令$g(x)$为使得这些$a[i_j]$与起来有至少$x$个$1$的方案数。

那么显然有$ans=\sum g(i) (-1)^i$

那么考虑预处理$g(x)$,可以考虑预处理与起来为$status$的方案数$f(status)$(这时$g(x)==\sum _{bitcount(i)==x} f(i)$),那就是要求$status \in a[i_j]$,假设我们有$cnt$个$status \in a[i]$,那么答案就是$2^{cnt}-1$(去掉什么都不选的)

可以用上面说的前缀和方法预处理出来$f(status)$,就做完了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
#define MAXN 1000005
#define MAXM 25
#define MOD 1000000007
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x<<3)+(x<<1)+(ch^'0');
ch=getchar();
}
return x*f;
}
int f[MAXN];
#define lowbit(x) x&-x
int cnt[MAXN],pow2[MAXN];
int main(){
int n=read();
cnt[0]=0;
for (register int i=1;i<MAXN;++i){
cnt[i]=cnt[i^(lowbit(i))]+1;
}
for (register int i=1;i<=n;++i){
int x=read();
f[x]++;
}
pow2[0]=1;
for (register int i=1;i<MAXN;++i){
pow2[i]=(pow2[i-1]*2)%MOD;
}
for (register int i=0;i<MAXM;++i){
for (register int j=0;j<MAXN;++j){
if (j&(1<<i)) f[j^(1<<i)]=(f[j^(1<<i)]+f[j])%MOD;
}
}
int ans=0;
for (register int i=0;i<MAXN;++i){
ans=(ans+(cnt[i]&1?-1:1)*(pow2[f[i]]-1))%MOD;
}
printf("%d\n",(ans%MOD+MOD)%MOD);
return 0;
}

 评论