題目大意
給定僅由數(shù)碼組成的串 a 和非負(fù)整數(shù) l 與 r甫煞,考察其的一個(gè)劃分蛾方,若該劃分中任意一個(gè)串都是一個(gè)正當(dāng)整數(shù)(沒(méi)有多余的前導(dǎo)零)且屬于閉區(qū)間 [l, r]暖释,則我們稱該劃分為一個(gè)美麗劃分灵份。
求一共有多少個(gè)美麗劃分哀军。因?yàn)榇鸢副容^大沉眶,輸出其對(duì) 998244353 取模的結(jié)果。
題目保證 a, l, r 的位數(shù)不超過(guò) 106 且 l 不超過(guò) r 杉适。
分析
這題很顯然可以用動(dòng)態(tài)規(guī)劃的方法來(lái)解決谎倔,令 dp[i] 恰好在 i 位置右邊結(jié)束某一串的方案數(shù),則:
其中 valid(s) 表示子串 s 是否合法猿推。一個(gè)比較顯然的結(jié)論是如果當(dāng)前串沒(méi)有前導(dǎo)零且它的長(zhǎng)度在 內(nèi)那么該串一定合法片习,如果在
外則一定非法。對(duì)于長(zhǎng)度等于邊界的情況蹬叭,我們需要比較字符串的大小藕咏。但是暴力的做法每次比較需要 O(n), 顯然不能承受秽五。
我的第一想法是用類似于后綴數(shù)組構(gòu)造的步驟來(lái)排序所有長(zhǎng)度為某個(gè)定值的字符串孽查。然后獲得了順序以后再用我們對(duì)應(yīng)的 l 或者 r 去求一個(gè) lower_bound 或 upper_bound,這樣的復(fù)雜度在 O(nlogn)坦喘。這道題時(shí)限只有 1s盲再,所以很可能會(huì)超時(shí)。
然后我們發(fā)現(xiàn)如果我們的當(dāng)前串與目標(biāo)串相等那么當(dāng)前一定可行瓣铣。如果不等的話洲胖,設(shè)當(dāng)前串為 s,目標(biāo)串為 t坯沪,那么它們的大小關(guān)系就是他們第一個(gè)不等位置的大小關(guān)系。如果我們已經(jīng)處理出了 a 中的所有長(zhǎng)度為 |t| 的串與 t 的最長(zhǎng)公共前綴長(zhǎng)度擒滑,那么單次判斷可以在 O(1) 時(shí)間完成腐晾。而這個(gè)預(yù)處理工作可以用構(gòu)造 z 數(shù)組的方法在 O(|a| + |t|) 時(shí)間完成叉弦。復(fù)雜度為線性,可以接受藻糖。
另外注意特判 l 為 0 的情況淹冰。
代碼
總復(fù)雜度為 O(|a| + |l| + |r|)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define FOR(i, a, b) for (int (i) = (a); (i) <= (b); (i)++)
#define ROF(i, a, b) for (int (i) = (a); (i) >= (b); (i)--)
#define REP(i, n) FOR(i, 0, (n)-1)
#define sqr(x) ((x) * (x))
#define all(x) (x).begin(), (x).end()
#define reset(x, y) memset(x, y, sizeof(x))
#define uni(x) (x).erase(unique(all(x)), (x).end());
#define BUG(x) cerr << #x << " = " << (x) << endl
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define _1 first
#define _2 second
#define chkmin(a, b) a = min(a, b)
#define chkmax(a, b) a = max(a, b)
const int maxn = 1123456;
const int MOD = 998244353;
char s[maxn * 2], a[maxn], lo[maxn], hi[maxn];
int z[maxn * 2], dp[maxn];
vector<int> in[maxn], out[maxn];
inline void upd(int &a, int b) {
a += b;
if (a >= MOD) a -= MOD;
}
void get_z(int n) {
int l = 0, r = 0;
z[0] = 0;
FOR(i, 1, n - 1) {
if (r >= i) {
if (z[i - l] + i <= r) z[i] = z[i - l];
else {
int nxt = r - i;
while (r < n && s[r] == s[nxt]) r++, nxt++;
r--;
z[i] = nxt;
l = i;
}
} else {
r = i;
int nxt = 0;
while (r < n && s[r] == s[nxt]) r++, nxt++;
l = i;
r--;
z[i] = nxt;
}
}
}
int main() {
scanf("%s%s%s", a + 1, lo + 1, hi + 1);
int n = strlen(a + 1), len_l = strlen(lo + 1), len_r = strlen(hi + 1);
strcpy(s, lo + 1);
s[len_l] = '$';
strcpy(s + len_l + 1, a + 1);
get_z(len_l + n + 1);
FOR(i, 1, n - len_l + 1) if (a[i] != '0') {
int idx = len_l + i;
if (z[idx] == len_l || a[z[idx] + i] > lo[z[idx] + 1])
in[i + len_l - 1].eb(i);
else in[i + len_l].eb(i);
}
strcpy(s, hi + 1);
s[len_r] = '$';
strcpy(s + len_r + 1, a + 1);
get_z(len_r + n + 1);
if (lo[1] != '0') {
FOR(i, 1, n - len_r + 1) if (a[i] != '0') {
int idx = len_r + i;
if (z[idx] == len_r || a[z[idx] + i] < hi[z[idx] + 1])
out[i + len_r].eb(i);
else out[i + len_r - 1].eb(i);
}
} else {
FOR(i, 1, n - len_r + 1) if (a[i] != '0') {
int idx = len_r + i;
if (z[idx] == len_r || a[z[idx] + i] < hi[z[idx] + 1])
out[i + len_r].eb(i);
else out[i + len_r - 1].eb(i);
}
FOR(i, 1, n) if (a[i] == '0') {
in[i].eb(i);
out[i + 1].eb(i);
}
}
dp[0] = 1;
int way = 0;
FOR(i, 1, n) {
for (auto it : in[i]) upd(way, dp[it - 1]);
for (auto it : out[i]) upd(way, MOD - dp[it - 1]);
dp[i] = way;
}
printf("%d", dp[n]);
}