DP hard Find All Good Strings digit DP

const int mxN=500, mxM=50, M=1e9+7;
class Solution {
public:
#define ll long long

ll dp[mxN+1][mxM+1][2];
int tr[mxM][26];
ll solve(int n, string s, string e) {
reverse(e.begin(), e.end());
for(int i=0; i<e.size(); ++i) {
string f=e.substr(0, i);
for(int j=0; j<26; ++j) {
string g=f+(char)('a'+j);
for(int k=i+1; ; --k) {
if(g.substr(i+1-k)==e.substr(0, k)) {
tr[i][j]=k;
break;
}
}
}
}
for(int i=0; i<=n; ++i)
for(int j=0; j<e.size(); ++j)
dp[i][j][0]=dp[i][j][1]=0;
dp[n][0][1]=1;
for(int i=n-1; ~i; --i) {
//transition from i+1
for(int j=0; j<e.size(); ++j) {
for(int k=0; k<26; ++k) {
for(int l : {0, 1}) {
//transition from dp[i+1][j][l]
int nl;
if(k<s[i]-'a')
nl=1;
else if(k>s[i]-'a')
nl=0;
else
nl=l;
dp[i][tr[j][k]][nl]=(dp[i+1][j][l]+dp[i][tr[j][k]][nl])%M;
}
}
}
}
ll ans=0;
for(int j=0; j<e.size(); ++j)
ans+=dp[0][j][1];
return ans%M;
}
int findGoodStrings(int n, string s1, string s2, string evil) {
ll ans=solve(n, s2, evil);
bool ok=1;
for(int i=0; i<n&&ok; ++i)
ok=s1[i]=='a';
if(!ok) {
for(int i=n-1; ; --i) {
if(s1[i]>'a') {
--s1[i];
break;
}
s1[i]='z';
}
auto ans1 = solve(n, s1, evil);
//cout<<ans<<" "<<ans1<<endl;
ans=(ans+M-ans1)%M;
}
return ans;
}
};
class Solution:
def findGoodStrings(self, n: int, s1: str, s2: str, evil: str) -> int:
N, M = n, len(evil)
MOD = 10**9+7
dp = [[[0,0] for j in range(M+1)] for i in range(N+1)]
tr = [[0]*26 for _ in range(M+1)]
def solve(n, s, e):
e = e[::-1]
for i, c in enumerate(e):
f = e[:i]
for j in range(26):
g = f+chr(ord('a')+j)
for k in range(i+1, -1, -1):
if g[i+1-k:]==e[:k]:
tr[i][j] = k
break
for i in range(N+1):
for j in range(M+1):
dp[i][j] = [0, 0]
dp[N][0][1] = 1
for i in range(n-1, -1, -1):
for j in range(len(e)):
for k in range(26):
for l in range(2):
nl = 0
if k<ord(s[i])-ord('a'):
nl = 1
elif k>ord(s[i])-ord('a'):
nl = 0
else:
nl = l
dp[i][tr[j][k]][nl]=(dp[i+1][j][l]+dp[i][tr[j][k]][nl])%MOD
return sum(dp[0][j][1] for j in range(len(e)))%MOD
ans = solve(n, s2, evil)
if set(s1)=={'a'}:return ans
s1 = list(s1)
for i in range(n-1,-1,-1):
if s1[i]>'a':
s1[i] = chr(ord(s1[i])-1)
break
s1[i] = 'z'
s1 = ''.join(s1)
ans1 = solve(n, s1, evil)
return (ans+MOD-ans1)%MOD
const int MOD = 1e9 + 7;
const int N = 500 + 10;
const int M = 50 + 10;
int dp[N][M][2];
class Solution {
public:
vector<int> kmp(const string& t) {
int m = t.size();
vector<int> f(m);
f[0] = 0;
int k = 0;
for (int i = 1; i < m; ++i) {
for (; k > 0 && t[k] != t[i]; k = f[k - 1]);
if (t[k] == t[i]) k++;
f[i] = k;
}
return f;
}
bool contain(const string& s, const string& t) {
int n = s.size(), m = t.size();
for (int i = 0; i + m <= n; ++i) {
if (s.substr(i, m) == t) return true;
}
return false;
}
string s, t;
vector<int> f;
int solve(int pos, int matched, int tight) {
if (matched == t.size()) return 0;
if (pos == s.size()) return 1;
int& ret = dp[pos][matched][tight];
if (ret >= 0) return ret;
ret = 0;
for (int i = 0; i < 26; ++i) {
if (tight && i > s[pos] - 'a') continue;
int nxt_tight = tight && i == s[pos] - 'a';
int nxt_matched = matched;
while (nxt_matched > 0 && t[nxt_matched] - 'a' != i) nxt_matched = f[nxt_matched - 1];
if (t[nxt_matched] - 'a' == i) nxt_matched += 1;
ret = (ret + solve(pos + 1, nxt_matched, nxt_tight)) % MOD;
}
return ret;
}
int findGoodStrings(int n, string s1, string s2, string evil) {
this->f = kmp(evil);
memset(dp, 255, sizeof(dp));
s = s2; t = evil;
int B = solve(0, 0, 1);
memset(dp, 255, sizeof(dp));
s = s1; t = evil;
int A = solve(0, 0, 1);
//cout << B << " " << A << endl;
int ret = (B + MOD - A) % MOD;
if (!contain(s1, evil)) ret = (ret + 1) % MOD;
return ret;
}
};
from functools import lru_cachedef srange(a, b):
yield from (chr(i) for i in range(ord(a), ord(b)+1))

def failure(pat):
res = [0]
i, target = 1, 0
while i < len(pat):
if pat[i] == pat[target]:
target += 1
res += target,
i += 1
elif target:
target = res[target-1]
else:
res += 0,
i += 1
return res
class Solution:
def findGoodStrings(self, n: int, s1: str, s2: str, evil: str) -> int:
f = failure(evil)
@lru_cache(None)
def dfs(idx, max_matched=0, lb=True, rb=True):
'''
idx: current_idx_on_s1_&_s2,
max_matched: nxt_idx_to_match_on_evil,
lb, rb: is_left_bound, is_right_bound
'''
if max_matched == len(evil): return 0 # evil found, break
if idx == n: return 1 # base case

l = s1[idx] if lb else 'a' # valid left bound
r = s2[idx] if rb else 'z' # valid right bound
candidates = [*srange(l, r)]

res = 0
for i, c in enumerate(candidates):
nxt_matched = max_matched
while evil[nxt_matched] != c and nxt_matched:
nxt_matched = f[nxt_matched - 1]
res += dfs(idx+1, nxt_matched + (c == evil[nxt_matched]),
lb=(lb and i == 0), rb=(rb and i == len(candidates)-1))
return res

return dfs(0) % (10**9 + 7)

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Jimmy Shen

Jimmy Shen

Data Scientist/MLE/SWE @takemobi