给定一个长度为 $n ~(1 \leq n \leq 10^5)$ 的数组 $a_i~(0 \leq a_i \leq 3\cdot 10^4)$,问有多少对 $(i, j, k)$ 满足 $i < j < k$ 且 $a_i - a_j = a_j - a_k$

也就是统计有多少对长度为 $3$ 的等差子序列(题目连接 BZOJ-3509, CodeChef COUNTARI

例如 3 5 3 6 3 4 10 4 5 2,就有 (1, 3, 5), (1, 6, 9), (1, 8, 9), (3, 6, 9), (3, 8, 9), (5, 6, 9), (5, 8, 9), (4, 6, 10), (4, 8, 10) 一共 9 对

假设现在枚举 $j$,就变成要寻找有多少对 $(i, k)$ 满足 $2a_j=a_i+a_k$,如果不考虑 $i < j < k$ 这个限制,也就是说 $a_i$ 和 $a_k$ 可以随便在数组中选两个,并且加起来要是 $2a_j$

直接一个个去算 $j$ 的话再枚举 $(i, k)$ 复杂度十分高,大概是 $\mathcal O(n^3)$,肯定是不能做到。但是回头想想,现在要做的操作是从 $a$ 中取出一个数,再从 $a$ 中取出一个数,组合起来。这会想到母函数或者多项式乘法,因为多项式乘法的结果就是从每个多项式中取出一项,组合起来(指数相加)

所以现在可以构造 $a$ 的母函数(如果不太清楚的可以先做这题

\[A(x) = x^2+3x^3+2x^4+2x^5+x^6+x^{10}\]

那么将 $A(x)$ 平方后就可以得到要求的答案,$2a_j$ 那一项的系数就是 $(i, k)$ 的方案数,至于多项式乘法,只要用 FFT 优化就好了,关于 FFT 可以看这里

然后现在来考虑存在 $i < j < k$ 这个限制的情况,还是可以枚举 $j$,然后计算 $(i, k)$,你同样可以构造出母函数,只不过这回要构造两个母函数,一个表示 $j$ 左边的序列,一个表示 $j$ 右边的序列,因为要求从左右两边选出 $(i, k)$。例如说 $j = 4, a_j = 6$,它左边序列的母函数就是 $L(x) = 2x^3+x^5$,右边序列的母函数是 $R(x)=x^2+x^3+2x^4+x^5+x^{10}$,然后求出 $L(x)R(x)$ 中 $2a_j$ 项的系数就是答案,但是这样还不如暴力,因为你要求 $n$ 次 FFT,复杂度 $\mathcal O(n^2\log n)$

但是这样是可以优化的,因为一次 FFT 只求一个点的值太浪费了,我们可以一次性求出一段区间的值,然后对于在区间内的 $(i, j, k)$ 暴力求解,也就是对于区间 $[L, R]$,用 FFT 求出满足 $i < L, L \leq j \leq R, R < k$ 的方案数,再暴力求出区间内的方案数,如果设块的大小是 $\mathcal O(\sqrt{n\log n})$,最终复杂度大约会是 $\mathcal O(n\sqrt{n\log n})$

#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

const long long mod_v = 17ll * (1 << 27) + 1;
const int MaxN = 100010, MaxA = 1 << 16;
int max_v, A[MaxN], cntl[MaxA], cntr[MaxA];
long long X[MaxA], Y[MaxA], eps[MaxA], inv_eps[MaxA];

long long calc_block(int n, int block_size)
{
	long long ans = 0;

	for(int i = 0; i < n; ++i)
		++cntr[A[i]];

	for(int i = 0; i < n; i += block_size)
	{
		int r = std::min(n, i + block_size) - 1;
		for(int j = i; j <= r; ++j)
			--cntr[A[j]];
		for(int j = i; j <= r; ++j)
		{
			for(int k = j + 1; k <= r; ++k)
			{
				int z = 2 * A[j] - A[k];
				if(z >= 0) ans += cntl[z];
				int x = 2 * A[k] - A[j];
				if(x >= 0) ans += cntr[x];
			}

			++cntl[A[j]];
		}
	}
	return ans;
}

long long power(long long x, long long p)
{
	long long v = 1;
	while(p)
	{
		if(p & 1) v = x * v % mod_v;
		x = x * x % mod_v;
		p >>= 1;
	}

	return v;
}

void init_eps(int n)
{
	long long base = power(3, (mod_v - 1) / n);
	long long inv_base = power(base, mod_v - 2);
	eps[0] = 1, inv_eps[0] = 1;
	for(int i = 1; i < n; ++i)
	{
		eps[i] = eps[i - 1] * base % mod_v;
		inv_eps[i] = inv_eps[i - 1] * inv_base % mod_v;
	}
}

long long inc(long long x, long long d) 
{
	x += d; 
	return x >= mod_v ? x - mod_v : x; 
}

long long dec(long long x, long long d) 
{
	x -= d; 
	return x < 0 ? x + mod_v : x; 
}

void transform(int p, int n, long long *x, long long *w)
{
	for(int i = 0, j = 0; i != n; ++i)
	{
		if(i > j) std::swap(x[i], x[j]);
		for(int l = n >> 1; (j ^= l) < l; l >>= 1);
	}

	for(int i = 2; i <= n; i <<= 1)
	{
		int m = i >> 1;
		for(int j = 0; j < n; j += i)
		{
			for(int k = 0; k != m; ++k)
			{
				long long z = x[j + m + k] * w[p / i * k] % mod_v;
				x[j + m + k] = dec(x[j + k], z);
				x[j + k] = inc(x[j + k], z);
			}
		}
	}
}

long long solve_other(int n, int block_size)
{
	int p = 1;
	while(p < max_v) p <<= 1;
	p <<= 1;
	init_eps(p);

	long long ans = 0, inv = power(p, mod_v - 2);
	for(int i = 0; i < n; i += block_size)
	{
		int m = 0;
		int r = std::min(n, i + block_size) - 1;
		for(int j = 0; j != i; ++j)
			++X[A[j]], m = std::max(m, A[j]);
		for(int j = r + 1; j < n; ++j)
			++Y[A[j]], m = std::max(m, A[j]);
		int len = 1;
		while(len < m) len <<= 1;
		len <<= 1;
		transform(p, len, X, eps);
		transform(p, len, Y, eps);
		for(int j = 0; j != len; ++j)
			X[j] = X[j] * Y[j] % mod_v;
		transform(p, len, X, inv_eps);
		for(int j = 0; j != len; ++j)
			X[j] = inc(X[j] * inv % mod_v, mod_v);
		for(int j = i; j <= r; ++j)
			ans += X[A[j] << 1];
		std::memset(X, 0, sizeof(long long) * len);
		std::memset(Y, 0, sizeof(long long) * len);
	}

	return ans;
}

int main()
{
	int n;
	std::scanf("%d", &n);
	for(int i = 0; i != n; ++i)
	{
		std::scanf("%d", A + i);
		if(A[i] > max_v) max_v = A[i];
	}

	int block_size = 8 * (int)sqrt(n);
	if(block_size > n) block_size = n;
	long long ans = 0;
	ans += calc_block(n, block_size);
	ans += solve_other(n, block_size);
	std::printf("%lld", ans);
	return 0;
}