BZOJ-3157. 国王奇遇记

Posted by miskcoo on June 8, 2014

题目大意是要求下面这个式子
\(\sum_{i=1}^n m^i \cdot i^m\)

这个题目有三个版本:

这篇文章介绍 $\mathcal O(m^2)$ 和 $\mathcal O(m)$ 两种做法

为了方便,定义一个函数 $f(i)$ \(f(i) = \sum_{k=1}^n k^i \cdot m^k\)

然后使用”扰动法” \(\begin{eqnarray*} (m - 1) \cdot f(i) & = & \sum_{k=1}^n k^i \cdot m^{k + 1} - \sum_{k=1}^n k^i \cdot m^k \\ & = & \sum_{k=1}^{n + 1} (k - 1)^i \cdot m^k - \sum_{k=1}^n k^i \cdot m^k \\ & = & n^i \cdot m^{n + 1} + \sum_{k=1}^n m^k \sum_{j = 0}^{i - 1} {i \choose j} \cdot (-1)^{i - j} \cdot k^j \\ & = & n^i \cdot m^{n + 1} + \sum_{j = 0}^{i - 1} {i \choose j} \cdot (-1)^{i - j} \sum_{k = 1}^n k^j \cdot m^k \\ & = & n^i \cdot m^{n + 1} + \sum_{j = 0}^{i - 1} {i \choose j} \cdot (-1)^{i - j} \cdot f(j) \\ \end{eqnarray*}\)

这个算法的复杂度是 $\mathcal O(m^2)$ 的,但是这题最快可以做到 $\mathcal O(m)$ 的! 下面我们先给出刚刚的 $\mathcal O(m^2)$ 算法的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <cstdio>
 
const long long mod_const = 1000000007LL;
long long comb[1001][1001];
long long f[1001];
 
void init_combination(int n)
{
    comb[0][0] = 1;
    for(int i = 1; i <= n; ++i)
    {
        comb[i][0] = 1;
        comb[i][i] = 1;
        for(int j = 1; j != i; ++j)
        {
            comb[i][j] = comb[i - 1][j] + comb[i - 1][j - 1];
            comb[i][j] %= mod_const;
        }
    }
}
 
long long power(long long base, int p)
{
    long long v = 1;
    while(p)
    {
        if(p & 1) v = v * base % mod_const;
        base = base * base % mod_const;
        p >>= 1;
    }
    return v;
}
 
long long calc_sum(int n, int m)
{
    long long invert = power(m - 1, mod_const - 2);
    f[0] = ((power(m, n + 1) - 1) * invert - 1) % mod_const;
    if(f[0] < 0) f[0] += mod_const;
    for(int i = 1; i <= m; ++i)
    {
        long long t = 0;
        for(int j = 0; j != i; ++j)
        {
            int sign = ((i ^ j) & 1) ? -1 : 1;
            t = (t + comb[i][j] * sign * f[j]) % mod_const;
        }
 
        f[i] = (t + power(n, i) * power(m, n + 1)) % mod_const;
        f[i] = f[i] * invert % mod_const;
        if(f[i] < 0) f[i] += mod_const;
    }
    return f[m];
}
 
int main()
{
    int N, M;
    std::scanf("%d %d", &N, &M);
    init_combination(M);
    if(M == 1) std::printf("%lld", (long long)(N + 1) * N / 2 % mod_const);
    else std::printf("%lld", calc_sum(N, M));
    return 0;
}

我们现在记题目要求的和式为 $F_m(n)$,首先我们可以把 $m$ 比较小的时候的通项列出来试试看, \(\begin{eqnarray*} F_1(n) &=& \frac{1}{2}1^n(n^2+n) \\ F_2(n) &=& 2^n(2n^2-4n+6) - 6 \\ F_3(n) &=& \frac{3}{8}\left [ 3^n(4n^3-6n^2+12n-11) + 11 \right ] \\ F_4(n) &=& \frac{4}{81}\left [ 4^n(27n^4-36n^3+90n^2-132n+95) - 95 \right ] \end{eqnarray*}\)

我们发现当 $m > 1$ 的时候 $F_m(n)$ 一定有这样的形式: \(F_m(n) = m^n P_m(n) - P_m(0)\)

其中 $P_m(n)$ 是一个 $m$ 次多项式,于是只要求出 $P_m(0), P_m(1), \cdots, P_m(n)$ 就可以用这篇文章的方法在 $\mathcal O(m)$ 的时间内计算出 $P_m(n)$ 了! 计算 $F_m(n + 1) - F_m(n)$ 可以得到 $P_m$ 的递推式 \(\begin{eqnarray*} m^{n+1}(n+1)^m &=& m^{n+1}P_m(n+1) - m^nP_m(n) \\ P_m(n+1) &=& \frac{P_m(n)}{m} + (n+1)^m \end{eqnarray*}\)

