这篇文章将会介绍如何求解区间 K 小问题。

无修改、无插入区间第K小问题

关于这题,可以直接用主席树(函数式线段树)来解决。假设现在有无限的内存和时间,可以对每个前缀构造一棵权值线段树,然后每次询问区间 [l, r] 就找出前缀 [0, l - 1] 和 [0, r] 的线段树,相减就可以得到所需要的权值线段树,然后在这上面根据子树大小跑就可以了。

由于空间有限,最开始不初始化整棵树,而是动态的添加节点,并且在插入的时候只有经过的那一条链上的东西会被更改,所以对于一些没有修改的子树,直接用指针指向上一次的版本,因为每一次插入最多经过 $\mathcal O(\log n)$ 个节点,这样总共空间不会超过 $\mathcal O(n\log n)$,同样,时间复杂度也是每次 $\mathcal O(\log n)$。

参考问题:BZOJ-2588: Spoj 10628. Count on a tree

#include <cstdio>
#include <algorithm>

const int MaxN = 100010, MaxL = 17, MaxV = ~0u >> 1;
int n, m, total;
int head[MaxN], next[MaxN << 1], point[MaxN << 1], weight[MaxN];
int que[MaxN], fa[MaxN], depth[MaxN], dist[MaxL][MaxN];

template<typename Type>
class memory_pool
{
	Type* data[1 << 16];
	int used, remain;
public:
	memory_pool() : used(0), remain(1 << 16) {}
	~memory_pool() 
	{
		for(int i = 0; i != used; ++i)
			delete[] data[i];
	}

	Type* fetch() 
	{
		if(remain + 1 < 1 << 16) 
			return data[used - 1] + ++remain;
		data[used++] = new Type[1 << 16];
		return data[used - 1] + (remain = 0);
	}
};

struct node_t
{
	int w;
	node_t *l, *r;
};

node_t *nil, *root[MaxN], *tmp[4];
memory_pool<node_t> mem;

void add_edge(int u, int v)
{
	point[++total] = v;
	next[total] = head[u];
	head[u] = total;
}

node_t* modify(node_t* now, unsigned head, unsigned tail, unsigned pos)
{
	node_t* n = mem.fetch();
	*n = *now; ++n->w;
	if(head == tail)
		return n;

	unsigned m = (head + tail) >> 1;
	if(pos <= m) n->l = modify(n->l, head, m, pos);
	else n->r = modify(n->r, m + 1, tail, pos);
	return n;
}

void solve_father()
{
	int qhead = 0, qtail = 0;
	que[qtail++] = 1;
	depth[1] = 1, fa[1] = 0;
	root[0] = nil;
	while(qhead != qtail)
	{
		int u = que[qhead++];
		root[u] = modify(root[fa[u]], 0, MaxV, weight[u]);
		for(int k = head[u]; k; k = next[k])
		{
			int v = point[k];
			if(v == fa[u]) continue;
			fa[v] = u;
			depth[v] = depth[u] + 1;
			que[qtail++] = v;
		}
	}
}

void init_lca()
{
	for(int i = 1; i <= n; ++i)
		dist[0][i] = fa[i];
	for(int l = 1; l != MaxL; ++l)
		for(int i = 1; i <= n; ++i)
			dist[l][i] = dist[l - 1][dist[l - 1][i]];
}

int get_lca(int u, int v)
{
	if(depth[u] > depth[v])
		std::swap(u, v);

	int diff = depth[v] - depth[u];
	for(int i = 0; diff; diff >>= 1, ++i)
		if(diff & 1) v = dist[i][v];

	if(u == v) return u;

	for(int p = MaxL - 1; u != v; p ? --p : 0)
	{
		if(dist[p][u] != dist[p][v] || p == 0)
		{
			v = dist[p][v];
			u = dist[p][u];
		}
	}

	return u;
}

int ask(unsigned head, unsigned tail, int k)
{
	if(head == tail) return head;

	unsigned m = (head + tail) >> 1;
	int num = tmp[0]->l->w + tmp[1]->l->w 
		- tmp[2]->l->w - tmp[3]->l->w;
	if(k <= num)
	{
		for(int i = 0; i != 4; ++i)
			tmp[i] = tmp[i]->l;
		return ask(head, m, k);
	} else {
		for(int i = 0; i != 4; ++i)
			tmp[i] = tmp[i]->r;
		return ask(m + 1, tail, k - num);
	}
}

