BZOJ-3509. [CodeChef] COUNTARI

Posted by miskcoo on April 27, 2015

给定一个长度为 $n ~(1 \leq n \leq 10^5)$ 的数组 $a_i~(0 \leq a_i \leq 3\cdot 10^4)$,问有多少对 $(i, j, k)$ 满足 $i < j < k$ 且 $a_i - a_j = a_j - a_k$

也就是统计有多少对长度为 $3$ 的等差子序列(题目连接 BZOJ-3509, CodeChef COUNTARI

例如 3 5 3 6 3 4 10 4 5 2,就有 (1, 3, 5), (1, 6, 9), (1, 8, 9), (3, 6, 9), (3, 8, 9), (5, 6, 9), (5, 8, 9), (4, 6, 10), (4, 8, 10) 一共 9 对

假设现在枚举 $j$,就变成要寻找有多少对 $(i, k)$ 满足 $2a_j=a_i+a_k$,如果不考虑 $i < j < k$ 这个限制,也就是说 $a_i$ 和 $a_k$ 可以随便在数组中选两个,并且加起来要是 $2a_j$

直接一个个去算 $j$ 的话再枚举 $(i, k)$ 复杂度十分高,大概是 $\mathcal O(n^3)$,肯定是不能做到。但是回头想想,现在要做的操作是从 $a$ 中取出一个数,再从 $a$ 中取出一个数,组合起来。这会想到母函数或者多项式乘法,因为多项式乘法的结果就是从每个多项式中取出一项,组合起来(指数相加)

所以现在可以构造 $a$ 的母函数(如果不太清楚的可以先做这题

\[A(x) = x^2+3x^3+2x^4+2x^5+x^6+x^{10}\]

那么将 $A(x)$ 平方后就可以得到要求的答案,$2a_j$ 那一项的系数就是 $(i, k)$ 的方案数,至于多项式乘法,只要用 FFT 优化就好了,关于 FFT 可以看这里

然后现在来考虑存在 $i < j < k$ 这个限制的情况,还是可以枚举 $j$,然后计算 $(i, k)$,你同样可以构造出母函数,只不过这回要构造两个母函数,一个表示 $j$ 左边的序列,一个表示 $j$ 右边的序列,因为要求从左右两边选出 $(i, k)$。例如说 $j = 4, a_j = 6$,它左边序列的母函数就是 $L(x) = 2x^3+x^5$,右边序列的母函数是 $R(x)=x^2+x^3+2x^4+x^5+x^{10}$,然后求出 $L(x)R(x)$ 中 $2a_j$ 项的系数就是答案,但是这样还不如暴力,因为你要求 $n$ 次 FFT,复杂度 $\mathcal O(n^2\log n)$

但是这样是可以优化的,因为一次 FFT 只求一个点的值太浪费了,我们可以一次性求出一段区间的值,然后对于在区间内的 $(i, j, k)$ 暴力求解,也就是对于区间 $[L, R]$,用 FFT 求出满足 $i < L, L \leq j \leq R, R < k$ 的方案数,再暴力求出区间内的方案数,如果设块的大小是 $\mathcal O(\sqrt{n\log n})$,最终复杂度大约会是 $\mathcal O(n\sqrt{n\log n})$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

const long long mod_v = 17ll * (1 << 27) + 1;
const int MaxN = 100010, MaxA = 1 << 16;
int max_v, A[MaxN], cntl[MaxA], cntr[MaxA];
long long X[MaxA], Y[MaxA], eps[MaxA], inv_eps[MaxA];

long long calc_block(int n, int block_size)
{
	long long ans = 0;

	for(int i = 0; i < n; ++i)
		++cntr[A[i]];

	for(int i = 0; i < n; i += block_size)
	{
		int r = std::min(n, i + block_size) - 1;
		for(int j = i; j <= r; ++j)
			--cntr[A[j]];
		for(int j = i; j <= r; ++j)
		{
			for(int k = j + 1; k <= r; ++k)
			{
				int z = 2 * A[j] - A[k];
				if(z >= 0) ans += cntl[z];
				int x = 2 * A[k] - A[j];
				if(x >= 0) ans += cntr[x];
			}

			++cntl[A[j]];
		}
	}
	return ans;
}

long long power(long long x, long long p)
{
	long long v = 1;
	while(p)
	{
		if(p & 1) v = x * v % mod_v;
		x = x * x % mod_v;
		p >>= 1;
	}

	return v;
}

void init_eps(int n)
{
	long long base = power(3, (mod_v - 1) / n);
	long long inv_base = power(base, mod_v - 2);
	eps[0] = 1, inv_eps[0] = 1;
	for(int i = 1; i < n; ++i)
	{
		eps[i] = eps[i - 1] * base % mod_v;
		inv_eps[i] = inv_eps[i - 1] * inv_base % mod_v;
	}
}

long long inc(long long x, long long d) 
{
	x += d; 
	return x >= mod_v ? x - mod_v : x; 
}

long long dec(long long x, long long d) 
{
	x -= d; 
	return x < 0 ? x + mod_v : x; 
}

void transform(int p, int n, long long *x, long long *w)
{
	for(int i = 0, j = 0; i != n; ++i)
	{
		if(i > j) std::swap(x[i], x[j]);
		for(int l = n >> 1; (j ^= l) < l; l >>= 1);
	}

	for(int i = 2; i <= n; i <<= 1)
	{
		int m = i >> 1;
		for(int j = 0; j < n; j += i)
		{
			for(int k = 0; k != m; ++k)
			{
				long long z = x[j + m + k] * w[p / i * k] % mod_v;
				x[j + m + k] = dec(x[j + k], z);
				x[j + k] = inc(x[j + k], z);
			}
		}
	}
}

long long solve_other(int n, int block_size)
{
	int p = 1;
	while(p < max_v) p <<= 1;
	p <<= 1;
	init_eps(p);

	long long ans = 0, inv = power(p, mod_v - 2);
	for(int i = 0; i < n; i += block_size)
	{
		int m = 0;
		int r = std::min(n, i + block_size) - 1;
		for(int j = 0; j != i; ++j)
			++X[A[j]], m = std::max(m, A[j]);
		for(int j = r + 1; j < n; ++j)
			++Y[A[j]], m = std::max(m, A[j]);
		int len = 1;
		while(len < m) len <<= 1;
		len <<= 1;
		transform(p, len, X, eps);
		transform(p, len, Y, eps);
		for(int j = 0; j != len; ++j)
			X[j] = X[j] * Y[j] % mod_v;
		transform(p, len, X, inv_eps);
		for(int j = 0; j != len; ++j)
			X[j] = inc(X[j] * inv % mod_v, mod_v);
		for(int j = i; j <= r; ++j)
			ans += X[A[j] << 1];
		std::memset(X, 0, sizeof(long long) * len);
		std::memset(Y, 0, sizeof(long long) * len);
	}

	return ans;
}

int main()
{
	int n;
	std::scanf("%d", &n);
	for(int i = 0; i != n; ++i)
	{
		std::scanf("%d", A + i);
		if(A[i] > max_v) max_v = A[i];
	}

	int block_size = 8 * (int)sqrt(n);
	if(block_size > n) block_size = n;
	long long ans = 0;
	ans += calc_block(n, block_size);
	ans += solve_other(n, block_size);
	std::printf("%lld", ans);
	return 0;
}