后缀数组(Suffix Array)学习笔记

感觉 OI-wiki 在相关方面讲得很牛啊!拜谢 OI-wiki 的后缀数组部分


一些约定

字符串 \(S\) 的长度为 \(|S|\),是下标从 \(1\) 开始(到 \(|S|\))的字符串。特别地,若无特殊说明,约定字符串 \(s\) 的长度为 \(n\)

后缀 \(i\) 代表以 \(i\) 开头的后缀。

若无特殊说明,则两个字符串 \(s_1\)\(s_2\) 的大小关系就是它们字典序的大小关系。

后缀数组是什么?

后缀数组(Suffix Array),可以直接表示字符串 \(s\) 的后缀的字典序大小关系。

定义 \(\text{sa}_i\) 表示按照字典序排序后,第 \(i\) 大的后缀的开头的下标;定义 \(\text{rk}_i\) 表示后缀 \(i\) 按照字典序排序后的排名。

显然,在求出后缀数组(\(\text{sa}\))后,我们有 \(\text{sa}_{\text{rk}_i} = \text{rk}_{\text{sa}_i} = i\)

如何求出后缀数组?

考虑最暴力的求法:将 \(s\) 的所有后缀存储起来,并直接 sort 起来。容易得到,该做法的时间复杂度是 \(O(n^2 \log n)\) 的。

考虑优化它:只要我们是基于比较的排序(即使用重定义 cmpsort),复杂度中的 \(O(n \log n)\) 就无法去掉。考虑如何优化比较两个字符串 \(s_1\)\(s_2\) 的大小关系。一个想法是,判断 \(s_1\)\(s_2\) 的大小关系只需要找到它们第一个不相等的字符,并直接比较即可。我们可以用二分 + 哈希的方式做这件事。该做法的时间复杂度为 \(O(n \log^2 n)\)。可能再使用一些手法可以做到 \(O(n \log n)\) 地求出后缀数组,但我不会,所以这里先不说。以后记得补一下

这种 \(O(n \log^2 n)\) 做法的代码:

点击查看代码
#include <bits/stdc++.h>
#define ll long long
#define mid (l + r + 1 >> 1)

using namespace std;

constexpr int N = 1e6 + 5;
constexpr int base1 = 31, mod1 = 998244353;
constexpr int base2 = 171, mod2 = 998244853;

int n;
int sa[N];
ll p1[N], p2[N], hs1[N], hs2[N];
char s[N];

bool check(int l1, int r1, int l2, int r2) {
	int hsx1 = (hs1[r1] - hs1[l1 - 1] * p1[r1 - l1 + 1] % mod1 + mod1) % mod1;
	int hsx2 = (hs2[r1] - hs2[l1 - 1] * p2[r1 - l1 + 1] % mod2 + mod2) % mod2;
	int hsy1 = (hs1[r2] - hs1[l2 - 1] * p1[r2 - l2 + 1] % mod1 + mod1) % mod1;
	int hsy2 = (hs2[r2] - hs2[l2 - 1] * p2[r2 - l2 + 1] % mod2 + mod2) % mod2;
	if (hsx1 == hsy1 && hsx2 == hsy2) {
		return true;
	} else {
		return false;
	}
}

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	
	cin >> s + 1;
	n = strlen(s + 1);
	
	p1[0] = p2[0] = 1ll;
	for (int i = 1; i <= n; ++i) {
		p1[i] = p1[i - 1] * base1 % mod1;
		p2[i] = p2[i - 1] * base2 % mod2;
		hs1[i] = (hs1[i - 1] * base1 + s[i]) % mod1;
		hs2[i] = (hs2[i - 1] * base2 + s[i]) % mod2;
	}
	
	for (int i = 1; i <= n; ++i) {
		sa[i] = i;
	}
	
	sort(sa + 1, sa + n + 1, [](int x, int y) {
		int len1 = n - x + 1, len2 = n - y + 1;
		int l = 0, r = min(len1, len2);
		while (l < r) {
			if (check(x, x + mid - 1, y, y + mid - 1)) l = mid;
			else r = mid - 1;
		}
		if (l == min(len1, len2)) {
			return x > y;
		} else {
			return s[x + l] < s[y + l];
		}
	});
	
	for (int i = 1; i <= n; ++i) {
		cout << sa[i] << " \n"[i == n];
	}
	
	return 0;
}

提交记录,可以看到哈希的常数和较劣的复杂度是无法通过本题的。

upd:好像基于比较的排序没有前途了 /ll

upd2:但是 wyd 讲了一个只利用哈希做到 \(O(n \log n)\) 求后缀数组的做法!该做法好像来源于 zak,拜谢。

我们首先可以将 \(s\) 的长度补全到 \(2\) 的次幂(设为 \(k\))来简化接下来的考虑(即,在 \(s\) 后面加入空字符使 \(n = 2^k\))。考虑首先处理出所有 \(n\) 个后缀的长度为 \(2^{k - 1}\) 的前缀的哈希值。这时,\(n\) 个后缀之间的关系就可以分为「哈希值相同」的部分和「哈希值不同」的部分。