int main()
{
	std::scanf("%d %d", &n, &m);
	for(int i = 1; i <= n; ++i)
		std::scanf("%d", weight + i);
	for(int i = 1; i != n; ++i)
	{
		int u, v;
		std::scanf("%d %d", &u, &v);
		add_edge(u, v);
		add_edge(v, u);
	}

	nil = mem.fetch();
	nil->l = nil->r = nil, nil->w = 0;
	solve_father();
	init_lca();

	int lastans = 0;
	for(int i = 0; i != m; ++i)
	{
		int x, y, k, lca;
		std::scanf("%d %d %d", &x, &y, &k);
		x ^= lastans;
		lca = get_lca(x, y);
		tmp[0] = root[x], tmp[1] = root[y];
		tmp[2] = root[lca], tmp[3] = root[fa[lca]];
		lastans = ask(0, MaxV, k);
		std::printf("%d", lastans);
		if(i + 1 != m) std::puts("");
	}
	return 0;
}

带修改、无插入区间第K小问题

现在需要修改权值,如果暴力修改每棵线段树的话,最多会要修改 $\mathcal O(n)$ 级别的东西,复杂度也会退化。这时候想想,我们需要的操作是修改一个区间的线段树和询问一个区间的线段树的信息,也就是区间修改和区间查询,很容易会想到树状数组!它能够把每个区间用 $\mathcal O(\log n)$ 个区间来表示出来。所以只要在每个树状数组的节点建立可以权值线段树,表示这个区间的权值信息就可以完成所需要的操作了,时间和空间复杂度都是 $\mathcal O(n\log^2 n)$。

参考问题:BZOJ-1901: Zju2112 Dynamic Rankings

#include <cstdio>
#include <cstring>
#include <algorithm>

template<typename Type>
class memory_pool
{
	Type* data[1 << 16];
	int used, remain;
public:
	memory_pool() : used(0), remain(1 << 16) {}
	~memory_pool() 
	{
		for(int i = 0; i != used; ++i)
			delete[] data[i];
	}

	Type* fetch() 
	{
		if(remain != 1 << 16) return data[used - 1] + ++remain;
		data[used++] = new Type[1 << 16];
		return data[used - 1] + (remain = 0);
	}
};

struct node_t
{
	int w;
	node_t *l, *r;
};

const int MaxN = 20001;
memory_pool<node_t> mem;
int size, data[MaxN], map[MaxN], oper[MaxN][4];
node_t *left[MaxN], *right[MaxN], *root[MaxN], *nil;

node_t* update(node_t* n, int head, int tail, int pos, int val)
{
	if(n == nil)
	{
		n = mem.fetch();
		n->l = n->r = nil;
		n->w = val;
	} else n->w += val;

	if(head == tail) return n;
	int m = (head + tail) >> 1;
	if(pos <= m) n->l = update(n->l, head, m, pos, val);
	else n->r = update(n->r, m + 1, tail, pos, val);
	return n;
}

void modify(int n, int x, int pos, int val)
{
	for(; x <= n; x += x & -x)
		root[x] = update(root[x], 1, size, pos, val);
}

int ask(int lc, int rc, int head, int tail, int k)
{
	if(head == tail) return head;

	int l = 0, r = 0;
	for(int i = 0; i != lc; ++i) l += left[i]->l->w;
	for(int i = 0; i != rc; ++i) r += right[i]->l->w;

	int m = (head + tail) >> 1;
	if(k <= r - l)
	{
		for(int i = 0; i != lc; ++i) left[i] = left[i]->l;
		for(int i = 0; i != rc; ++i) right[i] = right[i]->l;
		return ask(lc, rc, head, m, k);
	} else {
		for(int i = 0; i != lc; ++i) left[i] = left[i]->r;
		for(int i = 0; i != rc; ++i) right[i] = right[i]->r;
		return ask(lc, rc, m + 1, tail, k - (r - l));
	}
}

int get_ans(int l, int r, int k)
{
	int lc = 0, rc = 0;
	for(--l; l; l -= l & -l) left[lc++] = root[l];
	for(; r; r -= r & -r) right[rc++] = root[r];
	return data[ask(lc, rc, 1, size, k)];
}

int main()
{
	nil = mem.fetch();
	nil->l = nil->r = nil, nil->w = 0;
	
	int n, m, tot = 0;
	std::scanf("%d %d", &n, &m);
	for(int i = 1; i <= n; ++i)
		std::scanf("%d", data + i);

	tot = n;
	for(int i = 0; i != m; ++i)
	{
		char op[2];
		std::scanf("%s", op);
		if(op[0] == 'C')
		{
			oper[i][0] = 1;
			std::scanf("%d %d", oper[i] + 1, oper[i] + 2);
			data[++tot] = oper[i][2];
		} else {
			oper[i][0] = 0;
			std::scanf("%d %d %d", oper[i] + 1, oper[i] + 2, oper[i] + 3);
		}
	}

	std::memcpy(map, data, sizeof(data));
	std::sort(data + 1, data + tot + 1);
	size = std::unique(data + 1, data + tot + 1) - data - 1;
	for(int i = 1; i <= tot; ++i)
		map[i] = std::lower_bound(data + 1, data + size + 1, map[i]) - data;

	for(int i = 1; i <= n; ++i)
		root[i] = nil;
	for(int i = 1; i <= n; ++i)
		modify(n, i, map[i], 1);
	for(int i = 0; i != m; ++i)
	{
		if(oper[i][0])
		{
			int pos = std::lower_bound(data + 1, data + size + 1, oper[i][2]) - data;
			int p = oper[i][1];
			modify(n, p, map[p], -1);
			modify(n, p, map[p] = pos, 1);

		} else {
			std::printf("%d\n", get_ans(oper[i][1], oper[i][2], oper[i][3]));
		}
	}
	return 0;
}