然后现在我们可以将 $P_m(1), P_m(2), \cdots, P_m(m + 1)$ 都表示成 $A\cdot P_m(0) + B$,一共得到 $m + 1$ 个方程,为了得到 $P_m(0)$ 还缺少一个方程? 我们利用上面所说的那篇文章最后的结论 \(P_m(x) = \sum_{j=0}^m (-1)^{m - j}{x \choose j}{ {x - j - 1} \choose {m - j}} P_m(j)\) 当 $x > m$ 的时候这是成立的没有问题,于是,我们令 $x = m + 1$ 可以得到 \(\begin{eqnarray*} P_m(m + 1) &=& \sum_{j=0}^m (-1)^{m - j}{ {m + 1} \choose j}{ {m - j} \choose {m - j}} P_m(j) \\ 0 &=& \sum_{j=0}^{m+1} (-1)^{m - j}{ {m + 1} \choose j} P_m(j)\\ \end{eqnarray*}\)

这就是我们需要的第 $n + 2$ 个方程!然后就可以解出来 $P_m(0)$ 了!然后剩下的就是根据上面文章的方法计算出答案 一些小细节我在这里说一下,因为你是需要计算 $k^n$,这一部分实际上是可以线性时间内预处理的,大概做法是这样,对于每个数,如果是质数,那么我们用快速幂 $\mathcal O(\log n)$ 计算,如果不是质数,那么找出它的一个质因子,然后拆成两份已经计算过的比它小的数相乘可以 $\mathcal O(1)$ 计算,由于质数个数是 $\mathcal O(\frac{n}{\ln n})$ 级别的,因此总复杂度是 $\mathcal O(m)$,然后质数我们可以用线性筛法预处理出来

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include <cstdio>

const int MaxM = 500010;
const long long mod_v = 1000000007;

int n, m, ptot;
long long P[MaxM], A[MaxM], B[MaxM], L[MaxM], R[MaxM];
long long inv[MaxM], fac[MaxM], inv_fac[MaxM], pw[MaxM];
int prime[MaxM], not_prime[MaxM];

long long pow(long long x, long long p)
{
	long long v = 1;
	for(; p; p >>= 1, x = x * x % mod_v)
		if(p & 1) v = x * v % mod_v;
	return v;
}

void linear_sieve(int n)
{
	pw[1] = 1;
	for(int i = 2; i <= n; ++i)
	{
		if(!not_prime[i])
		{
			prime[ptot++] = i;
			pw[i] = pow(i, m);
		}

		for(int j = 0; j != ptot; ++j)
		{
			int t = prime[j] * i;
			if(t > n) break;
			not_prime[t] = 1;
			pw[t] = pw[i] * pw[prime[j]] % mod_v;
			if(i % prime[j] == 0) 
				break;
		}
	}
}

long long comb(int r, int k)
{
	return fac[r] * inv_fac[k] % mod_v * inv_fac[r - k] % mod_v;
}

void prework()
{
	inv[1] = 1;
	for(int i = 2; i <= m + 1; ++i)
		inv[i] = mod_v - mod_v / i * inv[mod_v % i] % mod_v;

	fac[0] = inv_fac[0] = 1;
	for(int i = 1; i <= m + 1; ++i)
	{
		fac[i] = fac[i - 1] * i % mod_v;
		inv_fac[i] = inv_fac[i - 1] * inv[i] % mod_v;
	}

	A[0] = 1, B[0] = 0;
	for(int i = 0; i <= m; ++i)
	{
		A[i + 1] = A[i] * inv[m] % mod_v;
		B[i + 1] = (B[i] * inv[m] + pw[i + 1]) % mod_v;
	}

	// R*P_m(0) + K = 0
	long long R = 0, K = 0;
	for(int i = 0; i <= m + 1; ++i)
	{
		long long coeffi = comb(m + 1, i) % mod_v;
		if(i & 1) coeffi = -coeffi;
		R = (R + coeffi * A[i]) % mod_v;
		K = (K + coeffi * B[i]) % mod_v;
	}

	P[0] = -K * pow(R, mod_v - 2) % mod_v;

	for(int i = 1; i <= m; ++i)
		P[i] = (A[i] * P[0] + B[i]) % mod_v;
}

long long solve()
{
	L[0] = n - m, R[0] = n;
	for(int i = 1; i <= m; ++i)
	{
		R[i] = R[i - 1] * (n - i) % mod_v;
		L[i] = L[i - 1] * (n - m + i) % mod_v;
	}

	long long p = 0;
	for(int i = 0; i <= m; ++i)
	{
		long long coeffi = inv_fac[i] * inv_fac[m - i] % mod_v;
		if(i != m) coeffi = coeffi * L[m - i - 1] % mod_v;
		if(i) coeffi = coeffi * R[i - 1] % mod_v;
		if((m ^ i) & 1) coeffi = -coeffi;
		p = (p + coeffi * P[i]) % mod_v;
	}

	long long ans = (pow(m, n) * p - P[0]) % mod_v;
	return (ans + mod_v) % mod_v;
}

int main()
{
	long long ans;
	std::scanf("%d %d", &n, &m);
	linear_sieve(m + 1);
	if(m == 1) 
	{
		ans = n * (n + 1ll) % mod_v * pow(2, mod_v - 2) % mod_v;
	} else if(n <= m) {
		ans = 0;
		long long exp = m;
		for(int i = 1; i <= n; ++i, exp = exp * m % mod_v)
			ans = (ans + exp * pw[i]) % mod_v;
	} else {
		prework();
		ans = solve();
	}

	std::printf("%lld\n", ans);
	return 0;
}