BZOJ-3509. [CodeChef] COUNTARI
给定一个长度为 $n ~(1 \leq n \leq 10^5)$ 的数组 $a_i~(0 \leq a_i \leq 3\cdot 10^4)$,问有多少对 $(i, j, k)$ 满足 $i < j < k$ 且 $a_i - a_j = a_j - a_k$
也就是统计有多少对长度为 $3$ 的等差子序列(题目连接 BZOJ-3509, CodeChef COUNTARI)
例如 3 5 3 6 3 4 10 4 5 2,就有 (1, 3, 5), (1, 6, 9), (1, 8, 9), (3, 6, 9), (3, 8, 9), (5, 6, 9), (5, 8, 9), (4, 6, 10), (4, 8, 10) 一共 9 对
假设现在枚举 $j$,就变成要寻找有多少对 $(i, k)$ 满足 $2a_j=a_i+a_k$,如果不考虑 $i < j < k$ 这个限制,也就是说 $a_i$ 和 $a_k$ 可以随便在数组中选两个,并且加起来要是 $2a_j$
直接一个个去算 $j$ 的话再枚举 $(i, k)$ 复杂度十分高,大概是 $\mathcal O(n^3)$,肯定是不能做到。但是回头想想,现在要做的操作是从 $a$ 中取出一个数,再从 $a$ 中取出一个数,组合起来。这会想到母函数或者多项式乘法,因为多项式乘法的结果就是从每个多项式中取出一项,组合起来(指数相加)
所以现在可以构造 $a$ 的母函数(如果不太清楚的可以先做这题)
\[A(x) = x^2+3x^3+2x^4+2x^5+x^6+x^{10}\]那么将 $A(x)$ 平方后就可以得到要求的答案,$2a_j$ 那一项的系数就是 $(i, k)$ 的方案数,至于多项式乘法,只要用 FFT 优化就好了,关于 FFT 可以看这里
然后现在来考虑存在 $i < j < k$ 这个限制的情况,还是可以枚举 $j$,然后计算 $(i, k)$,你同样可以构造出母函数,只不过这回要构造两个母函数,一个表示 $j$ 左边的序列,一个表示 $j$ 右边的序列,因为要求从左右两边选出 $(i, k)$。例如说 $j = 4, a_j = 6$,它左边序列的母函数就是 $L(x) = 2x^3+x^5$,右边序列的母函数是 $R(x)=x^2+x^3+2x^4+x^5+x^{10}$,然后求出 $L(x)R(x)$ 中 $2a_j$ 项的系数就是答案,但是这样还不如暴力,因为你要求 $n$ 次 FFT,复杂度 $\mathcal O(n^2\log n)$
但是这样是可以优化的,因为一次 FFT 只求一个点的值太浪费了,我们可以一次性求出一段区间的值,然后对于在区间内的 $(i, j, k)$ 暴力求解,也就是对于区间 $[L, R]$,用 FFT 求出满足 $i < L, L \leq j \leq R, R < k$ 的方案数,再暴力求出区间内的方案数,如果设块的大小是 $\mathcal O(\sqrt{n\log n})$,最终复杂度大约会是 $\mathcal O(n\sqrt{n\log n})$
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
const long long mod_v = 17ll * (1 << 27) + 1;
const int MaxN = 100010, MaxA = 1 << 16;
int max_v, A[MaxN], cntl[MaxA], cntr[MaxA];
long long X[MaxA], Y[MaxA], eps[MaxA], inv_eps[MaxA];
long long calc_block(int n, int block_size)
{
long long ans = 0;
for(int i = 0; i < n; ++i)
++cntr[A[i]];
for(int i = 0; i < n; i += block_size)
{
int r = std::min(n, i + block_size) - 1;
for(int j = i; j <= r; ++j)
--cntr[A[j]];
for(int j = i; j <= r; ++j)
{
for(int k = j + 1; k <= r; ++k)
{
int z = 2 * A[j] - A[k];
if(z >= 0) ans += cntl[z];
int x = 2 * A[k] - A[j];
if(x >= 0) ans += cntr[x];
}
++cntl[A[j]];
}
}
return ans;
}
long long power(long long x, long long p)
{
long long v = 1;
while(p)
{
if(p & 1) v = x * v % mod_v;
x = x * x % mod_v;
p >>= 1;
}
return v;
}
void init_eps(int n)
{
long long base = power(3, (mod_v - 1) / n);
long long inv_base = power(base, mod_v - 2);
eps[0] = 1, inv_eps[0] = 1;
for(int i = 1; i < n; ++i)
{
eps[i] = eps[i - 1] * base % mod_v;
inv_eps[i] = inv_eps[i - 1] * inv_base % mod_v;
}
}
long long inc(long long x, long long d)
{
x += d;
return x >= mod_v ? x - mod_v : x;
}
long long dec(long long x, long long d)
{
x -= d;
return x < 0 ? x + mod_v : x;
}
void transform(int p, int n, long long *x, long long *w)
{
for(int i = 0, j = 0; i != n; ++i)
{
if(i > j) std::swap(x[i], x[j]);
for(int l = n >> 1; (j ^= l) < l; l >>= 1);
}
for(int i = 2; i <= n; i <<= 1)
{
int m = i >> 1;
for(int j = 0; j < n; j += i)
{
for(int k = 0; k != m; ++k)
{
long long z = x[j + m + k] * w[p / i * k] % mod_v;
x[j + m + k] = dec(x[j + k], z);
x[j + k] = inc(x[j + k], z);
}
}
}
}
long long solve_other(int n, int block_size)
{
int p = 1;
while(p < max_v) p <<= 1;
p <<= 1;
init_eps(p);
long long ans = 0, inv = power(p, mod_v - 2);
for(int i = 0; i < n; i += block_size)
{
int m = 0;
int r = std::min(n, i + block_size) - 1;
for(int j = 0; j != i; ++j)
++X[A[j]], m = std::max(m, A[j]);
for(int j = r + 1; j < n; ++j)
++Y[A[j]], m = std::max(m, A[j]);
int len = 1;
while(len < m) len <<= 1;
len <<= 1;
transform(p, len, X, eps);
transform(p, len, Y, eps);
for(int j = 0; j != len; ++j)
X[j] = X[j] * Y[j] % mod_v;
transform(p, len, X, inv_eps);
for(int j = 0; j != len; ++j)
X[j] = inc(X[j] * inv % mod_v, mod_v);
for(int j = i; j <= r; ++j)
ans += X[A[j] << 1];
std::memset(X, 0, sizeof(long long) * len);
std::memset(Y, 0, sizeof(long long) * len);
}
return ans;
}
int main()
{
int n;
std::scanf("%d", &n);
for(int i = 0; i != n; ++i)
{
std::scanf("%d", A + i);
if(A[i] > max_v) max_v = A[i];
}
int block_size = 8 * (int)sqrt(n);
if(block_size > n) block_size = n;
long long ans = 0;
ans += calc_block(n, block_size);
ans += solve_other(n, block_size);
std::printf("%lld", ans);
return 0;
}