题目给出 $n$ 个字符串 $s_1, s_2, \cdots, s_n$,以及 $m$ 个询问 $q_1, q_2, \cdots, q_m$,每个询问是一个字符串,要求计算 $q_i$ 在多少个 $s_1, s_2, \cdots, s_n$ 中出现过

例如 $s_1=abcabc, s_2=ca, s_3=a$,那么 $abc$ 在 $s_1$ 中出现过,$a$ 在 $s_1, s_2, s_3$ 中出现过

数据范围:$1 \leq n \leq 10^4, 1\leq q \leq 6\cdot 10^4, \sum_{i=0}^n |s_i| \leq 10^5, \sum_{i=0}^m |q_i| \leq 3.6 \cdot 10^5$

题目连接:BZOJ-2780SPOJ-JZPGYZ

这题是给出多个母串,询问给定串在多少个母串中出现过,并且允许离线

考虑先给 $s_1, s_2, \cdots, s_n$ 建立广义后缀自动机,关于建立方法可以参见《2015年国家队论文》中刘研绎的《后缀自动机在字典树上的拓展》

然后对于某个询问 $q_i$ 来说,将其放在自动机上运行(注意不能走 parent 边),最后如果走到某个状态 T,那么,如果 $q_i$ 在 $s_j$ 出现过,顺着 $s_j$ 路径上某个节点的 parent 往上走一定会走到 T

所以问题就变成了,在 parent 树中询问 T 节点的子树中有多少个不同的串,可以先求出 parent 树的 DFS 序,这样子树就变成连续的一段区间,问题变成询问某段区间有多少个不同的数字(注意因为广义后缀自动机上一个节点可能代表多个串,所以每个节点都要用链表记录下有多少个串)

这个问题在线的话可以用可持久化线段树 $\mathcal O(n\log n)$ 来做,但是这题允许离线,可以用莫队算法来解决,但是复杂度是 $\mathcal O(n \sqrt n)$,这里有一种用树状数组解决的办法,可以做到 $\mathcal O(n\log n)$ 的时间解决

首先可以将询问按照右端点排序,然后从左往右处理每个节点,现在考虑一个数字出现在两个不同位置,由于询问是按照右端点排序的,也就是出现在左边的那个数能够影响到的节点出现在右边的那个数一定能够影响到,所以当处理到某个点的时候,将这个点的数字上一次出现的位置的值减 1,这个位置的值加 1,然后处理所有右端点在当前位置的询问(这个方法可以先去做这题,它就是询问某个区间内有多少个不同数字)

到此这题解决,时间复杂度是 $\mathcal (m\log m + m\log n + n |\sum|)$,$|\sum|$ 是字符集大小

#include <cstdio>
#include <algorithm>

const int MaxN = 200010, MaxQ = 360000, MaxAlpha = 26;
struct node_t
{
	int len;
	node_t *fa, *ch[MaxAlpha];
} node[MaxN];

struct graph_t
{
	int total;
	int head[MaxN], point[MaxN * 10], next[MaxN * 10];
	
	void add_edge(int u, int v)
	{
		point[++total] = v;
		next[total] = head[u];
		head[u] = total;
	}
} g, s;

int used = 1;
char str[MaxQ];
node_t *sam_head = node + used++, *sam_tail = sam_head;

void sam_extend(int x)
{
	node_t *p = sam_tail;
	if(p->ch[x])
	{
		if(p->len + 1 == p->ch[x]->len)
		{
			sam_tail = p->ch[x];
		} else {
			node_t *q = p->ch[x], *r = node + used++;
			*r = *q; q->fa = r;
			r->len = p->len + 1;
			for(; p && p->ch[x] == q; p = p->fa)
				p->ch[x] = r;
			sam_tail = r;
		}
	} else {
		node_t *n = node + used++;
		n->len = p->len + 1;
		sam_tail = n;
		for(; p && !p->ch[x]; p = p->fa)
			p->ch[x] = n;
		if(!p)
		{
			n->fa = sam_head;
		} else {
			if(p->len + 1 == p->ch[x]->len)
			{
				n->fa = p->ch[x];
			} else {
				node_t *q = p->ch[x], *r = node + used++;
				*r = *q; 
				n->fa = q->fa = r;
				r->len = p->len + 1;
				for(; p && p->ch[x] == q; p = p->fa)
					p->ch[x] = r;
			}
		}
	}
}

struct tree_array
{
	int size, ta[MaxN];

	void modify(int x, int v)
	{
		for(; x <= size; x += x & -x)
			ta[x] += v;
	}

	int ask(int x)
	{
		int v = 0;
		for(; x; x -= x & -x)
			v += ta[x];
		return v;
	}
} ta;

struct ques_t
{
	int id, l, r, ans;
	bool operator < (const ques_t& q) const
	{
		return r < q.r;
	}
} ques[MaxQ];

int dfn_index, cnt[MaxN], ans[MaxN];
int enter[MaxN], leave[MaxN], dfn[MaxN];
void dfs(int u, int fa)
{
	dfn[++dfn_index] = u;
	enter[u] = dfn_index;
	for(int k = g.head[u]; k; k = g.next[k])
		if(g.point[k] != fa) dfs(g.point[k], u);
	leave[u] = dfn_index;
}

int main()
{
	int n, q;
	std::scanf("%d %d", &n, &q);
	for(int i = 0; i != n; ++i)
	{
		std::scanf("%s", str);
		sam_tail = sam_head;
		for(int j = 0; str[j]; ++j)
		{
			sam_extend(str[j] - 'a');
			s.add_edge(sam_tail - node, i);
		}
	}

	for(int i = 2; i != used; ++i)
		g.add_edge(node[i].fa - node, i);
	dfs(1, 0);

	for(int i = 0; i != q; ++i)
	{
		std::scanf("%s", str);
		node_t *now = sam_head;
		bool matched = true;
		for(int j = 0; str[j]; ++j)
		{
			int x = str[j] - 'a';
			if(!now->ch[x])
			{
				matched = false;
				break;
			} else {
				now = now->ch[x];
			}
		}

		ques[i].id = i;
		if(!matched)
		{
			ques[i].ans = 0;
			ques[i].l = -1;
		} else {
			int u = now - node;
			ques[i].l = enter[u];
			ques[i].r = leave[u];
		}
	}

	std::sort(ques, ques + q);

	int now = 0;
	while(now != q && ques[now].l == -1) ++now;
	ta.size = dfn_index;
	for(int i = 1; i <= dfn_index && now != q; ++i)
	{
		for(int k = s.head[dfn[i]]; k; k = s.next[k])
		{
			int v = s.point[k];
			if(cnt[v]) ta.modify(cnt[v], -1);
			ta.modify(cnt[v] = i, 1);
		}

		for(; now != q && ques[now].r == i; ++now)
			ques[now].ans = ta.ask(ques[now].r) - ta.ask(ques[now].l - 1);
	}

	for(int i = 0; i != q; ++i)
		ans[ques[i].id] = ques[i].ans;
	for(int i = 0; i != q; ++i)
		std::printf("%d\n", ans[i]);
	return 0;
}