Description 题库链接 给定一个长度为 \(n\) 的字符串 \(S\) ,令 \(T_i\) 表示它从第 \(i\) 个字符开始的后缀。求 \[\sum_{1\leqslant ij\leqslant n}\text{len}(T_i)+\text{len}(T_j)-2\times\text{lcp}(T_i,T_j)\] 其中,
Description
题库链接
给定一个长度为 \(n\) 的字符串 \(S\) ,令 \(T_i\) 表示它从第 \(i\) 个字符开始的后缀。求
\[\sum_{1\leqslant i<j\leqslant n}\text{len}(T_i)+\text{len}(T_j)-2\times\text{lcp}(T_i,T_j)\]
其中, \(\text{len}(a)\) 表示字符串 \(a\) 的长度, \(\text{lcp}(a,b)\) 表示字符串 \(a\) 和字符串 \(b\) 的最长公共前缀。
\(2\leqslant n\leqslant 500000\) ,且均为小写字母。
Solution
注意到原式可化为
\[\begin{aligned}&(n-1)\sum_{i=1}^n\text{len}(T_i)-2\sum_{1\leqslant i<j\leqslant n}\text{lcp}(T_{sa_i},T_{sa_j})\\=&\frac{(n-1)n(n+1)}{2}-2\sum_{1\leqslant i<j\leqslant n}\text{lcp}(T_{sa_i},T_{sa_j})\end{aligned}\]
记 \(\text{LCP}(i,j)=\text{lcp}(T_{sa_i},T_{sa_j})\) ,由于
\[\text{LCP}(i,j)=\min_{i<k\leq j}\text{LCP}(k-1,k)\]
那么就可以单调栈预处理出两个数组 \(l_i,r_i\) 表示左边(右边)第一个大于(大于等于) \(height_i\) 的位置。注意,由于不能重复计算,等于只能一边取。
然后直接算贡献就好了。
Code
#include <bits/stdc++.h> #define ll long long using namespace std; const int N = 500000+5; char ch[N]; int n, m, x[N<<1], y[N<<1], c[N], sa[N], rk[N], height[N], l[N], r[N], s[N], top; void get_sa() { for (int i = 1; i <= n; i++) c[x[i] = ch[i]]++; for (int i = 2; i <= m; i++) c[i] += c[i-1]; for (int i = n; i >= 1; i--) sa[c[x[i]]--] = i; for (int k = 1; k <= n; k <<= 1) { int num = 0; for (int i = n-k+1; i <= n; i++) y[++num] = i; for (int i = 1; i <= n; i++) if (sa[i] > k) y[++num] = sa[i]-k; for (int i = 1; i <= m; i++) c[i] = 0; for (int i = 1; i <= n; i++) c[x[i]]++; for (int i = 2; i <= m; i++) c[i] += c[i-1]; for (int i = n; i >= 1; i--) sa[c[x[y[i]]]--] = y[i]; swap(x, y); x[sa[1]] = num = 1; for (int i = 2; i <= n; i++) x[sa[i]] = (y[sa[i]] == y[sa[i-1]] && y[sa[i]+k] == y[sa[i-1]+k]) ? num : ++num; if ((m = num) == n) break; } } void get_height() { for (int i = 1; i <= n; i++) rk[sa[i]] = i; for (int i = 1, k = 0; i <= n; i++) { if (rk[i] == 1) continue; if (k) --k; int j = sa[rk[i]-1]; while (j+k <= n && i+k <= n && ch[i+k] == ch[j+k]) ++k; height[rk[i]] = k; } } void work() { scanf("%s", ch+1), n = strlen(ch+1), m = 'z'; get_sa(); get_height(); ll ans = 1ll*(1+n)*n*(n-1)/2; s[++top] = n+1; for (int i = n; i >= 2; i--) { while (top != 1 && height[i] <= height[s[top]]) --top; r[i] = s[top]; s[++top] = i; } s[top = 1] = 1; for (int i = 2; i <= n; i++) { while (top != 1 && height[i] < height[s[top]]) --top; l[i] = s[top]; s[++top] = i; } for (int i = 2; i <= n; i++) ans -= 2ll*(i-l[i])*(r[i]-i)*height[i]; printf("%lld\n", ans); } int main() {work(); return 0; }