带修改、带插入区间第K小问题

现在我们需要支持可以在某个位置插入一个值,比如原先的序列是 1 2 3,然后在第二个位置插入 4 的话,序列就会变成 1 4 2 3。

因为树状数组不支持插入元素,但是类似 Problem 2 的思想,你或许会想到使用平衡树,它也在 $\mathcal O(\log n)$ 的时间内支持区间修改、查询,而且还支持插入元素。也就是平衡树上的节点就表示整棵子树中信息的权值线段树,但是仔细想就会发现,平衡树基本都是需要旋转来维持平衡的,这样一转,对权值线段树修改的时间复杂度就会迅速增大!

替罪羊树(Scapegoat Tree)是一种很神奇的平衡树,它不需要旋转来保持平衡,而是有一个平衡因子 $\alpha \in [0.5, 1]$ 每次插入如果发现某棵子树太不平衡,就把整棵子树暴力重构成完全二叉树。具体来说也就是如果某棵子树满足 $\alpha \cdot \text{size} > \max (\text{left_size}, \text{right_size})$,就认为是不平衡的,然后就重构它。

当 $\alpha = 0.5$ 的时候,这棵树就是完全二叉树,当 $\alpha = 1$ 的时候就永远不会重构,很显然,$\alpha$ 越小,询问的效率越高,$\alpha$ 越大,插入的效率越高。可以证明,替罪羊树的查询效率是 $\mathcal O(\log n)$,而插入均摊 $\mathcal O(\log n)$。 这样,这个问题只需要用替罪羊树再套上权值线段树就可以解决了 但是有一个问题就是你需要动态分配内存,并且在重构的时候要及时释放无用内存。并且在写权值线段树的时候不要写成函数式线段树那样,否则内存会爆(我就不小心写成这样被坑了好久)

参考问题:BZOJ-3065: 带插入区间K小值

#include <cstdio>
#include <algorithm>

const int MaxN = 70010, MaxV = 70000;

struct seg_t
{
	int w;
	seg_t *l, *r;
};

seg_t *seg_nil;

void destroy(seg_t* now)
{
	if(now == seg_nil) return;
	destroy(now->l);
	destroy(now->r);
	delete now;
}

seg_t* insert(seg_t* n, int l, int r, int pos, int v)
{
	if(n->w + v == 0) 
	{
		if(n != seg_nil)
			destroy(n);
		return seg_nil;
	}

	if(n == seg_nil)
	{
		n = new seg_t;
		n->l = n->r = seg_nil;
		n->w = v;
	} else n->w += v;
	if(l == r) return n;

	int m = (l + r) >> 1;
	if(pos <= m) n->l = insert(n->l, l, m, pos, v);
	else n->r = insert(n->r, m + 1, r, pos, v);
	return n;
}

struct scap_t
{
	int size, val;
	scap_t *l, *r;
	seg_t *seg;
};

scap_t *scap_nil, *root, *rebuild_node, *rebuild_fa;
double alpha;
int record_num, seg_num;
int value[MaxN], record[MaxN];
seg_t* seg_record[MaxN];

void destroy(scap_t* now)
{
	if(now == scap_nil) return;
	destroy(now->seg);
	destroy(now->l);
	record[++record_num] = now->val;
	destroy(now->r);
	delete now;
}

scap_t* scap_build(int l, int r)
{
	if(l > r) return scap_nil;

	scap_t* n = new scap_t;
	if(l == r) 
	{
		n->size = 1;
		n->val = record[l];
		n->l = n->r = scap_nil;
		n->seg = insert(seg_nil, 0, MaxV, record[l], 1);
		return n;
	}

	int m = (l + r) >> 1;
	n->val = record[m];
	n->l = scap_build(l, m - 1);
	n->r = scap_build(m + 1, r);
	n->size = n->l->size + n->r->size + 1;
	n->seg = seg_nil;
	for(int i = l; i <= r; ++i)
		n->seg = insert(n->seg, 0, MaxV, record[i], 1);
	return n;
}

