一类用转置原理优化算法的 Trick

11k 词

对于一些算法,如果我们可以将其视作对一个向量的若干线性变换,则称这种算法为线性算法。

具体地,我们可以将算法过程中维护的变量写成一个向量 v\vec{v},然后如果对这些变量的操作都是在向量上的线性变换,那么这就是一个线性算法。

把线性变换写作一个矩阵 AA,就可以将该算法表示为矩阵与向量的乘法 A×vA \times \vec{v}

我们称 ATvA^T \vec{v} 为这个算法的转置算法。

但是我们显然不能把矩阵 AA 给直接写出来,因此需要考虑如何快速得到转置算法。

首先把矩阵 AA 拆成若干个初等矩阵相乘的形式:A=BkBk1B1A = B_k B_{k - 1} \ldots B_1,于是就有:

BkBk1B1v=B1TB2TBkTvB_k B_{k - 1} \ldots B_1 \vec{v} = B_1^T B_2^T \ldots B_k^T \vec{v}

这样就相当于把所有初等变换的顺序倒过来然后把每个初等变换转置一下,于是我们只需要求每个初等变换的转置就好了:

考虑一个初等矩阵 MM 对应的初等变换(以初等行变换为例):

  • 对于初等变换交换矩阵的某两行 i,ji, jMj,i=Mi,j=1,Mi,i=Mj,j=0M_{j, i} = M_{i, j} = 1, M_{i, i} = M_{j, j} = 0,转置后有 Mj,iT=Mi,jT=1,Mi,iT=Mj,jT=0M^T_{j, i} = M^T_{i, j} = 1, M^T_{i, i} = M^T_{j, j} = 0,因此这种初等变换的转置和原本一样,对应到向量上的操作就是交换两维。
  • 对于初等变换给第 ii 行乘上 kkMi,i=kM_{i, i} = k,转置后有 Mi,iT=kM^T_{i, i} = k,因此这种初等变换的转置和原本一样,对应到向量上的操作就是某一维乘上一个系数 kk
  • 对于初等变换给第 ii 行加上第 jj 行:Mi,j=1M_{i, j} = 1,转置后有 Mj,iT=1M^T_{j, i} = 1,对应到向量上的操作就是给第 ii 维加上第 jj 维变成给第 jj 维变成第 ii 维。

于是我们可以得到线性算法所有运算均为以下三种之一:

  • 交换变量 x,yx, y,转置依然是交换变量 x,yx, y
  • xkxx \gets kx,转置后依然是 xkxx \gets kx。注意 kk 可以等于 00
  • xx+kyx \gets x + ky,转置后变成 yy+kxy \gets y + kxxyx \gets y 可以写成上一运算与该运算的组合,即 x0x,xx+yx \gets 0x, x\gets x + y

显然一个算法转置后复杂度不变,如果我们能对转置后的算法进行优化,那优化后再转置一次就可以得到原算法的优化了。

另一方面,在某些题目中可以将转置后的算法进行一些小修改使其达到与原算法相同的效果。

例题

CF2039F1 Shohag Loves Counting (Easy Version) / CF2039F2 Shohag Loves Counting (Hard Version)

先考虑 F1,考虑什么样的序列是好的,显然每次长度增加 11,其权值都是原本权值的倍数,因此首先可以得到好的序列长度为 O(logm)O(\log m)

然后考虑区间 max\max 构成的集合,显然每次删掉的都必须是最小值且只删掉最小值,因此序列必须是单谷的。然后又可以得到序列中每个数都必须互不相同,否则会存在某个时刻中间有两个最小值,然后下一时刻最小值就不会被删掉。

因此问题转化成:求所有单调递减序列 ana_n 使得 i[1,n),gcd(a1,a2,,ai)>gcd(a1,a2,,ai,ai+1)\forall i \in [1, n), \gcd(a_1, a_2, \ldots, a_i) > \gcd(a_1, a_2, \ldots, a_i, a_{i + 1}),并且这个序列的贡献为 2n12^{n - 1}

fi,jf_{i, j} 表示 gcd=i\gcd = i,长度为 jj 的答案,转移就考虑加入最小值 xx。为了满足加入的数确实是最小值,我们可以从大到小枚举 xx 进行转移。

转移式子大概长这样:

fi,jfi,j+k=2fik,j1[gcd(ik,x)=i]fx,1=1f_{i, j} \gets f_{i, j} + \sum\limits_{k = 2} f_{ik, j - 1} [\gcd(ik, x) = i] \\ f_{x, 1} = 1

