一直以为自己是学过 KMP 的,然而却并没有真正理解它的精华。
做了这个题,总算是加深了对 KMP 的理解。
参考了 @Tony1312 的题解,讲的非常棒。
题意
对于一个长度为 $n$ 的串的每个前缀,求出它的不重合的相同前后缀个数。$n \le 10^6$.
思路
对于此题取 1-indexed 的字符串较为方便。
看到 border 显然想到 KMP. 我们不妨设 cnt[i]
表示前缀 i 的相同前后缀个数(可以重合),则可以递推。为了方便,我们令 cnt[i]
包括整个串作为相同前后缀的情况。(这点和 fail[]
不同,它不能包括整个串作为相同前后缀的情况。)
于是 cnt[1] = 1
,然后在求 fail[]
的时候可以递推出剩下的 cnt[]
。
暴力
因为不会算 KMP 的复杂度,我就先写了个暴力,跳 fail[]
直到 $j \leq \lfloor i / 2\rfloor$,然后统计答案,果断 TLE 了……
其实考虑这个例子就知道为什么了:
1
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa...
在这个数据上暴力跳 fail[]
的复杂度是 $O(n^2)$ 的。
KMP 时间复杂度证明
如何优化?
首先我们得知道为啥 KMP 是 线性复杂度的……
考虑以下的 KMP 实现:(假设串 a, b
长度分别为 n, m
)
fail[1] = 0;
for(int i = 2, j = 0; i <= m; i++) {
while(j && b[j+1] != b[i]) j = fail[j]; // 1
if(b[j+1] == b[i]) j++; // 2
fail[i] = j;
}
for(int i = 1, j = 0; i <= n; i++) {
while(j && b[j+1] != a[i]) j = fail[j]; // 3
if(b[j+1] == a[i]) j++; // 4
if(j == m) {
printf("%d\n", i-m+1);
j = fail[j];
}
}
我们使用记账分析计算复杂度。
对于第一个循环:
- 1 号位置每次执行 $j$ 至少减少 $1$
- 2 号位置每次执行 $j$ 增加 $1$
而 2 号位置最多执行 $m$ 次,所以相当于存了 $m$ 块钱。
1 号位置每次执行至少花 $1$ 块钱,所以最多花 $m$ 块钱。
于是第一个循环复杂度是 $O(m)$.
对于第二个循环:
- 3 号位置每次执行 $j$ 至少减少 $1$
- 4 号位置每次执行 $j$ 增加 $1$
而 $4$ 号位置最多执行 $n$ 次,所以相当于存了 $n$ 块钱。 $3$ 号位置每次执行至少花 $1$ 块钱,所以最多花 $n$ 块钱。
于是第二个循环复杂度是 $O(n)$.
综上所述,KMP 算法总时间复杂度为 $O(n+m)$.
优化
我们考虑把暴力算 num[]
的复杂度优化到和计算 fail[]
同阶。这时就需要用到摊还分析的思想了:只要保证存钱不超过 $O(m)$,就可以在线性复杂度解决问题。
类似计算 fail[]
,我们有如下的性质。设对于每个前缀,它的合法的相同前后缀中,前缀结尾于 $j$ 位置,则:
- 每当 $i$ 增加 $1$,$j$ 最多增加 $1$,也可能不变或者减少。
证明也很显然:我们不妨考虑 $j$ 增加 $2$ 的情况,如果此时的 $j$ 也满足 $j \le \lfloor i/2 \rfloor$,那么显然也有 b[j+1] = b[i-1]
。对于前缀 $i-1$ 也满足不重合。所以上个前缀的前缀结尾根本就不等于 $j$,而是等于 $j+1$,所以假设不成立。对于 $j$ 增加 $3, 4, 5, \dots$ 等情况可以同理证明根本不存在。
有了这个结论,就保证了最多存 $O(n)$ 块钱。于是复杂度也有保证了。只需要从上次的 $j$ 开始跳 fail[]
;跳到位置后再加 $1$(这样就变成了不考虑重叠的答案,好想好写的多!),然后再继续跳,直到满足不重合为止。
这部分代码如下(此处 $n$ 是串长) :
int ans = 1;
for(int i = 2, j = 0; i <= n; i++) {
while(j && b[j+1] != b[i]) j = fail[j];
if(b[j+1] == b[i]) j++;
while(j > i/2) j = fail[j];
ans = mul(ans, cnt[j]+1);
}
根据记账分析,复杂度显然是 $O(n)$ 的。
代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1e6, MOD = int(1e9+7);
int T, n, fail[MAXN + 10], cnt[MAXN + 10];
char b[MAXN + 10];
void calc_fail() {
fail[1] = 0; cnt[1] = 1;
for(int i = 2, j = 0; i <= n; i++) {
while(j && b[j+1] != b[i]) j = fail[j];
if(b[j+1] == b[i]) j++;
fail[i] = j;
cnt[i] = cnt[fail[i]] + 1;
}
}
int main() {
scanf("%d", &T);
while(T--) {
scanf("%s", b + 1);
n = strlen(b + 1);
calc_fail();
long long ans = 1;
for(int i = 2, j = 0; i <= n; i++) {
while(j && b[j+1] != b[i]) j = fail[j];
if(b[j+1] == b[i]) j++;
while(j > i/2) j = fail[j];
ans = ans * (cnt[j]+1) % MOD;
}
printf("%lld\n", ans);
}
return 0;
}