scap_t* scap_rebuild(scap_t* now)
{
	record_num = 0;
	destroy(now);
	return scap_build(1, record_num);
}

int scap_modify(scap_t* now, int pos, int v)
{
	int old_val = 0;
	int sz = now->l->size;
	if(sz + 1 == pos)
	{
		old_val = now->val;
		now->val = v;
	} else if(pos <= sz) {
		old_val = scap_modify(now->l, pos, v);
	} else {
		old_val = scap_modify(now->r, pos - sz - 1, v);
	}

	now->seg = insert(now->seg, 0, MaxV, old_val, -1);
	now->seg = insert(now->seg, 0, MaxV, v, 1);
	return old_val;
}

scap_t* scap_insert(scap_t* now, int pos, int v)
{
	if(now == scap_nil)
	{
		scap_t *n = new scap_t;
		n->val = v, n->size = 1;
		n->l = n->r = scap_nil;
		n->seg = insert(seg_nil, 0, MaxV, v, 1);
		return n;
	}

	now->seg = insert(now->seg, 0, MaxV, v, 1);
	int sz = now->l->size;
	if(pos <= sz) now->l = scap_insert(now->l, pos, v);
	else now->r = scap_insert(now->r, pos - sz - 1, v);
	now->size = now->l->size + now->r->size + 1;
	if(now->size * alpha < std::max(now->l->size, now->r->size))
		rebuild_node = now;
	if(now->l == rebuild_node || now->r == rebuild_node)
		rebuild_fa = now;
	return now;
}

void scap_query(scap_t* now, int l, int r)
{
	if(l > r) return;
	int lsz = now->l->size + 1, sz = now->size;
	if(l == 1 && r == sz)
	{
		seg_record[seg_num++] = now->seg;
		return;
	}

	if(l <= lsz && r >= lsz)
		record[record_num++] = now->val;

	if(r < lsz) 
	{
		scap_query(now->l, l, r);
	} else if(l > lsz) { 
		scap_query(now->r, l - lsz, r - lsz);
	} else {
		scap_query(now->l, l, lsz - 1);
		scap_query(now->r, 1, r - lsz);
	}
}

int find_kth(int l, int r, int k)
{
	seg_num = record_num = 0;
	scap_query(root, l, r);
	int low = 0, high = MaxV;
	while(low < high)
	{
		int mid = (low + high) >> 1, sum = 0;
		for(int i = 0; i != seg_num; ++i)
			sum += seg_record[i]->l->w;
		for(int i = 0; i != record_num; ++i)
			sum += record[i] >= low && record[i] <= mid;
		if(k <= sum)
		{
			high = mid;
			for(int i = 0; i != seg_num; ++i)
				seg_record[i] = seg_record[i]->l;
		} else {
			for(int i = 0; i != seg_num; ++i)
				seg_record[i] = seg_record[i]->r;
			k -= sum;
			low = mid + 1;
		}
	}

	return low;
}

void init()
{
	static seg_t seg_nil_base;
	seg_nil = &seg_nil_base;
	seg_nil->w = 0;
	seg_nil->l = seg_nil->r = seg_nil;

	static scap_t scap_nil_base;
	scap_nil = &scap_nil_base;
	scap_nil->l = scap_nil->r = scap_nil;
	scap_nil->seg = seg_nil;
	scap_nil->size = 0;

	alpha = 0.8;
}

int main()
{
	int n;
	std::scanf("%d", &n);
	for(int i = 1; i <= n; ++i)
		std::scanf("%d", record + i);

	init();
	root = scap_build(1, n);

	int m, lastans = 0;
	std::scanf("%d", &m);
	for(int i = 0; i != m; ++i)
	{
		char op[2];
		std::scanf("%s", op);
		if(*op == 'Q')
		{
			int x, y, k;
			std::scanf("%d %d %d", &x, &y, &k);
			x ^= lastans, y ^= lastans, k ^= lastans;
			std::printf("%d\n", lastans = find_kth(x, y, k));
		} else {
			int x, val;
			std::scanf("%d %d", &x, &val);
			x ^= lastans, val ^= lastans;
			if(*op == 'M') 
			{
				scap_modify(root, x, val);
			} else {
				rebuild_fa = 0;
				rebuild_node = 0;
				root = scap_insert(root, x - 1, val);
				if(rebuild_node)
				{
					if(rebuild_node == root) root = scap_rebuild(root);
					else if(rebuild_fa->l == rebuild_node)
						rebuild_fa->l = scap_rebuild(rebuild_node);
					else rebuild_fa->r = scap_rebuild(rebuild_node);
				}
			} 
		}
	}

	destroy(root);
	return 0;
}