FFT

多项式求逆元

Posted by Miskcoo's Space on May 10, 2015

概述

多项式求逆元是多项式除法和多项式取模的必要过程,在竞赛中,有时候多项式求逆元的应用比多项式的除法还要多,用快速傅里叶变换以及倍增算法可以做到在 $\mathcal O(n\log n)$ 的时间内求出一个多项式的逆元

基本概念

在介绍多项式的逆元之前,先说明一些概念:多项式的度、多项式的逆元、多项式的除法和取余

对于一个多项式 $A(x)$,称其最高项的次数为这个多项式的(degree),记作 $deg A$

对于多项式 $A(x), B(x)$,存在唯一的 $Q(x), R(x)$ 满足 $A(x) = Q(x)B(x) + R(x)$,其中 $deg R < deg B$,我们称 $Q(x)$ 为 $B(x)$ 除 $A(x)$ 的,$R(x)$ 为 $B(x)$ 除 $A(x)$ 的余数,可以记作

\[A(x) \equiv R(x) \pmod {B(x)}\]

多项式的逆元

对于一个多项式 $A(x)$,如果存在 $B(x)$ 满足 $deg B \leq deg A$ 并且

\[A(x)B(x) \equiv 1 \pmod {x^n}\]

那么称 $B(x)$ 为 $A(x)$ 在 $\bmod x^n$ 意义下的逆元(inverse element),记作 $A^{-1}(x)$

求解过程

现在考虑如何求 $A^{-1}(x)$,当 $n=1$ 时,$A(x) \equiv c \pmod x$,$c$ 是一个常数,这样,$A^{-1}(x)$ 就是 $c^{-1}$

对于 $n>1$ 的情况,设 $B(x) = A^{-1}(x)$ 由定义可以知道

\[\begin{equation} \label{eqn1} A(x)B(x) \equiv 1 \pmod {x^n} \end{equation}\]

假设在 $\bmod x^{\lceil \frac{n}{2} \rceil}$ 意义下 $A(x)$ 的逆元是 $B’(x)$ 并且我们已经求出,那么