然后发现系数是 2j12^{j - 1},这个可以直接乘进去,所以修改状态定义令 fif_igcd=i\gcd = i 的所有序列的贡献和:

fifi+k=22fik[gcd(ik,x)=i]fxfx+1f_i \gets f_i + \sum\limits_{k = 2} 2f_{ik} [\gcd(ik, x) = i] \\ f_x \gets f_x + 1

考虑优化转移,显然可以改成容斥,令 si=k=1fiks_i = \sum_{k = 1} f_{ik},则转移变为:

fi=fi+ijjx2sjμ(ji)2fifxfx+1f_i = f_i + \sum\limits_{i | j \land j | x} 2 s_j \mu\left(\frac{j}{i}\right) - 2f_i \\ f_x \gets f_x + 1

后面减去的是求 gcd\gcd 后等于自己的方案数。

直接做是 O(mlog2m)O(m\log^2m),可以高维后缀和优化到 O(mlogmloglogm)O(m\log m\log\log m)

参考代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
// 長い夜の終わりを信じながら
// Think twice, code once.
#include <vector>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
#define eputchar(c) putc(c, stderr)
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
#define eputs(str) fputs(str, stderr), putc('\n', stderr)
using namespace std;

const int mod = 998244353;

int T, n, mu[1000005], dp[1000005], sum[1000005], tmp[1000005];
vector<int> vec[1000005];

int main() {
mu[1] = 1;
for (int i = 1; i <= 1000000; i++)
for (int j = i; j <= 1000000; j += i) {
vec[j].push_back(i);
if (j != i) mu[j] -= mu[i];
}
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
for (int i = 1; i <= n; i++) dp[i] = sum[i] = 0;
for (int i = n; i >= 1; i--) {
for (int j : vec[i]) tmp[j] = (mod - dp[j] * 2 % mod) % mod;
for (int j : vec[i])
for (int k : vec[j]) tmp[k] = (tmp[k] + 2ll * sum[j] * mu[j / k] % mod + mod) % mod;
tmp[i] = (tmp[i] + 1) % mod;
for (int j : vec[i]) {
dp[j] = (dp[j] + tmp[j]) % mod;
for (int k : vec[j]) sum[k] = (sum[k] + tmp[j]) % mod;
}
}
printf("%d\n", sum[1]);
}
return 0;
}

由于我们是从后往前枚举的,因此无法通过 F2。

但是这显然是线性算法,我们可以将其视为有一个向量 v\vec{v}

v=(s1s2snf1f2fntmp1tmp2tmpn1)\vec{v} = \left(\begin{matrix} s_1 \\ s_2 \\ \vdots \\ s_n\\ f_1 \\ f_2 \\ \vdots \\ f_n \\ tmp_1 \\ tmp_2 \\ \vdots \\ tmp_n \\ 1 \end{matrix}\right)

初始时 v\vec{v} 除了最后一维都是 00

然后枚举的每个 ii 都视为将其乘上一个矩阵 AiA_i,则问题答案即为:

(1000)×A1A2Am×(0001)\left(\begin{matrix} 1 & 0 & 0 & \cdots & 0 \end{matrix}\right) \times A_1 A_2 \cdots A_m \times \left(\begin{matrix} 0 \\ 0 \\ \vdots \\ 0 \\ 1 \end{matrix}\right)

将其转置后即为:

(0001)×AmTA2TA1T×(1000)\left(\begin{matrix} 0 & 0 & \cdots & 0 & 1 \end{matrix}\right) \times A_m^T \cdots A_2^T A_1^T \times \left(\begin{matrix} 1 \\ 0 \\ 0 \\ \vdots \\ 0 \end{matrix}\right)

也就是初始时令 s1=1s_1 = 1,用一个变量 valval 来维护原向量中的那个 11,从前往后枚举 ii 并乘上转置后的矩阵,最后 valval 即为答案。

由于是从前往后做的,因此后一个的答案可以从前一个继承过来,只需要做一遍就可以求出所有 mm 的解然后 O(1)O(1) 查询了。

参考代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// 長い夜の終わりを信じながら
// Think twice, code once.
#include <vector>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
#define eputchar(c) putc(c, stderr)
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
#define eputs(str) fputs(str, stderr), putc('\n', stderr)
using namespace std;

