「快速傅里叶变换 / 快速数论变换」学习笔记

接触 FFT,是第一次用 Python 写高精度乘法之后知道可以用它做到 O(nlogn)
然鹅当时的我十分的 simple(并不意味着现在不是),对它望而却步。

FFT

何为 FFT

它可以在 O(nlogn) 内把一个系数表示的多项式转化为它的点值表示

补充 - 点值表示

A(x) 为一个 n1 次多项式,那么用 n 个不同的 x 带入 A,算出 ny
n(x,y) 可以唯一确定这个多项式

两个多项式相乘称为卷积
系数表示的多项式求卷积的复杂度是 O(n2) 的。
但点值表示的多项式的复杂度是 O(n) 的。

DFT 离散傅里叶变换

傅里叶教我们用特定的 x 求点值表示——单位根!

补充 - 复数

从前老师教我们 n 有意义当且仅当 n0
但是我们也会遇到 1 这种东西。
我们称其为虚数

虚数单位 i=1,一个复数 (x,y)=x+yi
其中的 x 称为实部y 称为虚部

把复数看成一个向量/点,它所在的平面直角坐标系有一个特殊的名称——复平面。

补充 - 单位根

把单位圆(圆心在原点,半径为 1 的圆)n 等分,从 (1,0) 开始逆时针将其编号,第 k 个记为 ωkn
显而易见 ωkn=(ω1n)k,所以 ω1n 称为 n 次单位根。

ωkn=(coskn2π,sinkn2π)
以及两个比较显然的性质: - ωkn=ωxkxn。 - ωk+n2n=ωkn

IDFT - 离散傅立叶逆变换

把多项式 A(x) 使用单位根的点值表示再次作为另一个多项式 B(x) 的系数表示,取 ω0n,ω1n,,ωn+1n 代入求得 B 的点值表示。
将其每一位除以 n,就得到了 A 的系数表示。

A(x) 的点值表示是 (b1,b2,,bn)B(x) 的点值表示是 (c1,c2,,cn)
上述结论的证明: ck=n1i=0bi(ωkn)i=n1i=0(n1j=0)(ωkn)i=n1i=0n1j=0(ωikn)jai

ik=0n1j=0(ωikn)j=n
其余时候根据等比数列求和公式,可知其值为 0

FFT 快速傅里叶变换

然鹅 DFT 仍然是 O(n2) 的……
我们考虑用分治来优化。

A(x)=n1i=0aixi
A0(x)=n21i=0a2ixi,A1(x)=n21i=0a2i+1xi
于是有 A(x)=A0(x2)+xA1(x2)

对于 k<n2,有 A(ωkn)=A0((ωkn)2)+ωknA1((ωkn)2)=A0(ωkn2)+ωknA1(ωkn2)

A(ωk+n2n)=A0((ωk+n2n)2)+ωknA1((ωk+n2n)2)=A0(ωkn2)ωknA1(ωkn2)

然后就可以递归地写出一个 FFT 了。

一些优化

非递归

睿智的先人们找到了一种神奇的规律:在 FFT 分治时,最后第 x 项所在的位置是 x 二进制翻转后的数。

蝴蝶变换

证明十分严(kong)谨(bu),实际上在代码实现里只是把一个地方换了一下而简化了代码。

参考代码

洛谷 3803.多项式乘法(FFT)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <cstdio>
#include <cmath>
#include <complex>
#include <algorithm>
using namespace std;
const int N = 1 << 21;
const double PI = acos(-1);
typedef complex<double> cp;
int lena,lenb,n = 1,lg;
cp a[N + 5],b[N + 5],omg[N + 5],inv[N + 5];
void fft(cp *a,cp *omg)
{
for(register int i = 0;i < n;++i)
{
int t = 0;
for(register int j = 0;j < lg;++j)
if(i & (1 << j))
t |= (1 << lg - j - 1);
if(i < t)
swap(a[i],a[t]);
}
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
cp t = omg[n / w * j] * a[i + j + m];
a[i + j + m] = a[i + j] - t,a[i + j] += t;
}
}
int main()
{
scanf("%d%d",&lena,&lenb);
++lena,++lenb;
for(;n < lena + lenb;n <<= 1,++lg);
for(register int i = 0;i < n;++i)
inv[i] = conj(omg[i] = cp(cos(2 * PI * i / n),sin(2 * PI * i / n)));
int x;
for(register int i = 0;i < lena;++i)
scanf("%d",&x),a[i].real(x);
for(register int i = 0;i < lenb;++i)
scanf("%d",&x),b[i].real(x);
fft(a,omg),fft(b,omg);
for(register int i = 0;i < n;++i)
a[i] *= b[i];
fft(a,inv);
for(register int i = 0;i < lena + lenb - 1;++i)
printf("%d ",(int)(a[i].real() / n + 0.5));
}


