基本概念
对于求和式 ∑anxn,如果是有限项相加,则称为多项式。记错 F(x)=∑i=0naixi,其中 ai 称为该多项式的 i 次项系数,记作 [xi]F(x),有时也用 F[i] 表示。
定义两个多项式的 ⊕ 运算卷积为:
[xn](F×G)(x)=i⊕j=n∑[xi]F(x)[xj]G(x)
以下提到卷积一般指加法卷积。
FFT
多项式的点值表达:只需要 n+1 个点值就可以唯一确定一个最高 n 次多项式。
两个多项式卷积就对应点值相乘。也就是对于任意 x,(F×G)(x)=F(x)G(x)。
FFT 的基本思想就是把两个系数表达的多项式转化成点值表达(DFT),点值相乘后再转换为系数表达(IDFT)。
单位根
n 次单位根(n 为正整数)是 n 次幂为 1 的复数。表现在复平面上就是单位圆上幅角 为 n2π 的倍数的点。
显然 n 次单位根有恰好 n 个,记作 ωn0,ωn1,ωn2,…,ωnn−1。其中 ωni 的幅角为 ni2π。
单位根的一些基本性质:
- ωn0=1。
- ωnk=(ωn1)k。
- ωniωnj=ωni+j。
- (ωnk)−1=ωn−k=ωnn−k。
- (ωni)j=ωnij。
- ω2n2k=ωnk。
- 若 n 为偶数,则 ωnk+n/2=−ωnk。
DFT
对于 n−1 次多项式 F(x),我们设两个多项式 FL(x),FR(x):
FL(x)=F[0]+F[2]x+F[4]x2+…+F[n−2]xn/2−1FR(x)=F[1]+F[3]x+F[5]x2+…+F[n−1]xn/2−1
这里钦定 n 是偶数。
则 F(x)=FL(x2)+xFR(x2)。
对于 0≤k<n/2,代入 x=ωnk:
F(ωnk)=FL((ωnk)2)+ωnkFR((ωnk)2)=FL(ωn/2k)+ωnkFR(ωn/2k)
代入 x=ωnk+n/2:
F(ωnk+n/2)=FL((ωnk+n/2)2)+ωnk+n/2FR((ωnk+n/2)2)=FL(ωn/2k)−ωnkFR(ωn/2k)
也就是只需要知道 FL(x) 和 FR(x) 在 ωn/20∼ωn/2n/2−1 处的点值表示,就可以知道 F(x) 在 ωn0∼ωnn−1 处的点值表示。
那么对于一个 n−1 次多项式,先扩展到 2k 次多项式,然后每次分治再合并就可以实现 DFT 了。
IDFT
设 n 个点值为 gi,则:
nF[k]=i=0∑n−1(ωn−k)iG[i]
证明:
i=0∑n−1(ωn−k)iG[i]=i=0∑n−1(ωn−k)ij=0∑n−1F[j](ωni)j=i=0∑n−1j=0∑n−1ωni(j−k)F[j]=j=0∑n−1F[j]i=0∑n−1ωni(j−k)=j=0∑n−1F[j]×[n∣(j−k)]×n=nF[k]
第三个等号到第四个等号是单位根反演。
那么就相当于 G[i] 作为系数再代入一遍。无非这次代入的是 ωn0,ωn−1,ωn−2,…,ωn1−n。那么就相当于 DFT 里求单位根后再求个逆元即可。
蝴蝶变换
我们发现 DFT 和 IDFT 差别不大,因此可以把这两个写在一起。
然后考虑优化分治的常数。
考虑过程中系数的变化,这里以 n=8 为例:
F[0]F[0]F[0]F[1]F[2]F[4]F[2]F[4]F[2]F[3]F[6]F[6]F[4]F[1]F[1]F[5]F[3]F[5]F[6]F[5]F[3]F[7]F[7]F[7]
稍微观察一下可以发现最终序列就是原序列下标二进制翻转。
这个东西显然可以 O(n) 递推。然后从下往上合并就可以了。
代码实现
复数类:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| struct complex { double a, b;
complex() = default; complex(double _a, double _b): a(_a), b(_b) {} complex operator+(const complex &x) const {return complex(a + x.a, b + x.b);} complex operator-(const complex &x) const {return complex(a - x.a, b - x.b);} complex operator*(const complex &x) const {return complex(a * x.a - b * x.b, a * x.b + b * x.a);} complex operator/(const complex &x) const { double t = b * b + x.b * x.b; return complex((a * x.a + b * x.b) / t, (b * x.a - a * x.b) / t); } complex &operator+=(const complex &x) {return *this = *this + x;} complex &operator-=(const complex &x) {return *this = *this - x;} complex &operator*=(const complex &x) {return *this = *this * x;} complex &operator/=(const complex &x) {return *this = *this / x;} };
|
FFT:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| void FFT(vector<complex> &f, int flag) const { int n = f.size(); vector<int> swp(n); for (int i = 0; i < n; i++) { swp[i] = swp[i >> 1] >> 1 | ((i & 1) * (n >> 1)); if (i < swp[i]) std::swap(f[i], f[swp[i]]); } for (int mid = 1; mid < n; mid <<= 1) { complex w1(cos(pi / mid), flag * sin(pi / mid)); for (int i = 0; i < n; i += mid << 1) { complex w(1, 0); for (int j = 0; j < mid; j++, w *= w1) { complex x = f[i + j], y = w * f[i + mid + j]; f[i + j] = x + y, f[i + mid + j] = x - y; } } } return ; }
|
卷积:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| poly operator*(const poly &b) const { const poly &a(*this); int n = 1; while (n < (int)(a.size() + b.size()) - 1) n <<= 1; vector<complex> A(n), B(n); for (int i = 0; i < (int)a.size(); i++) A[i] = complex(a[i], 0); for (int i = 0; i < (int)b.size(); i++) B[i] = complex(b[i], 0); FFT(A, 1), FFT(B, 1); vector<complex> C(n); for (int i = 0; i < n; i++) C[i] = A[i] * B[i]; FFT(C, -1); poly c; for (int i = 0; i < (int)(a.size() + b.size()) - 1; i++) c.push_back(C[i].a / n + 0.5); return c; }
|
NTT
FFT 依赖于单位根,于是不可避免地会产生精度问题。考虑找一个单位根的平替。
然而数学家已经证明,在复数域下单位根是唯一一类满足要求的数。
好在大部分题目都在模意义下进行。考虑模意义下什么东西能替代单位根。
那么就要思考我们用到了单位根的哪些性质:
- ωnk=(ωn1)k。
- ωn0∼ωnn−1 互不相同。
- ωnk=ωnkmodn。
- ω2n2k=ωnk。
我们发现原根能很好地满足要求,更准确地说是令 ωn1=g(p−1)/n。此处 p 为模数且是质数。
容易发现只要 n∣(p−1) 则上述性质都是满足的。
由于 n 是 2 的幂次,因此我们希望 p−1 的质因子 2 尽可能多。比如 p=998244353=223×7×17+1。
常见 NTT 模数原根表:https://blog.miskcoo.com/2014/07/fft-prime-table。
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| void NTT(vector<int> &g, int flag) const { int n = g.size(); vector<unsigned long long> f(g.begin(), g.end()); vector<int> swp(n); for (int i = 0; i < n; i++) { swp[i] = swp[i >> 1] >> 1 | ((i & 1) * (n >> 1)); if (i < swp[i]) std::swap(f[i], f[swp[i]]); } for (int mid = 1; mid < n; mid <<= 1) { int w1 = power(flag ? G : invG, (mod - 1) / mid / 2); vector<int> w(mid); w[0] = 1; for (int i = 1; i < mid; i++) w[i] = (long long)w[i - 1] * w1 % mod; for (int i = 0; i < n; i += mid << 1) for (int j = 0; j < mid; j++) { int t = (long long)w[j] * f[i + mid + j] % mod; f[i + mid + j] = f[i + j] - t + mod; f[i + j] += t; } if (mid == 1 << 10) for (int i = 0; i < n; i++) f[i] %= mod; } int inv = flag ? 1 : power(n, mod - 2); for (int i = 0; i < n; i++) g[i] = f[i] % mod * inv % mod; return ; }
|
这里用了一个取模优化:用 unsigned long long
存储结果,只需在中间进行一次取模即可。