后缀数组(Suffix Array)学习笔记
感觉 OI-wiki 在相关方面讲得很牛啊!拜谢 OI-wiki 的后缀数组部分。
一些约定
字符串 \(S\) 的长度为 \(|S|\),是下标从 \(1\) 开始(到 \(|S|\))的字符串。特别地,若无特殊说明,约定字符串 \(s\) 的长度为 \(n\)。
后缀 \(i\) 代表以 \(i\) 开头的后缀。
若无特殊说明,则两个字符串 \(s_1\) 和 \(s_2\) 的大小关系就是它们字典序的大小关系。
后缀数组是什么?
后缀数组(Suffix Array),可以直接表示字符串 \(s\) 的后缀的字典序大小关系。
定义 \(\text{sa}_i\) 表示按照字典序排序后,第 \(i\) 大的后缀的开头的下标;定义 \(\text{rk}_i\) 表示后缀 \(i\) 按照字典序排序后的排名。
显然,在求出后缀数组(\(\text{sa}\))后,我们有 \(\text{sa}_{\text{rk}_i} = \text{rk}_{\text{sa}_i} = i\)。
如何求出后缀数组?
考虑最暴力的求法:将 \(s\) 的所有后缀存储起来,并直接 sort 起来。容易得到,该做法的时间复杂度是 \(O(n^2 \log n)\) 的。
考虑优化它:只要我们是基于比较的排序(即使用重定义 cmp 的 sort),复杂度中的 \(O(n \log n)\) 就无法去掉。考虑如何优化比较两个字符串 \(s_1\) 和 \(s_2\) 的大小关系。一个想法是,判断 \(s_1\) 和 \(s_2\) 的大小关系只需要找到它们第一个不相等的字符,并直接比较即可。我们可以用二分 + 哈希的方式做这件事。该做法的时间复杂度为 \(O(n \log^2 n)\)。可能再使用一些手法可以做到 \(O(n \log n)\) 地求出后缀数组,但我不会,所以这里先不说。以后记得补一下
这种 \(O(n \log^2 n)\) 做法的代码:
点击查看代码
#include <bits/stdc++.h>
#define ll long long
#define mid (l + r + 1 >> 1)
using namespace std;
constexpr int N = 1e6 + 5;
constexpr int base1 = 31, mod1 = 998244353;
constexpr int base2 = 171, mod2 = 998244853;
int n;
int sa[N];
ll p1[N], p2[N], hs1[N], hs2[N];
char s[N];
bool check(int l1, int r1, int l2, int r2) {
int hsx1 = (hs1[r1] - hs1[l1 - 1] * p1[r1 - l1 + 1] % mod1 + mod1) % mod1;
int hsx2 = (hs2[r1] - hs2[l1 - 1] * p2[r1 - l1 + 1] % mod2 + mod2) % mod2;
int hsy1 = (hs1[r2] - hs1[l2 - 1] * p1[r2 - l2 + 1] % mod1 + mod1) % mod1;
int hsy2 = (hs2[r2] - hs2[l2 - 1] * p2[r2 - l2 + 1] % mod2 + mod2) % mod2;
if (hsx1 == hsy1 && hsx2 == hsy2) {
return true;
} else {
return false;
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
cin >> s + 1;
n = strlen(s + 1);
p1[0] = p2[0] = 1ll;
for (int i = 1; i <= n; ++i) {
p1[i] = p1[i - 1] * base1 % mod1;
p2[i] = p2[i - 1] * base2 % mod2;
hs1[i] = (hs1[i - 1] * base1 + s[i]) % mod1;
hs2[i] = (hs2[i - 1] * base2 + s[i]) % mod2;
}
for (int i = 1; i <= n; ++i) {
sa[i] = i;
}
sort(sa + 1, sa + n + 1, [](int x, int y) {
int len1 = n - x + 1, len2 = n - y + 1;
int l = 0, r = min(len1, len2);
while (l < r) {
if (check(x, x + mid - 1, y, y + mid - 1)) l = mid;
else r = mid - 1;
}
if (l == min(len1, len2)) {
return x > y;
} else {
return s[x + l] < s[y + l];
}
});
for (int i = 1; i <= n; ++i) {
cout << sa[i] << " \n"[i == n];
}
return 0;
}
提交记录,可以看到哈希的常数和较劣的复杂度是无法通过本题的。
upd:好像基于比较的排序没有前途了 /ll
upd2:但是 wyd 讲了一个只利用哈希做到 \(O(n \log n)\) 求后缀数组的做法!该做法好像来源于 zak,拜谢。
我们首先可以将 \(s\) 的长度补全到 \(2\) 的次幂(设为 \(k\))来简化接下来的考虑(即,在 \(s\) 后面加入空字符使 \(n = 2^k\))。考虑首先处理出所有 \(n\) 个后缀的长度为 \(2^{k - 1}\) 的前缀的哈希值。这时,\(n\) 个后缀之间的关系就可以分为「哈希值相同」的部分和「哈希值不同」的部分。
对于哈希值相同的后缀,我们显然只需要比较它们长度不大于 \(2^{k - 1}\) 的后半部分的字典序即可。对于不同的每个哈希值,我们都只保留一个串进行接下来的比较。
递归地做这件事情。
复杂度证明,不会。
我怎么这么菜啊呜呜,如何变得可以自己想出这个的复杂度证明阿?
上面的 \(O(n \log^2 n)\) 做法好想也好写,不过复杂度还是稍微劣了一些,接下来将说一说最普遍的 \(O(n \log n)\) 的后缀数组求法。
我们可以认为 \(s\) 后面接着足够多的空字符(假定空字符的字典序最小),那么 \(n\) 个后缀的长度也可以被补全为一样的。
考虑对 \(n\) 个字符排序的过程。我们首先对每个后缀的第一个字符进行排序,设此时的 \(\text{rk}_i\) 表示 \(s_i\) 这个字符的排名,直接按照 \(\text{rk}_i\) 作为权值排序即可完成这一步。
接下来,我们当然可以继续做比较第 \(2\) 个字符的过程,但我们实际上可以优化它。考虑使用倍增的思想,求出新的 \(\text{rk}_i\) 表示 \(s_{[i, i + 1]}\) 的排名,那么我们再次按照 \(\text{rk}_{i}\) 排序即可,而求新的 \(\text{rk}_i\) 的过程又是 \(O(n)\) 的。一般地,设 \(\text{rk}_i\) 表示 \(s_{[i, i + 2^k - 1]}\) 的排名,并按照它排序计算 \(s_{[i, i + 2^{k + 1} - 1]} 的排名\)。该做法的时间复杂度为 \(O(n \log^2 n)\)。
代码(注意因为我们对 \(s\) 进行了长度上的“补全”,所以 \(\text{rk}\) 要开到 \(|s|\) 的两倍):
点击查看代码
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 1e6 + 5;
int n, w;
int sa[N], rk[N << 1], rk2[N << 1];
char s[N];
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
cin >> s + 1;
n = strlen(s + 1);
for (int i = 1; i <= n; ++i) {
sa[i] = i;
rk[i] = s[i];
}
for (w = 1; w < n; w <<= 1) {
sort(sa + 1, sa + n + 1, [](int x, int y) {
if (rk[x] == rk[y]) {
return rk[x + w] < rk[y + w];
} else {
return rk[x] < rk[y];
}
});
for (int i = 1; i <= n; ++i) {
rk2[i] = rk[i];
}
for (int p = 0, i = 1; i <= n; ++i) {
if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
rk[sa[i]] = p;
} else {
rk[sa[i]] = ++p;
}
}
}
for (int i = 1; i <= n; ++i) {
cout << sa[i] << " \n"[i == n];
}
return 0;
}
而优化这个做法是不困难的:复杂度瓶颈在于求新的 \(\text{sa}\) 的排序而非求 \(\text{rk}\) 的过程。我们的排序只参考了 \(\text{rk}\) 数组,那么直接使用基数排序替换掉 sort 即可。
注意基数排序和计数排序的区别:计数排序是直接把所有数放到值域的桶里面,基数排序是对于多个关键字,按较低到较高优先级的关键字依次排序。
代码:
点击查看代码
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 1e6 + 5;
int n;
int sa[N], sa2[N], buc[N], rk[N << 1], rk2[N << 1];
char s[N];
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
cin >> s + 1;
n = strlen(s + 1);
int sz = max(n, 127);
for (int i = 1; i <= n; ++i) ++buc[rk[i] = s[i]];
for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
for (int i = 1; i <= n; ++i) sa[buc[rk[i]]--] = i;
for (int w = 1; w < n; w <<= 1) {
for (int i = 0; i <= sz; ++i) buc[i] = 0;
for (int i = 1; i <= n; ++i) sa2[i] = sa[i];
for (int i = 1; i <= n; ++i) ++buc[rk[sa2[i] + w]];
for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
for (int i = n; i > 0; --i) sa[buc[rk[sa2[i] + w]]--] = sa2[i];
for (int i = 0; i <= sz; ++i) buc[i] = 0;
for (int i = 1; i <= n; ++i) sa2[i] = sa[i];
for (int i = 1; i <= n; ++i) ++buc[rk[sa2[i]]];
for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
for (int i = n; i > 0; --i) sa[buc[rk[sa2[i]]]--] = sa2[i];
for (int i = 1; i <= n; ++i) rk2[i] = rk[i];
for (int cur = 0, i = 1; i <= n; ++i) {
if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
rk[sa[i]] = cur;
} else {
rk[sa[i]] = ++cur;
}
}
}
for (int i = 1; i <= n; ++i) {
cout << sa[i] << " \n"[i == n];
}
return 0;
}
当然我们还能再进行一些常数上的优化,得到更为实用的版本。
- 化简对第二关键字(\(s_{[i + 2^k, i + 2^{k + 1} - 1]}\) 部分)进行的排序过程。
考虑对第二关键字进行排序的实质:对于 \(i = 1\) 到 \(n\),按照 \(\text{rk}_{i + 2^k - 1}\) 进行排序,最终的结果必定是把最后的 \(n\) 个(后缀 \(i - n + 1\) 到后缀 \(n\))放到最前面,并把其余的部分按照原顺序(倍增到 \(k\) 层时后半部分的排名顺序)往后平移。
-
在每次倍增结束后动态更新值域。
-
若 \(\text{rk}\) 数组内任意两项的值都不相同,则排序过程必定已经结束,可以停止接下来的操作。
代码:
点击查看代码
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 1e6 + 5;
int n;
int buc[N], sa[N], sa2[N], rk[N << 1], rk2[N << 1];
char s[N];
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
cin >> s + 1;
n = strlen(s + 1);
int sz = 127;
for (int i = 1; i <= n; ++i) ++buc[rk[i] = s[i]];
for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
for (int i = 1; i <= n; ++i) sa[buc[rk[i]]--] = i;
for (int w = 1; w < n; w <<= 1) {
int cur = 0;
for (int i = n - w + 1; i <= n; ++i) sa2[++cur] = i;
for (int i = 1; i <= n; ++i) {
if (sa[i] > w) {
sa2[++cur] = sa[i] - w;
}
}
for (int i = 0; i <= sz; ++i) buc[i] = 0;
for (int i = 1; i <= n; ++i) ++buc[rk[i]];
for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
for (int i = n; i > 0; --i) sa[buc[rk[sa2[i]]]--] = sa2[i];
cur = 0;
for (int i = 1; i <= n; ++i) rk2[i] = rk[i];
for (int i = 1; i <= n; ++i) {
if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
rk[sa[i]] = cur;
} else {
rk[sa[i]] = ++cur;
}
}
sz = cur;
if (cur == n) {
break;
}
}
for (int i = 1; i <= n; ++i) {
cout << sa[i] << " \n"[i == n];
}
return 0;
}
\(\text{height}\) 数组
定义 \(h_i\) 表示字典序第 \(i - 1\) 大的后缀与第 \(i\) 大的后缀的 \(\text{LCP}\)(Longest Common Prefix,最长公共前缀)长度,\(h_1 = 0\)。
现有结论:\(h_{\text{rk}_i} \ge h_{\text{rk}_{i - 1}} - 1\)。
证明:当 \(h_{\text{rk}_{i - 1}} \le 1\) 时显然成立。考虑 \(h_{\text{rk}_{i - 1}} > 1\) 的情况。此时,字典序恰好比后缀 \(i - 1\) 小 \(1\) 的后缀去掉其第一个字符后,新的后缀与后缀 \(i\) 的 \(\text{LCP}\) 长度必定为 \(h_{\text{rk}_{i - 1}} - 1\)。而字典序比后缀 \(i\) 小 \(1\) 的后缀必定包含该 \(\text{LCP}\)。故结论成立。
于是从前往后扫描原串并依据上述结论暴力计算,即可在 \(O(n)\) 的时间复杂度内求出 \(h\) 数组。代码如下:
for (int i = 1, w = 0; i <= n; ++i) {
if (rk[i] == 1) {
continue;
}
w = max(0, w - 1);
while (s[i + w] == s[sa[rk[i] - 1] + w]) {
++w;
}
h[rk[i]] = w;
}
使用 \(h\) 数组可以求不同子串数目等信息。
P2870 [USACO07DEC] Best Cow Line G
容易发现这题本质是在比较原串与翻转串的大小关系。
显然可以哈希,但你也可以将原串翻转后接回去,并建后缀数组做。
P4248 [AHOI2013] 差异
一个结论是 \(\text{LCP}(t_i, t_j) = \min\limits_{k = i + 1}^{j} h_k\),直接利用该结论做即可。
P4248 [AHOI2013] 差异
考虑经典套路,在两个字符串间插入特殊字符并拼到一起,此时原先的两个字符串可能会有多的贡献。把这部分贡献统计一下并减掉即可。用 SA 算贡献的部分是简单的。
P2408 不同子串个数
从增量的角度考虑,按照排名顺序依次加入后缀,每个后缀对答案的贡献是 \(\text{len}_{\text{sa}_i} - h_i\) 的。因此本质不同子串个数为 \(\frac{n \times (n + 1)}{2} - \sum h_i\)。
P2178 [NOI2015] 品酒大会
统计方案数是平凡的,统计最大美味度可以通过在单调栈的过程中维护当前区间(栈顶位置到当前位置,显然只会有加入贡献的操作)\(\max\) 与 \(\min\) 来 \(O(n)\) 解决。细节挺多,可能会比较难写。

浙公网安备 33010602011771号