NTT

FFT 到 NTT

傅立叶把单位根的性质应用到了 FFT 中,但是是不是只有单位根有这样的性质呢?
——不,还有原根

补充 - 原根

对于 g,P,如果 1i,j<P,ij,gigj(modP),则称 gP 的原根。

NTT 的特点

必须取模,而且模数形如 P=2kr+1

常用的模数有 998244353,1004535809,其原根均为 3
对于其他的模数,此处引用一个表,来源见参考文献
其中 gP=2kr+1 的原根。

P r k g
3 1 1 2
5 1 2 2
17 1 4 3
97 3 5 5
193 3 6 5
257 1 8 3
7681 15 9 17
12289 3 12 11
40961 5 13 3
65537 1 16 3
786433 3 18 10
5767169 11 19 3
7340033 7 20 3
23068673 11 21 3
104857601 25 22 3
167772161 5 25 3
469762049 7 26 3
1004535809 479 21 3
2013265921 15 27 31
2281701377 17 27 3
3221225473 3 30 5
75161927681 35 31 3
77309411329 9 33 7
206158430209 3 36 22
2061584302081 15 37 7
2748779069441 5 39 3
6597069766657 3 41 5
39582418599937 9 42 5
79164837199873 9 43 5
263882790666241 15 44 7
1231453023109121 35 45 3
1337006139375617 19 46 3
3799912185593857 27 47 5
4222124650659841 15 48 19
7881299347898369 7 50 6
31525197391593473 7 52 3
180143985094819841 5 55 6
1945555039024054273 27 56 5
4179340454199820289 29 57 3

为什么用原根

NTT 中把所有的 ωkn 全部替换成了 g(P1)kn
为什么可以呢?
因为 FFT 中用到的单位根的性质原根都满足。

参考代码

题目同上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int N = 1 << 21;
const long long mod = 998244353;
const long long G = 3;
const long long Gi = 332748118;
int lena,lenb,n = 1,lg;
long long fpow(long long a,long long b)
{
long long ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = ret * a % mod),a = a * a % mod;
return ret;
}
long long a[N + 5],b[N + 5],omg[N + 5],inv[N + 5];
void ntt(long long *a,long long *omg)
{
for(register int i = 0;i < n;++i)
{
int t = 0;
for(register int j = 0;j < lg;++j)
if(i & (1 << j))
t |= (1 << lg - j - 1);
if(i < t)
swap(a[i],a[t]);
}
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
long long t = omg[n / w * j] * a[i + j + m] % mod;
a[i + j + m] = (a[i + j] - t + mod) % mod,a[i + j] = (a[i + j] + t) % mod;
}
}
int main()
{
scanf("%d%d",&lena,&lenb);
++lena,++lenb;
for(;n < lena + lenb;n <<= 1,++lg);
for(register int i = 0;i < n;++i)
omg[i] = fpow(G,(mod - 1) / n * i),inv[i] = fpow(Gi,(mod - 1) / n * i);
int x;
for(register int i = 0;i < lena;++i)
scanf("%d",&x),a[i] = x;
for(register int i = 0;i < lenb;++i)
scanf("%d",&x),b[i] = x;
ntt(a,omg),ntt(b,omg);
for(register int i = 0;i < n;++i)
a[i] *= b[i];
ntt(a,inv);
long long n_inv = fpow(n,mod - 2);
for(register int i = 0;i < lena + lenb - 1;++i)
printf("%lld ",a[i] * n_inv % mod);
}

参考文献

Related Issues not found

Please contact @Alpha1022 to initialize the comment