对于哈希值相同的后缀,我们显然只需要比较它们长度不大于 \(2^{k - 1}\) 的后半部分的字典序即可。对于不同的每个哈希值,我们都只保留一个串进行接下来的比较。

递归地做这件事情。

复杂度证明,不会。

我怎么这么菜啊呜呜,如何变得可以自己想出这个的复杂度证明阿?


上面的 \(O(n \log^2 n)\) 做法好想也好写,不过复杂度还是稍微劣了一些,接下来将说一说最普遍的 \(O(n \log n)\) 的后缀数组求法。

我们可以认为 \(s\) 后面接着足够多的空字符(假定空字符的字典序最小),那么 \(n\) 个后缀的长度也可以被补全为一样的。

考虑对 \(n\) 个字符排序的过程。我们首先对每个后缀的第一个字符进行排序,设此时的 \(\text{rk}_i\) 表示 \(s_i\) 这个字符的排名,直接按照 \(\text{rk}_i\) 作为权值排序即可完成这一步。

接下来,我们当然可以继续做比较第 \(2\) 个字符的过程,但我们实际上可以优化它。考虑使用倍增的思想,求出新的 \(\text{rk}_i\) 表示 \(s_{[i, i + 1]}\) 的排名,那么我们再次按照 \(\text{rk}_{i}\) 排序即可,而求新的 \(\text{rk}_i\) 的过程又是 \(O(n)\) 的。一般地,设 \(\text{rk}_i\) 表示 \(s_{[i, i + 2^k - 1]}\) 的排名,并按照它排序计算 \(s_{[i, i + 2^{k + 1} - 1]} 的排名\)。该做法的时间复杂度为 \(O(n \log^2 n)\)

代码(注意因为我们对 \(s\) 进行了长度上的“补全”,所以 \(\text{rk}\) 要开到 \(|s|\) 的两倍):

点击查看代码
#include <bits/stdc++.h>

using namespace std;

constexpr int N = 1e6 + 5;

int n, w;
int sa[N], rk[N << 1], rk2[N << 1];
char s[N];

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	
	cin >> s + 1;
	n = strlen(s + 1);
	
	for (int i = 1; i <= n; ++i) {
		sa[i] = i;
		rk[i] = s[i];
	}
	
	for (w = 1; w < n; w <<= 1) {
		sort(sa + 1, sa + n + 1, [](int x, int y) {
			if (rk[x] == rk[y]) {
				return rk[x + w] < rk[y + w];
			} else {
				return rk[x] < rk[y];
			}
		});
		
		for (int i = 1; i <= n; ++i) {
			rk2[i] = rk[i];
		}
		
		for (int p = 0, i = 1; i <= n; ++i) {
			if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
				rk[sa[i]] = p;
			} else {
				rk[sa[i]] = ++p;
			}
		}
	}
	
	for (int i = 1; i <= n; ++i) {
		cout << sa[i] << " \n"[i == n];
	}
	
	return 0;
}

而优化这个做法是不困难的:复杂度瓶颈在于求新的 \(\text{sa}\) 的排序而非求 \(\text{rk}\) 的过程。我们的排序只参考了 \(\text{rk}\) 数组,那么直接使用基数排序替换掉 sort 即可。

注意基数排序和计数排序的区别:计数排序是直接把所有数放到值域的桶里面,基数排序是对于多个关键字,按较低到较高优先级的关键字依次排序。

代码:

点击查看代码
#include <bits/stdc++.h>

using namespace std;

constexpr int N = 1e6 + 5;

int n;
int sa[N], sa2[N], buc[N], rk[N << 1], rk2[N << 1];
char s[N];

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	
	cin >> s + 1;
	n = strlen(s + 1);
	
	int sz = max(n, 127);
	for (int i = 1; i <= n; ++i) ++buc[rk[i] = s[i]];
	for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
	for (int i = 1; i <= n; ++i) sa[buc[rk[i]]--] = i;
	
	for (int w = 1; w < n; w <<= 1) {
		for (int i = 0; i <= sz; ++i) buc[i] = 0;
		for (int i = 1; i <= n; ++i) sa2[i] = sa[i];
		for (int i = 1; i <= n; ++i) ++buc[rk[sa2[i] + w]];
		for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
		for (int i = n; i > 0; --i) sa[buc[rk[sa2[i] + w]]--] = sa2[i];
		
		for (int i = 0; i <= sz; ++i) buc[i] = 0;
		for (int i = 1; i <= n; ++i) sa2[i] = sa[i];
		for (int i = 1; i <= n; ++i) ++buc[rk[sa2[i]]];
		for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
		for (int i = n; i > 0; --i) sa[buc[rk[sa2[i]]]--] = sa2[i];
		
		for (int i = 1; i <= n; ++i) rk2[i] = rk[i];
		for (int cur = 0, i = 1; i <= n; ++i) {
			if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
				rk[sa[i]] = cur;
			} else {
				rk[sa[i]] = ++cur;
			}
		}
	}
	
	for (int i = 1; i <= n; ++i) {
		cout << sa[i] << " \n"[i == n];
	}
	
	return 0;
}

