学完AC自动机后一直想学SAM,结果看不懂,先来看看比它简单一点的SA。

约定

字符串和数组下标从1开始。
后缀i:以第i个字符开头的后缀。

什么是后缀数组

后缀数组由后缀数组sa和名次数组rk两个数组组成。

解释

sa[i]表示的是字符串S的所有后缀中第i小的是”后缀sa[i]”。
rk[i]表示的是”后缀i”是字符串S的所有后缀中第rk[i]小的。
简单的说,后缀数组是”排第几的是谁?”名次数组是”你排第几?”
引用OIWIKI上的例子:
图片说明
我们来逐一解释:
首先,我们知道所有的后缀,从小到大它们依次是
aaaab、aaab、aab、aabaaaab、ab、abaaaab、b、baaaab。
根据定义求出rk=[4,6,8,1,2,3,5,7]。
同理,根据定义求出sa=[4,5,6,1,7,2,8,3]。

神奇的性质

sa[rk[i]]=rk[sa[i]]=i

从暴力到线性

O(n^2logn)

不难想到,我们可以从右向左的求出字符串S的所有后缀,然后排序。

1
过于简单不予展示。

O(nlog^2 n)

利用倍增思想,减少排序次数。
先对所有长度为1的子串进行排序。
得到字符串长度为1时的rk,记为rk1。接着以rk1[i]作为第一关键字,rk1[i+1]作为第二关键字,再次排序。得到长度为2的rk,记为rk2。
接着以rk2[i]作为第一关键字,rk2[i+2]作为第二关键字,再次排序。得到长度为4的rk,记为rk4。
利用倍增思想求出长度为n时的rk。
n个数的排序做logn次。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
string s=" aabaaaab";//这里具体情况具体分析
int n,w;
int oldrk[maxn<<1];
int rk[maxn<<1];
int SA[maxn];
void init()
{
n=s.length();//这个多了前面的空格
n--;
for(int i=1;i<=n;i++)
{
rk[i]=s[i];
//这里可能会有人问,为什么这么赋值
//你想啊,反正我都要倍增的,初值是多少根本没人在意
SA[i]=i;
}
}
bool cmp(int x,int y)
{
if(rk[x]==rk[y])
{
return rk[x+w]<rk[y+w];
}
return rk[x]<rk[y];
}
int main()
{
init();
for(w=1;w<n;w*=2)//注意是< , w是增量
{
sort(SA+1,SA+n+1,cmp);//通过rk对SA排序

memcpy(oldrk,rk,sizeof(rk));

for(int p=0,i=1;i<=n;i++)
{
// 若两个子串相同,它们对应的 rk 也需要相同,所以要去重
if(oldrk[SA[i]]==oldrk[SA[i-1]]
&& oldrk[SA[i]+w]==oldrk[SA[i-1]+w])
{
rk[SA[i]]=p;
}
else
{
rk[SA[i]]=++p;
}
}
}
for(int i=1;i<=n;i++){
cout<<SA[i]<<" ";
}
cout<<endl;
return 0;
}

O(nlogn)