\[\begin{equation} \label{eqn2} A(x)B'(x) \equiv 1 \pmod {x^{\lceil \frac{n}{2} \rceil}} \end{equation}\]

再将 $(\ref{eqn1})$ 放到 $\bmod x^{\lceil \frac{n}{2} \rceil}$ 意义下

\[\begin{equation} \label{eqn3} A(x)B(x) \equiv 1 \pmod {x^{\lceil \frac{n}{2} \rceil}} \end{equation}\]

然后 $(\ref{eqn2}) - (\ref{eqn3})$ 就可以得到

\[B(x) - B'(x) \equiv 0 \pmod {x^{\lceil \frac{n}{2} \rceil}}\]

两边平方

\[B^2(x) - 2B'(x)B(x) + B'^2(x) \equiv 0 \pmod {x^n}\]

这里解释一下平方后为什么模的 $x^{\lceil \frac{n}{2} \rceil}$ 也会平方,这是因为,左边多项式在 $\bmod x^n$ 意义下为 $0$,那么就说明其 $0$ 到 $n-1$ 次项系数都为 $0$,平方了之后,对于第 $0 \leq i \leq 2n-1$ 项,其系数 $a_i$ 为 $\sum_{j=0}^i a_ja_{i-j}$,很明显 $j$ 和 $i-j$ 之间必然有一个值小于 $n$,因此 $a_i$ 必然是 $0$,也就是说平方后在 $\bmod x^{2n}$ 意义下仍然为 $0$

然后同时乘上 $A(x)$,移项可以得到

\[B(x) \equiv 2B'(x) - A(x)B'^2(x) \pmod {x^n}\]

这样就可以得到 $\bmod x^n$ 意义下的逆元了,利用 FFT 加速之后可以做到在 $\mathcal O(n\log n)$ 时间内解决当前问题,最后总的时间复杂度也就是

\[T(n) = T(\frac{n}{2}) + \mathcal O(n \log n) = \mathcal O(n \log n)\]

顺便一提,由这个过程可以看出,一个多项式有没有逆元完全取决于其常数项是否有逆元

代码实现

假设我已经有了计算快速傅里叶变换的函数 transform(int deg, complex_t* x, complex_t* w) 以及单位根表 epsinv_eps

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
void polynomial_inverse(int deg, complex_t* a, complex_t* b, complex_t* tmp)
{
	if(deg == 1)
	{
		b[0] = 1.0 / a[0];
	} else {
		polynomial_inverse((deg + 1) >> 1, a, b, tmp);

		int p = 1;
		while(p < deg << 1) p <<= 1;
		copy(a, a + deg, tmp);
		fill(tmp + deg, tmp + p, 0.0);
		transform(p, tmp, eps);
		transform(p, b, eps);
		for(int i = 0; i != p; ++i)
			b[i] *= 2.0 - tmp[i] * b[i];
		transform(p, b, inv_eps);
		for(int i = 0; i != p; ++i)
			b[i] /= p;
		fill(b + deg, b + p, 0.0);
	}
}

下面是数论版的,并且是完整的代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <cstdio>
#include <complex>
#include <cmath>
#include <algorithm>
#include <iostream>
using std::copy;
using std::fill;

const long long mod_v = 17ll * (1 << 27) + 1;
const int MaxN = 10010;
long long a[MaxN], b[MaxN], c[MaxN];
long long eps[MaxN], inv_eps[MaxN];
int tot;

long long power(long long x, long long p)
{
	long long v = 1;
	while(p)
	{
		if(p & 1) v = x * v % mod_v;
		x = x * x % mod_v;
		p >>= 1;
	}

	return v;
}

void init_eps(int n)
{
	tot = n;
	long long base = power(3, (mod_v - 1) / n);
	long long inv_base = power(base, mod_v - 2);
	eps[0] = 1, inv_eps[0] = 1;
	for(int i = 1; i < n; ++i)
	{
		eps[i] = eps[i - 1] * base % mod_v;
		inv_eps[i] = inv_eps[i - 1] * inv_base % mod_v;
	}
}

long long inc(long long x, long long d) 
{
	x += d; 
	return x >= mod_v ? x - mod_v : x; 
}

long long dec(long long x, long long d) 
{
	x -= d; 
	return x < 0 ? x + mod_v : x; 
}

void transform(int n, long long *x, long long *w)
{
	for(int i = 0, j = 0; i != n; ++i)
	{
		if(i > j) std::swap(x[i], x[j]);
		for(int l = n >> 1; (j ^= l) < l; l >>= 1);
	}

	for(int i = 2; i <= n; i <<= 1)
	{
		int m = i >> 1;
		for(int j = 0; j < n; j += i)
		{
			for(int k = 0; k != m; ++k)
			{
				long long z = x[j + m + k] * w[tot / i * k] % mod_v;
				x[j + m + k] = dec(x[j + k], z);
				x[j + k] = inc(x[j + k], z);
			}
		}
	}
}

void polynomial_inverse(int deg, long long* a, long long* b, long long* tmp)
{
	if(deg == 1)
	{
		b[0] = power(a[0], mod_v - 2);
	} else {
		polynomial_inverse((deg + 1) >> 1, a, b, tmp);

		int p = 1;
		while(p < deg << 1) p <<= 1;
		copy(a, a + deg, tmp);
		fill(tmp + deg, tmp + p, 0);
		transform(p, tmp, eps);
		transform(p, b, eps);
		for(int i = 0; i != p; ++i)
		{
			b[i] = (2 - tmp[i] * b[i] % mod_v) * b[i] % mod_v;
			if(b[i] < 0) b[i] += mod_v;
		}
		transform(p, b, inv_eps);
		long long inv = power(p, mod_v - 2);
		for(int i = 0; i != p; ++i)
			b[i] = b[i] * inv % mod_v;
		fill(b + deg, b + p, 0);

	}
}

int main()
{
	init_eps(2048);
	int n;
	std::cin >> n;
	for(int i = 0; i != n; ++i)
		std::cin >> a[i];
	polynomial_inverse(n, a, b, c);
	std::cout << "inverse: ";
	for(int i = 0; i != n; ++i)
		printf("%lld ", (b[i] + mod_v) % mod_v);
	std::cout << std::endl;
	return 0;
}

 应用

预处理 Bernoulli 数

Bernoulli 数的指数生成函数(EGF)是

\[\frac{x}{e^x-1}=\sum_{i=0}^\infty B_i \frac{x^i}{i!}\]

将 $e^x$ 泰勒展开就可以改写成

\[\frac{x}{e^x-1}=\frac{1}{\sum_{i=0}^\infty \frac{x^i}{(i+1)!}}\]

然后利用刚刚所说的方法,求出这个多项式的逆元就可以得到 Bernoulli 数了

计算有标号简单连通无向图个数

这篇文章,在列出方程后最后可以用多项式求逆优化