「快速沃尔什变换」学习笔记

想起来这玩意看懂了之后一直没写……
要不是 WJJ 今天问了我我还真就遗忘在任务列表里了(逃

思想

考虑类比 FFT 处理多项式乘法的方式,可以尝试设计一种变换,使得其能够将位运算卷积转化为对应位置乘积。

处理位运算卷积

按位或

对于序列 a,b,考虑序列 c,其中 ci=j|k=iajbk

考虑一个对于序列的变换 FWTor(a)FWTor(a)i=j|iaj

则注意到 FWTor(c)i=(j|k)|i=iajbk=j|ik|iajbk=FWTor(a)iFWTor(b)i

这显然正是我们设计这个变换而所期望得到的。

接下来考虑如何快速实现这个变换(O(n2) 肯定是接受不了的)。
考虑分治,每次将 a 数组按下标在二进制下的最高位分,即对半分。
设在 a 中下标最高位为 0 的部分为 a0,为 1 的部分为 a1
则由于只有 01 产生贡献,而 0,1 均对 1 产生贡献,有 FWTor(a)=merge(FWTor(a0),FWTor(a0)+FWTor(a1))

其中两个序列的加法意义为对应位置相加,merge 为将两个序列首尾相接形成一个新的序列。
如此即可分治。

按位与

对于序列 a,b,考虑序列 c,其中 ci=j&k=iajbk

类似地,考虑 FWTand(a)i=j&i=iaj

显然 FWTand(c)i=FWTand(a)iFWTand(b)i

类似地,0,1 都对 0 有贡献但只有 11 有贡献,故有 FWTand(a)i=merge(FWTand(a0)+FWTand(a1),FWTand(a1))

按位异或

对于序列 a,b,考虑序列 c,其中 ci=jxork=iajbk

定义 ab=popcount(a&b)mod2,其中 popcount(x) 表示 x 在二进制表示下 1 的个数。
则显然有 (ij)xor(ik)=(jxork)i

考虑 FWTxor(a)i=ij=0ajij=1aj

则有 FWTxor(a)iFWTxor(b)i=(ij=0ajij=1aj)(ik=0bkik=1bk)=(ij=0ik=0ajbk+ij=1ik=1ajbk)(ij=0ik=1ajbk+ij=1ik=0ajbk)=(ij)xor(ik)=0ajbk(ij)xor(ik)=1ajbk=(jxork)i=0ajbk(jxork)i=1ajbk=FWTxor(c)i

分治合并时考虑最高位会贡献 1 还是 1,有 FWTxor(a)=merge(FWTxor(a0)+FWTxor(a1),FWTxor(a0)FWTxor(a1))

逆变换

逆变换什么的,把 FWT 倒过来做就好啦(
反正都是自底向上的,还是比较好做的。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <cstdio>
#define add(x,y) (x + y >= mod ? x + y - mod : x + y)
#define dec(x,y) (x < y ? x - y + mod : x - y)
using namespace std;
const int N = 1 << 17;
const int mod = 998244353;
const int inv = 499122177;
int n,a[N + 5],b[N + 5];
int f[N + 5],g[N + 5];
inline void fwt_or(int *a,int type)
{
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
a[i | j | m] = type == 1 ? add(a[i | j | m],a[i | j]) : dec(a[i | j | m],a[i | j]);
}
inline void fwt_and(int *a,int type)
{
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
a[i | j] = type == 1 ? add(a[i | j],a[i | j | m]) : dec(a[i | j],a[i | j | m]);
}
inline void fwt_xor(int *a,int type)
{
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
int t = a[i | j | m];
a[i | j | m] = dec(a[i | j],t),
a[i | j] = add(a[i | j],t),
type == -1 && (a[i | j] = (long long)a[i | j] * inv % mod,a[i | j | m] = (long long)a[i | j | m] * inv % mod);
}
}
int main()
{
scanf("%d",&n),n = 1 << n;
for(register int i = 0;i < n;++i)
scanf("%d",a + i);
for(register int i = 0;i < n;++i)
scanf("%d",b + i);
for(register int i = 0;i < n;++i)
f[i] = a[i],g[i] = b[i];
fwt_or(f,1),fwt_or(g,1);
for(register int i = 0;i < n;++i)
f[i] = (long long)f[i] * g[i] % mod;
fwt_or(f,-1);
for(register int i = 0;i < n;++i)
printf("%d%c",f[i]," \n"[i == n - 1]);
for(register int i = 0;i < n;++i)
f[i] = a[i],g[i] = b[i];
fwt_and(f,1),fwt_and(g,1);
for(register int i = 0;i < n;++i)
f[i] = (long long)f[i] * g[i] % mod;
fwt_and(f,-1);
for(register int i = 0;i < n;++i)
printf("%d%c",f[i]," \n"[i == n - 1]);
for(register int i = 0;i < n;++i)
f[i] = a[i],g[i] = b[i];
fwt_xor(f,1),fwt_xor(g,1);
for(register int i = 0;i < n;++i)
f[i] = (long long)f[i] * g[i] % mod;
fwt_xor(f,-1);
for(register int i = 0;i < n;++i)
printf("%d%c",f[i]," \n"[i == n - 1]);
}

Related Issues not found

Please contact @Alpha1022 to initialize the comment