使用基数排序代替sort

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
char s[MAXN];
int N, M, rak[MAXN], sa[MAXN], tax[MAXN], tp[MAXN];
void Debug() {
printf("*****************\n");
printf("下标"); for (int i = 1; i <= N; i++) printf("%d ", i); printf("\n");
printf("sa "); for (int i = 1; i <= N; i++) printf("%d ", sa[i]); printf("\n");
printf("rak "); for (int i = 1; i <= N; i++) printf("%d ", rak[i]); printf("\n");
printf("tp "); for (int i = 1; i <= N; i++) printf("%d ", tp[i]); printf("\n");
}
void Qsort() {
for (int i = 0; i <= M; i++) tax[i] = 0;
for (int i = 1; i <= N; i++) tax[rak[i]]++;
for (int i = 1; i <= M; i++) tax[i] += tax[i - 1];
for (int i = N; i >= 1; i--) sa[ tax[rak[tp[i]]]-- ] = tp[i];
}
void SuffixSort() {
M = 26;
for (int i = 1; i <= N; i++) rak[i] = s[i] - 'a' + 1, tp[i] = i;
Qsort();
Debug();
for (int w = 1, p = 0; p < N; M = p, w <<= 1) {
//w:当前倍增的长度,w = x表示已经求出了长度为x的后缀的排名,现在要更新长度为2x的后缀的排名
//p表示不同的后缀的个数,很显然原字符串的后缀都是不同的,因此p = N时可以退出循环
p = 0;//这里的p仅仅是一个计数器
for (int i = 1; i <= w; i++) tp[++p] = N - w + i;
for (int i = 1; i <= N; i++) if (sa[i] > w) tp[++p] = sa[i] - w; //这两句是后缀数组的核心部分
Qsort();//此时我们已经更新出了第二关键字,利用上一轮的rak更新本轮的sa
std::swap(tp, rak);//这里原本tp已经没有用了
rak[sa[1]] = p = 1;
for (int i = 2; i <= N; i++)
rak[sa[i]] = (tp[sa[i - 1]] == tp[sa[i]] && tp[sa[i - 1] + w] == tp[sa[i] + w]) ? p : ++p;
//这里当两个后缀上一轮排名相同时本轮也相同,至于为什么大家可以思考一下
Debug();
}
for (int i = 1; i <= N; i++)
printf("%d ", sa[i]);
}

O(n)

对不起,我看不懂。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
//DC3
#define F(X) ((X)/3+((X)%3==1?0:tb))
#define G(X) ((X)<tb?(X)*3+1:((X)-tb)*3+2)
int wa[maxn],wb[maxn],wv[maxn],ws[maxn];
int c0(int *r,int a,int b){
return r[a]==r[b]&&r[a+1]==r[b+1]&&r[a+2]==r[b+2];
}
int c12(int k,int *r,int a,int b ){
if(k==2)
return r[a]<r[b]||r[a]==r[b]&&c12(1,r,a+1,b+1);
else
return r[a]<r[b]||r[a]==r[b]&&wv[a+1]<wv[b+1];
}
void sort(int *r,int *a,int *b,int n,int m){
int i;
for(i=0;i<n;i++) wv[i]=r[a[i]];
for(i=0;i<m;i++) ws[i]=0;
for(i=0;i<n;i++) ws[wv[i]]++;
for(i=0;i<m;i++) ws[i]=ws[i-1];
for(i=n-1;i>=0;i--) b[--ws[wv[i]]]=a[i];
return ;

}
void dc3(int *r,int *sa,int n,int m){
int i,j,*rn=r+n,*san=sa+n,ta=0,tb=(n+1)/3,tbc=0,p;
r[n]=r[n+1]=0;
for(i=0;i<n;i++) if (i%3!=0) wa[tbc++]=i;
sort(r+2,wa,wb,tbc,m);
sort(r+1,wb,wa,tbc,m);
sort(r,wa,wb,tbc,m);
for(p=1,rn[F(wb[0])]=0,i=1;i<tbc;i++)
rn[F(wb[i])]=c0(r,wb[i-1],wb[i])?p-1:p++;
if(p<tbc) dc(rn,san,tbc,p);
else for (i=0;i<tbc;i++) san[rn[i]]=i;
for(i=0;i<tbc;i++) if(san[i]<tb) wb[ta++]=san[i]*3;
if(n%3==1) wb[ta++]=n-1;
sort(r,wb,wa,ta,m);
for(i=0;i<tbc;i++) wv[wb[i]=G(san[i])]=i;
for(i=0,j=0,p=0;i<ta&&j<tbc;p++)
sa[p]=c12(wb[j]%3,r,wa[i],wb[j])?wa[i++]:wb[j++];
for(;i<ta;p++) sa[p]=wa[i++];
for(;j<tbc;p++) sa[p]=wb[j++];
return
}