const int mod = 998244353;

int T, n, mu[1000005], dp[1000005], sum[1000005], tmp[1000005], val, ans[1000005];
vector<int> vec[1000005];

int main() {
mu[1] = 1;
for (int i = 1; i <= 1000000; i++)
for (int j = i; j <= 1000000; j += i) {
vec[j].push_back(i);
if (j != i) mu[j] -= mu[i];
}
sum[1] = 1;
for (int i = 1; i <= 1000000; i++) {
for (int j : vec[i]) {
tmp[j] = (tmp[j] + dp[j]) % mod;
for (int k : vec[j]) tmp[j] = (tmp[j] + sum[k]) % mod;
}
val = (val + tmp[i]) % mod;
for (int j : vec[i])
for (int k : vec[j]) sum[j] = (sum[j] + 2ll * tmp[k] * mu[j / k] % mod + mod) % mod;
for (int j : vec[i]) dp[j] = (dp[j] - tmp[j] * 2 % mod) % mod;
for (int j : vec[i]) tmp[j] = 0;
ans[i] = val;
}
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
printf("%d\n", ans[n]);
}
return 0;
}

优化 NTT

先贴个码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
static void NTT(poly &g, int flag) {
int n = g.size();
vector<unsigned long long> f(g.begin(), g.end());
vector<int> swp(n);
for (int i = 0; i < n; i++) {
swp[i] = swp[i >> 1] >> 1 | ((i & 1) * (n >> 1));
if (i < swp[i]) std::swap(f[i], f[swp[i]]);
}
for (int mid = 1; mid < n; mid <<= 1) {
int w1 = power(flag ? G : invG, (mod - 1) / mid / 2);
vector<int> w(mid);
w[0] = 1;
for (int i = 1; i < mid; i++) w[i] = (long long)w[i - 1] * w1 % mod;
for (int i = 0; i < n; i += mid << 1)
for (int j = 0; j < mid; j++) {
int t = (long long)w[j] * f[i + mid + j] % mod;
f[i + mid + j] = f[i + j] - t + mod;
f[i + j] += t;
}
if (mid == 1 << 10)
for (int i = 0; i < n; i++) f[i] %= mod;
}
int inv = flag ? 1 : power(n, mod - 2);
for (int i = 0; i < n; i++) g[i] = f[i] % mod * inv % mod;
return;
}

显然 NTT 也是线性算法,考虑它的转置长什么样。

首先翻转顺序,然后考虑中间这部分:

1
2
3
int t = (long long)w[j] * f[i + mid + j] % mod;
f[i + mid + j] = f[i + j] - t + mod;
f[i + j] += t;

以下用 w,x,yw, x, y 来表示 wj,fi+j,fi+mid+jw_j, f_{i + j}, f_{i + mid + j},这部分的运算可以写成:

t0,tt+wyy0,yy+x,yytxx+tt0t \gets 0, t \gets t + wy \\ y \gets 0, y \gets y + x, y \gets y - t \\ x \gets x + t \\ t \gets 0

转置后就变成:

t0tt+xtty,xx+y,y0yy+wt,t0t \gets 0 \\ t \gets t + x \\ t \gets t - y, x \gets x + y, y \gets 0 \\ y \gets y + wt, t \gets 0

于是我们可以写出转置后的 NTT:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
static void NTT(poly &g, int flag) {
int n = g.size();
vector<unsigned long long> f(g.begin(), g.end());
int inv = flag ? 1 : power(n, mod - 2);
for (int i = 0; i < n; i++) f[i] = f[i] % mod * inv % mod;
for (int mid = n >> 1; mid >= 1; mid >>= 1) {
int w1 = power(flag ? G : invG, (mod - 1) / mid / 2);
vector<int> w(mid);
w[0] = 1;
for (int i = 1; i < mid; i++) w[i] = (long long)w[i - 1] * w1 % mod;
for (int i = 0; i < n; i += mid << 1)
for (int j = 0; j < mid; j++) {
int t = (long long)(f[i + j] - f[i + mid + j] % mod + mod) * w[j] % mod;
f[i + j] += f[i + mid + j];
f[i + mid + j] = t;
}
if (mid == 1 << 10)
for (int i = 0; i < n; i++) f[i] %= mod;
}
vector<int> swp(n);
for (int i = 0; i < n; i++) {
swp[i] = swp[i >> 1] >> 1 | ((i & 1) * (n >> 1));
if (i < swp[i]) std::swap(f[i], f[swp[i]]);
}
for (int i = 0; i < n; i++) g[i] = f[i] % mod;
return;
}

