BZOJ-3771. Triple
这题挺有意思的,来源是 SPOJ-TSUM,题意大概是,给出 $n$ 个物品,第 $i$ 个物品的价值为 $x_i(0 < x_i \leq 4\cdot10^4)$,问在这 $n$ 个物品中选出 $1$ 个或 $2$ 个或 $3$ 个价值和是多少,对于每个价值求出方案个数,在这题中 $(a, b)$ 和 $(b, a)$ 算作一种方案
例如有三个物品
$x_1$ | $x_2$ | $x_3$ |
1 | 2 | 3 |
选一个就可以得到 $1, 2, 3$ 三种价值
选两个就可以得到 $1+2, 1+3, 2+3$ 三种价值
选三个就可以得到 $1+2+3$ 一种价值
所以最后答案就是
价值 | 1 | 2 | 3 | 4 | 5 | 6 |
方案数 | 1 | 1 | 2 | 1 | 1 | 1 |
这题可以先用母函数表示出选一个的方案(系数是物品出现次数,指数是物品价值)
\[A(x) = x^1 + x^2 + x^3\]然后由于多项式乘法会将系数相乘,指数相加,将两个这样方案的母函数 $A(x), B(x)$ 相乘就相当于在 $A(x)$ 表示的物品中选出一个再在 $B(x)$ 表示的物品中选出一个,组合起来
所以不考虑重复,在物品中选出三个的方案就是 $\frac{1}{6}A^3(x)$
现在用 $B(x), C(x)$ 分别表示一种物品选了 $2$ 次和 $3$ 次的方案
\[B(x) = x^2 + x^4 + x^6 \\ C(x) = x^3 + x^6 + x^9\]选三个有可能是 AAB, ABA, BAA, AAA 这几种重复情况,所以扣掉后方案就是
\[\frac{A^3(x) - 3A(x)\cdot B(x) + 2C(x)}{6}\]选两个的方案是
\[\frac{A^2(x)-B(x)}{2}\]选一个的方案是
\[A(x)\]加起来后就会得到总的方案
由于这里用到了多项式乘法,用FFT优化即可
/* BZOJ-3771: Triple
* 母函数+FFT */
#include <cstdio>
#include <cmath>
#include <complex>
#include <algorithm>
#include <iostream>
typedef std::complex<double> complex_t;
void shuffle(int n, complex_t *x)
{
for(int i = 0, j = 0; i != n; ++i)
{
if(i > j) std::swap(x[i], x[j]);
for(int l = n >> 1; (j ^= l) < l; l >>= 1);
}
}
void transform(int n, complex_t *x, complex_t *w)
{
shuffle(n, x);
for(int i = 2; i <= n; i <<= 1, ++w)
{
int m = i >> 1;
complex_t *tmp = new complex_t[m];
tmp[0] = 1;
for(int p = 1; p != m; ++p)
tmp[p] = tmp[p - 1] * *w;
for(int p = 0; p < n; p += i)
{
for(int t = 0; t != m; ++t)
{
complex_t z = x[p + m + t] * tmp[t];
x[p + m + t] = x[p + t] - z;
x[p + t] += z;
}
}
delete[] tmp;
}
}
const int MaxL = 17, MaxN = (1 << MaxL) + 10;
complex_t eps[MaxL + 1], eps_inv[MaxL + 1];
complex_t A[MaxN], B[MaxN], C[MaxN];
void init_eps()
{
double pi = std::acos(-1.0);
double angle = 2.0 * pi / (1 << MaxL);
eps[MaxL] = complex_t(std::cos(angle), std::sin(angle));
for(int i = MaxL - 1; i >= 0; --i)
eps[i] = eps[i + 1] * eps[i + 1];
for(int i = 0; i <= MaxL; ++i)
eps_inv[i] = 1.0 / eps[i];
}
int main()
{
init_eps();
int n, max_v = 0;
std::scanf("%d", &n);
for(int i = 0; i != n; ++i)
{
int v;
std::scanf("%d", &v);
A[v] += 1.0;
B[v * 2] += 1.0;
C[v * 3] += 1.0;
if(v * 3 > max_v) max_v = v * 3;
}
int m = 1;
while(m < max_v) m <<= 1;
transform(m, A, eps + 1);
transform(m, B, eps + 1);
for(int i = 0; i != m; ++i)
A[i] = A[i] * A[i] * A[i] / 6.0 + (A[i] * (A[i] - B[i]) - B[i]) / 2.0 + A[i];
transform(m, A, eps_inv + 1);
for(int i = 0; i != m; ++i)
A[i] = A[i] / double(m) + C[i] / 3.0;
for(int i = 0; i != m; ++i)
{
int x = int(A[i].real() + 0.5);
if(x) std::printf("%d %d\n", i, x);
}
return 0;
}