FFT

牛顿迭代法在多项式运算的应用

Posted by Miskcoo's Space on June 8, 2015

总算是快把 FFT 和生成函数的各种东西补了好多,膜拜策爷的论文和 Picks 的博客 QAQ

这篇文章大概就是说如何用牛顿迭代法来解满足 $G(F(z)) \equiv 0 \pmod {z^n}$ 的 $F(z)$

然后这东西可以比较方便地计算 $\sqrt{A(z)}$、$e^{A(z)}$,也就是多项式开根、求指数对数之类鬼畜的东西,在生成函数计数中十分有用

顺便一提,这里说的“多项式”实际上你可以直接理解为生成函数或者形式幂级数

Newton’s Method

好,现在问题是这样的,已经知道了一个函数 $G(z)$,求一个多项式 $F(z) \bmod {z^n}$,满足方程

\[G(F(z)) \equiv 0 \pmod {z^n}\]

然后这个问题嘛…… 可以回忆一下多项式求逆的过程

首先 $n = 1$ 的时候,$G(F(z)) \equiv 0 \pmod z$,这是要单独求出来的

现在假设已经求出了

\[G(F_0(z)) \equiv 0 \pmod {z^{\lceil \frac{n}{2} \rceil}}\]

考虑如何扩展到 $\bmod z^n$ 下,可以把 $G(F(z))$ 在 $F_0(z)$ 这里进行泰勒展开

