「2017 山东一轮集训 Day1」Sum

题目描述

求有多少 $ n $ 位十进制数是 $ p $ 的倍数且每位之和小于等于 $ m_i (m_i = 0, 1, 2, \ldots, m - 1, m) $,允许前导 $ 0 $,答案对 $ 998244353 $ 取模。

数据范围 $n \le 10^9\ \ \ \ p \le 50\ \ \ \ m\le1000$

算法讨论

首先有一个易得的迪屁方程

然而n太大,无法直接DP。但如果长度n为2的次幂,就可以有n/2方便得合并

发现这个式子是个卷积的形式,可以ntt加速求

这样之后,我们就的到了所有长度为2的次幂的答案

对于n的答案怎么搞呢?把n进行二进制拆分,把对应位的答案合并起来即可

然而这样写还是t了。。原来是对p的那一维卷积我是暴力搞的。。事实上这种二维的卷积也可以搞,只需要把原来的二维数组展开成一维的就好啦(神奇的是展开了卷出来也是对的!)%%%Sparrow

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <stdio.h>
#include <algorithm>
using namespace std;
const int N = 200050;
typedef long long ll;
const ll G = 3, mod = 998244353;

ll add(ll &a, ll b) { a = (a + b) % mod; }
ll mul(ll &a, ll b) { a = a * b % mod; }

ll ksm(ll a, ll b, ll c = mod) {
ll ans = 1;
for (; b; b >>= 1, a = a * a % c)
if (b & 1) ans = ans * a % c;
return ans;
}

ll n, p, m, len, max_len;

ll ntt_len, rev_len, ntt_wi[2][N], rev[N];
void __pre_ntt() {
ll f, g, i, j, w;

max_len = p * (m + 1) * 4;
for (ntt_len = 1; ntt_len < max_len; ntt_len <<= 1);
rev_len = ksm(ntt_len, mod - 2);

for (i = j = 0; i < ntt_len; i ++) {
if (i > j) rev[i] = j;
else rev[i] = i;
for (f = ntt_len >> 1; (j ^= f) < f; f >>= 1);
}

f = ksm(G, (mod - 1) / (ntt_len << 1));
g = ksm(f, mod - 2);
ntt_wi[1][0] = ntt_wi[0][0] = 1;
for (i = 1; i < ntt_len; i ++)
ntt_wi[1][i] = ntt_wi[1][i - 1] * f % mod,
ntt_wi[0][i] = ntt_wi[0][i - 1] * g % mod;
}

struct Data {
ll a[N];

void ntt(int ty) {
ll t, i, j, k, w;
ll f, g;
for (i = 0; i < ntt_len; i ++) swap(a[i], a[rev[i]]);
for (w = 1; w < ntt_len; w <<= 1) {
t = ntt_len / w;
for (k = 0; k < ntt_len; k += w << 1)
for (i = k, j = 0; i < k + w; i ++, j += t) {
f = a[i];
g = a[i + w] * ntt_wi[ty][j] % mod;
a[i] = (f + g) % mod;
a[i + w] = (f - g + mod) % mod;
}
}
if (ty == 1) return;
for (i = 0; i < ntt_len; i ++) a[i] = a[i] * rev_len % mod;
}

void clear() {
fill(a, a + ntt_len, 0);
}

Data operator + (const Data& rhs) const {
Data ret;
for (int i = 0; i < ntt_len; i ++) ret.a[i] = (a[i] + rhs.a[i]) % mod;
return ret;
}

Data operator * (Data rhs) {
Data ret;

ntt(1), rhs.ntt(1);
for (int i = 0; i < ntt_len; i ++) ret.a[i] = a[i] * rhs.a[i] % mod;
ntt(0), ret.ntt(0);

for (int i = 0; i <= m; i ++)
for (int k, j = 0; j < p; j ++) {
k = i * len + j + p;
add(ret.a[k - p], ret.a[k]);
ret.a[k] = 0;
}
fill(ret.a + (max_len >> 1), ret.a + ntt_len, 0);

return ret;
}

void print() {
// ntt(0);
for (int i = 0; i < ntt_len; i ++)
printf("%lld%c", a[i], i == ntt_len - 1 ? '\n' : ' ');
// ntt(1);
}

} ans, f, tmp;

int main() {
ll i, j, k, w, t, flag = 0;

scanf("%lld%lld%lld", &n, &p, &m);
__pre_ntt();
len = p * 2;
ll pow = 10;

for (i = 0; i < 10 && i <= m; i ++)
f.a[i * len + i % p] = 1;
if (n & 1) ans = f, flag = 1;

for (i = 1; (w = 1LL << i) <= n; i ++) {
tmp.clear();
for (k = 0; k <= m; k ++)
for (t = k * len, j = 0; j < p; j ++)
add(tmp.a[t + j * pow % p], f.a[t + j]);
f = tmp * f;
pow = pow * pow % p;

if (w & n) {
if (!flag) flag = 1, ans = f;
else {
tmp.clear();
for (k = 0; k <= m; k ++)
for (t = k * len, j = 0; j < p; j ++)
add(tmp.a[t + j * pow % p], ans.a[t + j]);
ans = tmp * f;
}
}
}

// ans.print();
ll sum = 0;
for (i = 0; i <= m; i ++) {
sum += ans.a[i * len]; sum %= mod;
printf("%lld%c", sum, i == m ? '\n' : ' ');
}
}