题目大意是要求下面这个式子
这个题目有三个版本:
这篇文章介绍 和
两种做法
为了方便,定义一个函数
然后使用“扰动法“
然后这个算法的复杂度是 的,但是这题最快可以做到
的!
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
#include <cstdio> const long long mod_const = 1000000007LL; long long comb[1001][1001]; long long f[1001]; void init_combination(int n) { comb[0][0] = 1; for(int i = 1; i <= n; ++i) { comb[i][0] = 1; comb[i][i] = 1; for(int j = 1; j != i; ++j) { comb[i][j] = comb[i - 1][j] + comb[i - 1][j - 1]; comb[i][j] %= mod_const; } } } long long power(long long base, int p) { long long v = 1; while(p) { if(p & 1) v = v * base % mod_const; base = base * base % mod_const; p >>= 1; } return v; } long long calc_sum(int n, int m) { long long invert = power(m - 1, mod_const - 2); f[0] = ((power(m, n + 1) - 1) * invert - 1) % mod_const; if(f[0] < 0) f[0] += mod_const; for(int i = 1; i <= m; ++i) { long long t = 0; for(int j = 0; j != i; ++j) { int sign = ((i ^ j) & 1) ? -1 : 1; t = (t + comb[i][j] * sign * f[j]) % mod_const; } f[i] = (t + power(n, i) * power(m, n + 1)) % mod_const; f[i] = f[i] * invert % mod_const; if(f[i] < 0) f[i] += mod_const; } return f[m]; } int main() { int N, M; std::scanf("%d %d", &N, &M); init_combination(M); if(M == 1) std::printf("%lld", (long long)(N + 1) * N / 2 % mod_const); else std::printf("%lld", calc_sum(N, M)); return 0; } |
我们现在记题目要求的和式为 ,首先我们可以把
比较小的时候的通项列出来试试看,
我们发现当 的时候
一定有这样的形式:
其中 是一个
次多项式,于是只要求出
就可以用这篇文章的方法在
的时间内计算出
了!
计算 可以得到
的递推式
然后现在我们可以将 都表示成
,一共得到
个方程,为了得到
还缺少一个方程?
我们利用上面所说的那篇文章最后的结论
当 的时候这是成立的没有问题,于是,我们令
可以得到
这就是我们需要的第 个方程!然后就可以解出来
了!然后剩下的就是根据上面文章的方法计算出答案
一些小细节我在这里说一下,因为你是需要计算 ,这一部分实际上是可以线性时间内预处理的,大概做法是这样,对于每个数,如果是质数,那么我们用快速幂
计算,如果不是质数,那么找出它的一个质因子,然后拆成两份已经计算过的比它小的数相乘可以
计算,由于质数个数是
级别的,因此总复杂度是
,然后质数我们可以用线性筛法预处理出来
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
#include <cstdio> const int MaxM = 500010; const long long mod_v = 1000000007; int n, m, ptot; long long P[MaxM], A[MaxM], B[MaxM], L[MaxM], R[MaxM]; long long inv[MaxM], fac[MaxM], inv_fac[MaxM], pw[MaxM]; int prime[MaxM], not_prime[MaxM]; long long pow(long long x, long long p) { long long v = 1; for(; p; p >>= 1, x = x * x % mod_v) if(p & 1) v = x * v % mod_v; return v; } void linear_sieve(int n) { pw[1] = 1; for(int i = 2; i <= n; ++i) { if(!not_prime[i]) { prime[ptot++] = i; pw[i] = pow(i, m); } for(int j = 0; j != ptot; ++j) { int t = prime[j] * i; if(t > n) break; not_prime[t] = 1; pw[t] = pw[i] * pw[prime[j]] % mod_v; if(i % prime[j] == 0) break; } } } long long comb(int r, int k) { return fac[r] * inv_fac[k] % mod_v * inv_fac[r - k] % mod_v; } void prework() { inv[1] = 1; for(int i = 2; i <= m + 1; ++i) inv[i] = mod_v - mod_v / i * inv[mod_v % i] % mod_v; fac[0] = inv_fac[0] = 1; for(int i = 1; i <= m + 1; ++i) { fac[i] = fac[i - 1] * i % mod_v; inv_fac[i] = inv_fac[i - 1] * inv[i] % mod_v; } A[0] = 1, B[0] = 0; for(int i = 0; i <= m; ++i) { A[i + 1] = A[i] * inv[m] % mod_v; B[i + 1] = (B[i] * inv[m] + pw[i + 1]) % mod_v; } // R*P_m(0) + K = 0 long long R = 0, K = 0; for(int i = 0; i <= m + 1; ++i) { long long coeffi = comb(m + 1, i) % mod_v; if(i & 1) coeffi = -coeffi; R = (R + coeffi * A[i]) % mod_v; K = (K + coeffi * B[i]) % mod_v; } P[0] = -K * pow(R, mod_v - 2) % mod_v; for(int i = 1; i <= m; ++i) P[i] = (A[i] * P[0] + B[i]) % mod_v; } long long solve() { L[0] = n - m, R[0] = n; for(int i = 1; i <= m; ++i) { R[i] = R[i - 1] * (n - i) % mod_v; L[i] = L[i - 1] * (n - m + i) % mod_v; } long long p = 0; for(int i = 0; i <= m; ++i) { long long coeffi = inv_fac[i] * inv_fac[m - i] % mod_v; if(i != m) coeffi = coeffi * L[m - i - 1] % mod_v; if(i) coeffi = coeffi * R[i - 1] % mod_v; if((m ^ i) & 1) coeffi = -coeffi; p = (p + coeffi * P[i]) % mod_v; } long long ans = (pow(m, n) * p - P[0]) % mod_v; return (ans + mod_v) % mod_v; } int main() { long long ans; std::scanf("%d %d", &n, &m); linear_sieve(m + 1); if(m == 1) { ans = n * (n + 1ll) % mod_v * pow(2, mod_v - 2) % mod_v; } else if(n <= m) { ans = 0; long long exp = m; for(int i = 1; i <= n; ++i, exp = exp * m % mod_v) ans = (ans + exp * pw[i]) % mod_v; } else { prework(); ans = solve(); } std::printf("%lld\n", ans); return 0; } |
扰动法是什么qwq