494B - Obsessive String

Solution:

I like this problem, and for me it is a particularly challenging one. Thankfully the editorial for this round is very well-written and detailed. I learnt quite a lot from this problem.

Firstly we need to find an efficient way to mark all occurrence of t in s, and here we can employ the \(O(N)\) Knuth-Morris-Pratt string matching algorithm. We keep track of all indexes i of s such that s[i-|t|+1 .. i] matches t exactly in an array g[i], setting such indexes g[i] as 1, and 0 otherwise. The hardest part is to come up with the right DP state. Here, the one that will work is: let f[i] be the number of ways to choose the substrings from [1..i], such that the last (right-most) substring ends with index i.

By partitioning the problem this way, we can calculate the total number of choosing those substrings by summing up f[i] from i = 1 to N. From here, the idea is pretty neat:

1. introduce another array sum[1..i], where sum[i] is the sum of f[j] for j = 1 to i. This represents the number of ways to choose substrings from [1..i].

2. if g[i] is not set, then we don't have a choice but to extend the last substring in [1..i-1] by one element. So f[i] = f[i-1].

3. otherwise, g[i] is set, then we have s[i-|t|+1 .. i] an occurrence of t. Here we have the luxury to extend this substring to the left, i.e. we can have s[k .. i] as the leftmost substring, where k ranges from [1 .. i-|t|+1]. For each k, we have (sum[k-1] + 1) ways of choosing the rest of the substrings in [1..k-1] (and "plus one" for not choosing any substring). Hence f[i] = sum of (sum[k-1] + 1) for k = 1 to i-|t|+1.

Computing sum[i] and sum of sum[i] before hand by using a dp table will give us O(N) running time overall.

Implementation:

#include <iostream> #include <cstdio> #include <algorithm> #include <string> using namespace std; long long MOD = (long long) 1e9 + 7LL; string s, t; int par[100005]; long long f[100005], g[100005], sum[100005], tot[100005]; int N, M; int main(){ cin >> s >> t; N = s.size(); M = t.size(); int k = -1; par[0] = -1; for(int i=1;i<M;++i){ while(k>=0 && t[k+1] != t[i]) k = par[k]; if(t[k+1] == t[i]) ++k; par[i] = k; } k = -1; for(int i=0;i<N;++i){ if(t[k+1] == s[i]) ++k; else { while(k>=0 && t[k+1] != s[i]) k = par[k]; if(t[k+1] == s[i]) ++k; } if(k == M-1) { g[i] = 1; } } f[0] = sum[0] = tot[0] = 0; long long ans = 0; for(int i=1;i<=N;++i){ if(g[i-1]){ f[i] = tot[i-M] + i-M+1; f[i] %= MOD; sum[i] = f[i] + sum[i-1]; sum[i] %= MOD; } else { f[i] = f[i-1]; sum[i] = f[i] + sum[i-1]; sum[i] %= MOD; } tot[i] = sum[i] + tot[i-1]; tot[i] %= MOD; ans += f[i]; ans %= MOD; } cout << ans << endl; return 0; }