[NOI2014]动物园

Posted by Panda2134's Blog on July 6, 2018

一直以为自己是学过 KMP 的,然而却并没有真正理解它的精华。
做了这个题,总算是加深了对 KMP 的理解。
参考了 @Tony1312 的题解,讲的非常棒。

题意

对于一个长度为 $n$ 的串的每个前缀,求出它的不重合的相同前后缀个数。$n \le 10^6$.

思路

对于此题取 1-indexed 的字符串较为方便。
看到 border 显然想到 KMP. 我们不妨设 cnt[i] 表示前缀 i 的相同前后缀个数(可以重合),则可以递推。为了方便,我们令 cnt[i] 包括整个串作为相同前后缀的情况。(这点和 fail[] 不同,它不能包括整个串作为相同前后缀的情况。)
于是 cnt[1] = 1,然后在求 fail[] 的时候可以递推出剩下的 cnt[]

暴力

因为不会算 KMP 的复杂度,我就先写了个暴力,跳 fail[] 直到 $j \leq \lfloor i / 2\rfloor$,然后统计答案,果断 TLE 了……

其实考虑这个例子就知道为什么了:

1
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa...

在这个数据上暴力跳 fail[] 的复杂度是 $O(n^2)$ 的。

KMP 时间复杂度证明

如何优化?
首先我们得知道为啥 KMP 是 线性复杂度的……

考虑以下的 KMP 实现:(假设串 a, b 长度分别为 n, m

fail[1] = 0;
for(int i = 2, j = 0; i <= m; i++) {
    while(j && b[j+1] != b[i]) j = fail[j]; // 1
    if(b[j+1] == b[i]) j++; // 2
    fail[i] = j;
}

for(int i = 1, j = 0; i <= n; i++) {
    while(j && b[j+1] != a[i]) j = fail[j]; // 3
    if(b[j+1] == a[i]) j++; // 4
    if(j == m) {
        printf("%d\n", i-m+1);
        j = fail[j];
    }
}

我们使用记账分析计算复杂度。

对于第一个循环:

  • 1 号位置每次执行 $j$ 至少减少 $1$
  • 2 号位置每次执行 $j$ 增加 $1$

而 2 号位置最多执行 $m$ 次,所以相当于存了 $m$ 块钱。
1 号位置每次执行至少花 $1$ 块钱,所以最多花 $m$ 块钱。

于是第一个循环复杂度是 $O(m)$.

对于第二个循环:

  • 3 号位置每次执行 $j$ 至少减少 $1$
  • 4 号位置每次执行 $j$ 增加 $1$

而 $4$ 号位置最多执行 $n$ 次,所以相当于存了 $n$ 块钱。 $3$ 号位置每次执行至少花 $1$ 块钱,所以最多花 $n$ 块钱。

于是第二个循环复杂度是 $O(n)$.

综上所述,KMP 算法总时间复杂度为 $O(n+m)$.

优化

我们考虑把暴力算 num[] 的复杂度优化到和计算 fail[] 同阶。这时就需要用到摊还分析的思想了:只要保证存钱不超过 $O(m)$,就可以在线性复杂度解决问题。

类似计算 fail[] ,我们有如下的性质。设对于每个前缀,它的合法的相同前后缀中,前缀结尾于 $j$ 位置,则:

  • 每当 $i$ 增加 $1$,$j$ 最多增加 $1$,也可能不变或者减少。

证明也很显然:我们不妨考虑 $j$ 增加 $2$ 的情况,如果此时的 $j$ 也满足 $j \le \lfloor i/2 \rfloor$,那么显然也有 b[j+1] = b[i-1]。对于前缀 $i-1$ 也满足不重合。所以上个前缀的前缀结尾根本就不等于 $j$,而是等于 $j+1$,所以假设不成立。对于 $j$ 增加 $3, 4, 5, \dots$ 等情况可以同理证明根本不存在。

有了这个结论,就保证了最多存 $O(n)$ 块钱。于是复杂度也有保证了。只需要从上次的 $j$ 开始跳 fail[];跳到位置后再加 $1$(这样就变成了不考虑重叠的答案,好想好写的多!),然后再继续跳,直到满足不重合为止。

这部分代码如下(此处 $n$ 是串长) :

int ans = 1;
        
for(int i = 2, j = 0; i <= n; i++) {
    while(j && b[j+1] != b[i]) j = fail[j];
    if(b[j+1] == b[i]) j++;
    while(j > i/2) j = fail[j];
    ans = mul(ans, cnt[j]+1);
}

根据记账分析,复杂度显然是 $O(n)$ 的。

代码

#include <bits/stdc++.h>
using namespace std;

const int MAXN = 1e6, MOD = int(1e9+7);
int T, n, fail[MAXN + 10], cnt[MAXN + 10];
char b[MAXN + 10];

void calc_fail() {
    fail[1] = 0; cnt[1] = 1;
    for(int i = 2, j = 0; i <= n; i++) {
        while(j && b[j+1] != b[i]) j = fail[j];
        if(b[j+1] == b[i]) j++;
        fail[i] = j;
        cnt[i] = cnt[fail[i]] + 1;
    }
}

int main() {
    scanf("%d", &T);
    while(T--) {
        scanf("%s", b + 1);
        n = strlen(b + 1);
        calc_fail();
        
        long long ans = 1;
        
        for(int i = 2, j = 0; i <= n; i++) {
            while(j && b[j+1] != b[i]) j = fail[j];
            if(b[j+1] == b[i]) j++;
            while(j > i/2) j = fail[j];
            ans = ans * (cnt[j]+1) % MOD;
        }
        printf("%lld\n", ans);
    }
    return 0;
}