\[\begin{eqnarray*} G(F(z)) & = & G(F_0(z)) \\ & + & \frac{G'(F_0(z))}{1!}\left(F(z) - F_0(z)\right) \\ & + & \frac{G''(F_0(z))}{2!}\left(F(z) - F_0(z)\right)^2 \\ & + & \cdots \end{eqnarray*}\]

因为 $F(z)$ 和 $F_0(z)$ 的最后 $\lceil \frac{n}{2} \rceil$ 项相同,所以 $\left(F(z) - F_0(z)\right)^2$ 的最低的非 $0$ 项次数大于 $2\lceil \frac{n}{2} \rceil$,所以可以得到

\[G(F(z)) \equiv G(F_0(z)) + G'(F_0(z))\left(F(z) - F_0(z)\right) \pmod {z^n}\]

然后因为 $ G(F(z)) \equiv 0 \pmod {z^n} $,可以得到

\[F(z) \equiv F_0(z) - \frac{G(F_0(z))}{G'(F_0(z))} \pmod {z^n}\]

然后好像就完了?现在来看看它能干什么用好了

应用

多项式开方

这是给出 $A(z)$,求 $B(z)$,满足方程

\[B^2(z) \equiv A(z) \pmod {z^n}\]

然后我们可以构造方程 $F^2(z) - A(z) = 0$,目的就是要求解 $F(z) \bmod z^n$

这时候 $G(F(z)) = F^2(z) - A(z)$,$G’(F(z)) = 2F(z)$,带到上面的迭代方程

\[\begin{eqnarray*} F(z) & \equiv & F_0(z) - \frac{F_0^2(z) - A(z)}{2F_0(z)} \\ & \equiv & \frac{F_0^2(z) + A(z)}{2F_0(z)} \pmod {z^n} \end{eqnarray*}\]

然后就可以计算了,复杂度是 $\mathcal O(n\log n)$

当系数是在模意义下的时候,有点麻烦的其实是常数项的确定,如果 $[z^0]A(z) = 1$ 还好,否则你可能要计算二次剩余

例. [Codeforces #250 Div1 E] The Child And Binary Tree

有一个含有 $n$ 个元素的正整数集合 $S = { c_1, c_2, \cdots, c_n }$,我们称一个节点带权的有根二叉树是好的,当且仅当对于每个节点 $v$,$v$ 的权值在 $S$ 内,并且我们称这棵树的权值为所有节点的权值和

现在给出一个正整数 $m$,你要计算出对于所有正整数 $1 \leq s \leq m$,有多少不同的好的二叉树满足它的权值是 $s$

答案对 $998244353 (7 \times 17 \times 2^{23} + 1$,一个质数$)$ 取模,其中 $1 \leq n, m, c_i \leq 10^5$

题目链接:Codeforces-438EBZOJ-3625

Example:比如说 $S = {1, 2}$,$s = 3$ 的时候一共有 $9$ 种方案

Solution:嗯…… 既然我放在了这里不用说肯定是和多项式开方有关系啦

直接用生成函数来考虑这个问题的话,对于一个节点来说,它的生成函数是

\[T(z) = \sum_{s \in S} z^s\]

现在假设答案的生成函数是 $F(z)$,那么由于二叉树的递归性质,可以由左右子树 $F^2(z)$ 加上一个根节点 $T(z)$ 组合而成,并且还有一个空树我们也算作一种,那么可以得到方程

\[F(z) = 1 + T(z) F^2(z)\]

解出来可以得到

\[F(z) = \frac{1 - \sqrt{1 - 4T(z)}}{2T(z)}\]

然后剩下的就是多项式开方和多项式求逆了

如果从递推的角度来考虑的话,我们可以设权值为 $s$ 的方案有 $f_s$ 种,那么可以得到方程

\[f_s = \sum_{i \in S, s - i \geq 0}\sum_{j = 0}^{s - i} f_j f_{s - i - j}\]

边界条件是 $f_0 = 1$,这样你会发现 $f_s$ 里面的求和是个卷积的形式,同样可以转换为生成函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#include <cstdio>
#include <algorithm>

using std::swap;
using std::fill;
using std::copy;

typedef int value_t;
typedef long long calc_t;
const int MaxN = 1 << 19;
const value_t mod_base = 119, mod_exp = 23;
const value_t mod_v = (mod_base << mod_exp) + 1;
const value_t primitive_root = 3;
int epsilon_num;
value_t eps[MaxN], inv_eps[MaxN], inv2;

value_t dec(value_t x, value_t v) { x -= v; return x < 0 ? x + mod_v : x; }
value_t inc(value_t x, value_t v) { x += v; return x >= mod_v ? x - mod_v : x; }
value_t pow(value_t x, value_t p)
{
	value_t v = 1;
	for(; p; p >>= 1, x = (calc_t)x * x % mod_v)
		if(p & 1) v = (calc_t)x * v % mod_v;
	return v;
}

void init_eps(int num)
{
	epsilon_num = num;
	value_t base = pow(primitive_root, (mod_v - 1) / num);
	value_t inv_base = pow(base, mod_v - 2);
	eps[0] = inv_eps[0] = 1;
	for(int i = 1; i != num; ++i)
	{
		eps[i] = (calc_t)eps[i - 1] * base % mod_v;
		inv_eps[i] = (calc_t)inv_eps[i - 1] * inv_base % mod_v;
	}
}

void transform(int n, value_t *x, value_t *w = eps)
{
	for(int i = 0, j = 0; i != n; ++i)
	{
		if(i > j) swap(x[i], x[j]);
		for(int l = n >> 1; (j ^= l) < l; l >>= 1);
	}

	for(int i = 2; i <= n; i <<= 1)
	{
		int m = i >> 1, t = epsilon_num / i;
		for(int j = 0; j < n; j += i)
		{
			for(int p = 0, q = 0; p != m; ++p, q += t)
			{
				value_t z = (calc_t)x[j + m + p] * w[q] % mod_v;
				x[j + m + p] = dec(x[j + p], z);
				x[j + p] = inc(x[j + p], z);
			}
		}
	}
}

void inverse_transform(int n, value_t *x)
{
	transform(n, x, inv_eps);
	value_t inv = pow(n, mod_v - 2);
	for(int i = 0; i != n; ++i)
		x[i] = (calc_t)x[i] * inv % mod_v;
}

void polynomial_inverse(int n, value_t *A, value_t *B)
{
	static value_t T[MaxN];
	if(n == 1)
	{
		B[0] = pow(A[0], mod_v - 2);
		return;
	}

	int half = (n + 1) >> 1;
	polynomial_inverse(half, A, B);

	int p = 1;
	for(; p < n << 1; p <<= 1);

	fill(B + half, B + p, 0);
	transform(p, B);

	copy(A, A + n, T);
	fill(T + n, T + p, 0);
	transform(p, T);

	for(int i = 0; i != p; ++i)
		B[i] = (calc_t)B[i] * dec(2, (calc_t)T[i] * B[i] % mod_v) % mod_v;
	inverse_transform(p, B);
}

void polynomial_sqrt(int n, value_t *A, value_t *B)
{
	static value_t T[MaxN];
	if(n == 1)
	{
		B[0] = 1; // sqrt A[0], here is 1
		return;
	}

	int p = 1;
	for(; p < n << 1; p <<= 1);

	int half = (n + 1) >> 1;
	polynomial_sqrt(half, A, B);
	fill(B + half, B + n, 0);
	polynomial_inverse(n, B, T);
	fill(T + n, T + p, 0);
	transform(p, T);

	fill(B + half, B + p, 0);
	transform(p >> 1, B);
	for(int i = 0; i != p >> 1; ++i)
		B[i] = (calc_t)B[i] * B[i] % mod_v;
	inverse_transform(p >> 1, B);
	for(int i = 0; i != n; ++i)
		B[i] = (calc_t)inc(A[i], B[i]) * inv2 % mod_v;
	transform(p, B);
	for(int i = 0; i != p; ++i)
		B[i] = (calc_t)B[i] * T[i] % mod_v;
	inverse_transform(p, B);

}

value_t tmp[MaxN];
value_t A[MaxN], B[MaxN], C[MaxN], T[MaxN];

int main()
{
	int n, m;
	std::scanf("%d %d", &n, &m);
	int min_v = ~0u >> 1;
	for(int i = 0; i != n; ++i)
	{
		std::scanf("%d", tmp + i);
		if(min_v > tmp[i]) min_v = tmp[i];
	}

	inv2 = mod_v - mod_v / 2;

	int p = 1;
	for(; p < (m + min_v + 1) << 1; p <<= 1);
	init_eps(p);

	A[0] = 1;
	for(int i = 0; i != n; ++i)
	{
		int x = tmp[i];
		T[x - min_v] = 2;
		A[x] = mod_v - 4;
	}

	polynomial_inverse(m + min_v + 1, T, C);
	polynomial_sqrt(m + min_v + 1, A, B);
	B[0] = dec(1, B[0]);
	for(int i = 1; i <= m + min_v; ++i)
		B[i] = mod_v - B[i];
	for(int i = 0; i <= m; ++i)
		B[i] = B[i + min_v];
	fill(B + m + 1, B + p, 0);
	fill(C + m + 1, C + p, 0);
	transform(p, B);
	transform(p, C);
	for(int i = 0; i != p; ++i)
		B[i] = (calc_t)B[i] * C[i] % mod_v;
	inverse_transform(p, B);
	for(int i = 1; i <= m; ++i)
		std::printf("%d\n", B[i]);
	return 0;
}

 多项式的对数和指数函数

啥…… 多项式的对数是什么?其实我们可以认为它是一个多项式和麦克劳林级数的复合,也就是给出一个多项式 $A(z) = \sum_{i \geq 1} a_iz^i$

\[\ln (1 - A(z)) = -\sum_{i \geq 1} \frac{A^i(z)}{i}\]

指数函数同样可以这样定义

\(\exp(A(z)) = e^{A(z)} = \sum_{i \geq 0} \frac{A^i(z)}{i!}\)

对数的计算

对于一个多项式 $A(z) = 1 + \sum_{i \geq 1}a_iz^i $,现在要计算其对数 $\ln A(z)$

注意这里 $A(z)$ 的常数项是 $1$ 是因为上面的定义,因为到直接计算似乎很难计算,我们考虑对其求导后的结果

\[(\ln A(z))' = \frac{A'(z)}{A(z)}\]

也就是说,我们要计算出 $A(z)$ 的逆元,这可以在 $\mathcal O(n\log n)$ 的时间内完成,那么

\[\ln A(z) = \int \frac{A'(z)}{A(z)}\]

至于多项式的求导和积分,这是都是 $\mathcal O(n)$ 复杂度的,因此计算对数的时间是 $\mathcal O(n\log n)$

指数的计算

对于一个多项式 $A(z) = \sum_{i \geq 1}a_iz^i $,现在要计算其指数 $e^{A(z)}$

如果和对数计算一样直接求导,你会发现是不可行的,因为求导完还有指数函数自身,这时候就要利用刚刚所说的牛顿迭代,我们需要的方程实际上是

\[F(z) = e^{A(z)}\]

变形后变成

\[\ln F(z) - A(z) = 0\]

构造函数 $G(F(z)) = \ln F(z) - A(z)$,这时候 $G’(F(z)) = \frac{1}{F(z)}$,然后得到递推式

\[\begin{eqnarray*} F(z) & \equiv & F_0(z) - \frac{G(F_0(z))}{G'(F_0(z))} \\ & \equiv & F_0(z) \left (1 - \ln F_0(z) + A(z) \right ) \pmod {z^n}\end{eqnarray*}\]

然后 $F(z)$ 的常数项是 $1$,最后复杂度是

\(T(n) = T(\frac{n}{2}) + \mathcal O(n\log n) = \mathcal O(n\log n)\)

void polynomial_logarithm(int n, value_t *A, value_t *B)
{
	static value_t T[MaxN];
	int p = 1;
	for(; p < n << 1; p <<= 1);

	polynomial_inverse(n, A, T);
	fill(T + n, T + p, 0);
	transform(p, T);

	// derivative
	copy(A, A + n, B);
	for(int i = 0; i < n - 1; ++i)
		B[i] = (calc_t)B[i + 1] * (i + 1) % mod_v;
	fill(B + n - 1, B + p, 0);
	transform(p, B);

	for(int i = 0; i != p; ++i)
		B[i] = (calc_t)B[i] * T[i] % mod_v;
	inverse_transform(p, B);

	// integral
	for(int i = n - 1; i; --i)
		B[i] = (calc_t)B[i - 1] * inv[i] % mod_v;
	B[0] = 0;
}

void polynomial_exponent(int n, value_t *A, value_t *B)
{
	static value_t T[MaxN];
	if(n == 1)
	{
		B[0] = 1;
		return;
	}

	int p = 1; 
	for(; p < n << 1; p <<= 1);

	int half = (n + 1) >> 1;
	polynomial_exponent(half, A, B);
	fill(B + half, B + p, 0);

	polynomial_logarithm(n, B, T);
	for(int i = 0; i != n; ++i)
		T[i] = dec(A[i], T[i]);
	T[0] = inc(T[0], 1);
	transform(p, T);
	transform(p, B);
	for(int i = 0; i != p; ++i)
		B[i] = (calc_t)B[i] * T[i] % mod_v;
	inverse_transform(p, B);
}

任意次幂的计算

给出一个多项式 $A(z)$,你现在要计算 $A^k(z), k \in \mathbb Q$

对于 $k \in \mathbb N$ 的部分,可以直接使用快速幂来计算,这样复杂度是 $\mathcal O(n\log n\log k)$

现在有了求指数和求对数的运算,那么

\[A^k(z) = e^{k \ln A(z)}\]

这样就可以在 $\mathcal O(n\log n)$ 的时间内计算出 $A^k(z)$,包括开方等也可以这样来计算