当然我们还能再进行一些常数上的优化,得到更为实用的版本。

  1. 化简对第二关键字(\(s_{[i + 2^k, i + 2^{k + 1} - 1]}\) 部分)进行的排序过程。

考虑对第二关键字进行排序的实质:对于 \(i = 1\)\(n\),按照 \(\text{rk}_{i + 2^k - 1}\) 进行排序,最终的结果必定是把最后的 \(n\) 个(后缀 \(i - n + 1\) 到后缀 \(n\))放到最前面,并把其余的部分按照原顺序(倍增到 \(k\) 层时后半部分的排名顺序)往后平移。

  1. 在每次倍增结束后动态更新值域。

  2. \(\text{rk}\) 数组内任意两项的值都不相同,则排序过程必定已经结束,可以停止接下来的操作。

代码:

点击查看代码
#include <bits/stdc++.h>

using namespace std;

constexpr int N = 1e6 + 5;

int n;
int buc[N], sa[N], sa2[N], rk[N << 1], rk2[N << 1];
char s[N];

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	
	cin >> s + 1;
	n = strlen(s + 1);
	
	int sz = 127;
	for (int i = 1; i <= n; ++i) ++buc[rk[i] = s[i]];
	for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
	for (int i = 1; i <= n; ++i) sa[buc[rk[i]]--] = i;
	
	for (int w = 1; w < n; w <<= 1) {
		int cur = 0;
		for (int i = n - w + 1; i <= n; ++i) sa2[++cur] = i;
		for (int i = 1; i <= n; ++i) {
			if (sa[i] > w) {
				sa2[++cur] = sa[i] - w;
			}
		}
		
		for (int i = 0; i <= sz; ++i) buc[i] = 0;
		for (int i = 1; i <= n; ++i) ++buc[rk[i]];
		for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
		for (int i = n; i > 0; --i) sa[buc[rk[sa2[i]]]--] = sa2[i];
		
		cur = 0;
		for (int i = 1; i <= n; ++i) rk2[i] = rk[i];
		for (int i = 1; i <= n; ++i) {
			if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
				rk[sa[i]] = cur;
			} else {
				rk[sa[i]] = ++cur;
			}
		}
		sz = cur;
		
		if (cur == n) {
			break;
		}
	}
	
	for (int i = 1; i <= n; ++i) {
		cout << sa[i] << " \n"[i == n];
	}
	
	return 0;
}

\(\text{height}\) 数组

定义 \(h_i\) 表示字典序第 \(i - 1\) 大的后缀与第 \(i\) 大的后缀的 \(\text{LCP}\)(Longest Common Prefix,最长公共前缀)长度,\(h_1 = 0\)

现有结论:\(h_{\text{rk}_i} \ge h_{\text{rk}_{i - 1}} - 1\)

证明:当 \(h_{\text{rk}_{i - 1}} \le 1\) 时显然成立。考虑 \(h_{\text{rk}_{i - 1}} > 1\) 的情况。此时,字典序恰好比后缀 \(i - 1\)\(1\) 的后缀去掉其第一个字符后,新的后缀与后缀 \(i\)\(\text{LCP}\) 长度必定为 \(h_{\text{rk}_{i - 1}} - 1\)。而字典序比后缀 \(i\)\(1\) 的后缀必定包含该 \(\text{LCP}\)。故结论成立。

于是从前往后扫描原串并依据上述结论暴力计算,即可在 \(O(n)\) 的时间复杂度内求出 \(h\) 数组。代码如下:

for (int i = 1, w = 0; i <= n; ++i) {
	if (rk[i] == 1) {
		continue;
	}
	w = max(0, w - 1);
	while (s[i + w] == s[sa[rk[i] - 1] + w]) {
		++w;
	}
	h[rk[i]] = w;
}

使用 \(h\) 数组可以求不同子串数目等信息。

P2870 [USACO07DEC] Best Cow Line G

容易发现这题本质是在比较原串与翻转串的大小关系。

显然可以哈希,但你也可以将原串翻转后接回去,并建后缀数组做。

P4248 [AHOI2013] 差异

一个结论是 \(\text{LCP}(t_i, t_j) = \min\limits_{k = i + 1}^{j} h_k\),直接利用该结论做即可。

P4248 [AHOI2013] 差异

考虑经典套路,在两个字符串间插入特殊字符并拼到一起,此时原先的两个字符串可能会有多的贡献。把这部分贡献统计一下并减掉即可。用 SA 算贡献的部分是简单的。

P2408 不同子串个数

从增量的角度考虑,按照排名顺序依次加入后缀,每个后缀对答案的贡献是 \(\text{len}_{\text{sa}_i} - h_i\) 的。因此本质不同子串个数为 \(\frac{n \times (n + 1)}{2} - \sum h_i\)

P2178 [NOI2015] 品酒大会

统计方案数是平凡的,统计最大美味度可以通过在单调栈的过程中维护当前区间(栈顶位置到当前位置,显然只会有加入贡献的操作)\(\max\)\(\min\)\(O(n)\) 解决。细节挺多,可能会比较难写。

posted @ 2025-07-04 20:37  zyb_txdy  阅读(26)  评论(0)    收藏  举报