然后考虑 NTT 的本质是什么,本质是求多项式在 ωn0,ωn1,,ωnn1\omega_n^0, \omega_n^1, \ldots, \omega_n^{n - 1} 处的点值(DFT)和根据点值求多项式的系数(IDFT),如果你把这个东西直接写成矩阵,你会发现它们分别是:

ωn0ωn0ωn0ωn0ωn0ωn1ωn2ωnn1ωn0ωn2ωn4ωn2n2ωn0ωnn1ωn2n2ωn(n1)2\begin{matrix} \omega_n^0 & \omega_n^0 & \omega_n^0 & \cdots & \omega_n^0 \\ \omega_n^0 & \omega_n^1 & \omega_n^2 & \cdots & \omega_n^{n - 1} \\ \omega_n^0 & \omega_n^2 & \omega_n^4 & \cdots & \omega_n^{2n - 2} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ \omega_n^0 & \omega_n^{n - 1} & \omega_n^{2n - 2} & \cdots & \omega_n^{(n - 1)^2} \end{matrix}

和:

ωn0nωn0nωn0nωn0nωn0nωn1nωn2nωn1nnωn0nωn2nωn4nωn22nnωn0nωn1nnωn22nnωn(n1)2n\begin{matrix} \frac{\omega_n^0}{n} & \frac{\omega_n^0}{n} & \frac{\omega_n^0}{n} & \cdots & \frac{\omega_n^0}{n} \\ \frac{\omega_n^0}{n} & \frac{\omega_n^{-1}}{n} & \frac{\omega_n^{-2}}{n} & \cdots & \frac{\omega_n^{1 - n}}{n} \\ \frac{\omega_n^0}{n} & \frac{\omega_n^{-2}}{n} & \frac{\omega_n^{-4}}{n} & \cdots & \frac{\omega_n^{2 - 2n}}{n} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ \frac{\omega_n^0}{n} & \frac{\omega_n^{1 - n}}{n} & \frac{\omega_n^{2 - 2n}}{n} & \cdots & \frac{\omega_n^{-(n - 1)^2}}{n} \end{matrix} \\

所以 NTT 的转置和 NTT 效果完全相同!

对于很多多项式操作(如多项式乘法、求逆等),在 DFT 之后 IDFT 之前,我们并不关心其每个值的顺序,因此如果只对 DFT 转置,那 DFT 和 IDFT 之间的两次蝴蝶变换就可以被省略掉了!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// 这份代码还稍微优化了下取模常数。
static void NTT(poly &g, int flag) {
int n = g.size();
vector<int> f(g.begin(), g.end());
if (flag) {
for (int mid = n >> 1; mid >= 1; mid >>= 1) {
int w1 = power(G, (mod - 1) / mid / 2);
vector<int> w(mid);
w[0] = 1;
for (int i = 1; i < mid; i++) w[i] = (long long)w[i - 1] * w1 % mod;
for (int i = 0; i < n; i += mid << 1)
for (int j = 0; j < mid; j++) {
int t = (long long)(f[i + j] - f[i + mid + j] + mod) * w[j] % mod;
f[i + j] = f[i + j] + f[i + mid + j] >= mod ?
f[i + j] + f[i + mid + j] - mod : f[i + j] + f[i + mid + j];
f[i + mid + j] = t;
}
}
for (int i = 0; i < n; i++) g[i] = f[i];
} else {
for (int mid = 1; mid < n; mid <<= 1) {
int w1 = power(invG, (mod - 1) / mid / 2);
vector<int> w(mid);
w[0] = 1;
for (int i = 1; i < mid; i++) w[i] = (long long)w[i - 1] * w1 % mod;
for (int i = 0; i < n; i += mid << 1)
for (int j = 0; j < mid; j++) {
int t = (long long)w[j] * f[i + mid + j] % mod;
f[i + mid + j] = f[i + j] - t < 0 ? f[i + j] - t + mod : f[i + j] - t;
f[i + j] = f[i + j] + t >= mod ? f[i + j] + t - mod : f[i + j] + t;
}
}
int inv = power(n, mod - 2);
for (int i = 0; i < n; i++) g[i] = (long long)f[i] * inv % mod;
}
return;
}
留言