[笔记] 黎明前的巧克力

给定 $n$ 个数 $a_i$,求把 $n$ 个数分为三个集合,且前两个集合异或值相同的方案数。集合可区分而集合内的元素不可区分。

$1\le a_i,n\le 10^6$

题目链接


思路

首先可以 dp,$f_{i,j}$ 为考虑了前 $i$ 个数,且前两个集合异或值为 $j$ 的方案数:
$$
f_{i,j}=f_{i-1,j}+2f_{i-1,j\otimes a_i}
$$
貌似没法优化了?我们把它写成 FWT 的形式:

$$ \begin{aligned} c_i&=1+2x^{a_i}\\ F_i&=F_{i-1}\otimes c_i \end{aligned} $$

中间那个 $\otimes$ 是异或 FWT。注意到我们 $c_i$ 的项数很少,我们把 $c_i$ 的 FWT 拆开。首先异或 FWT 的意义是:

$$ \operatorname{FWT}(a)_i=\sum_j a_j\cdot (-1)^{|i\&j|} $$

我们发现由于 0 与任何数位与都是 0,所以 $\operatorname{FWT}(c_i)$ 首先每一项都有个 1,然后 2 的符号就是任意了。所以我们知道 $\operatorname{FWT}(c_i)$ 仅由 -1 和 3 组成。那么也就是说我们现在在算很多个这样数组的点积,于是考虑统计 -1 和 3 的个数。我们现在考虑 $i$ 这一位的 -1 和 3 的个数($\operatorname{FWT}(a_{?,i})$)。记这一位所有数和为 $sum_i$,我们注意到,$3cnt_{i,3}-cnt_{i,-1}=sum$,$cnt_{i,-1}=n-cnt_{i,3}$,所以只要知道 $sum_i$ 就可以算 $cnt_{i,3}$ 和 $cnt_{i,-1}$。于是有:

$$ \begin{aligned} sum_i&=\sum_j\operatorname{FWT}(c_j)_i\\ &=\sum_j1+2\cdot (-1)^{|i\&a_j|} \end{aligned} $$

欸我们一看!右边那个东西可以统一用一遍 FWT 弄出来。没了。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// Mivik 2020.12.24
#include <mivik.h>

MI cin;

const int nmax = 1048576;
const int mod = 998244353;

int n, p3[nmax], c[nmax];
inline int pro(int x) { return x >= mod? x - mod: x; }
inline int per(int x) { return x < 0? x + mod: x; }
inline int n1(int v, int p) { return (p & 1)? per(-v): v; }
inline int div2(int x) { return ((x & 1)? x + mod: x) >> 1; }
inline int round_up(int x) { return 1 << (32 - __builtin_clz(x)); }
template<bool rev>
inline void fwt(int *v, int len) {
for (int i = 1, q = 2; i < len; q = (i = q) << 1)
for (int j = 0; j < len; j += q)
for (int k = 0; k < i; ++k) {
const int x(v[j | k]), y(v[i | j | k]);
v[j | k] = pro(x + y);
v[i | j | k] = per(x - y);
if (rev) {
v[j | k] = div2(v[j | k]);
v[i | j | k] = div2(v[i | j | k]);
}
}
}
int main() {
cin > n; int lim = 0;
for (int i = p3[0] = 1; i <= n; ++i) {
p3[i] = pro(p3[i - 1] + pro(p3[i - 1] << 1));
const int x(R); c[x] += 2;
if (x > lim) lim = x;
}
lim = round_up(lim);
fwt<0>(c, lim);
for (int i = 0; i < lim; ++i) {
const int sum(pro(n + c[i]));
const int c3(pro(sum + n) >> 2), cn1(n - c3);
c[i] = n1(p3[c3], cn1);
}
fwt<1>(c, lim);
cout < per(c[0] - 1) < endl;
}
作者

Mivik

发布于

2020-12-24

更新于

2022-11-11

许可协议

评论