「清华集训 2017」生成树计数

:pig:

题目描述

在一个 $s$ 个点的图中,存在 $s-n$ 条边,使图中形成了 $n$ 个连通块,第 $i$ 个连通块中有 $a_i$ 个点。

现在我们需要再连接 $n-1$ 条边,使该图变成一棵树。对一种连边方案,设原图中第 $i$ 个连通块连出了 $d_i$ 条边,那么这棵树 $T$ 的价值为:

你的任务是求出所有可能的生成树的价值之和,对 $998244353$ 取模。

输入格式

输入的第一行包含两个整数 $n,m$,意义见题目描述。

接下来一行有 $n$ 个整数,第 $i$ 个整数表示 $a_i$ $(1\le a_i< 998244353)$。

  • 你可以由 $a_i$ 计算出图的总点数 $s$,所以在输入中不再给出 $s$ 的值。

输出格式

输出包含一行一个整数,表示答案。

算法讨论

由于每个联通块有大小,所以答案要乘上$a_i^{d_i}$

需要枚举生成树,无法做

很不容易想到prufer序列,即这棵树可以由一个$n - 2$的序列表示,序列中每个元素出现次数$k_i+1$即为它在树上的度数,那么

中间一步$k_1+k_2+…+k_n=n-2$不好处理,于是构造生成函数:

于是答案就是$F(x)的第n-2项再乘(n-2)!$

由于n比较大,还是不好处理:

同理$B_i(x) = e^{a_ix}\sum_{j=0}^{2m}S(m+1,j+1)a_i^{j+1}x^j$

现在$A和B$的初始长度最多2m,可以分治NTT解决

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include <stdio.h>
#include <algorithm>
using namespace std;
const int N = 1 << 17, mod = 998244353;

char buf[1 << 20], *p1, *p2;
#define GC (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? 0 : *p1 ++)
inline int _R() {
int d = 0; char t; bool ty = 1;
while (t = GC, (t < '0' || t > '9') && t != '-');
t == '-' ? (ty = 0) : (d = t - '0');
while (t = GC, t >= '0' && t <= '9') d = (d << 3) + (d << 1) + t - '0';
return ty ? d : -d;
}

inline int add(int a, int b) { a += b; return a >= mod ? a - mod : a; }
inline int sub(int a, int b) { a -= b; return a < 0 ? a + mod : a; }
inline int mul(int a, int b) { return 1LL * a * b % mod; }
inline int ksm(int a, int b) { int r; for (r = 1; b; b >>= 1, a = mul(a, a)) if (b & 1) r = mul(r, a); return r; }


int s, n, m, a[N];
int fact[N], inv_fact[N], inv[N], e[N], S[88][88];

void init() {
int i, j;
fact[0] = inv[1] = inv_fact[0] = e[0] = 1;
for (i = 2; i <= n; i ++) inv[i] = mul(mod - mod / i, inv[mod % i]);
for (i = 1; i <= n; i ++) fact[i] = mul(fact[i - 1], i);
for (i = 1; i <= n; i ++) inv_fact[i] = mul(inv_fact[i - 1], inv[i]);
for (i = 1; i <= n - 2; i ++) e[i] = mul(e[i - 1], s);
S[0][0] = 1;
for (i = 1; i <= 2 * m + 1; i ++)
for (j = 1; j <= 2 * m + 1; j ++)
S[i][j] = add(S[i - 1][j - 1], mul(S[i - 1][j], j));
}

int ntt_wi[2][N], MAX;
int __pre_ntt(int p) {
MAX = p;
int i, f = ksm(3, (mod - 1) / MAX), g = ksm(f, mod - 2);
ntt_wi[1][0] = ntt_wi[0][0] = 1;
for (i = 1; i < p; i ++) {
ntt_wi[1][i] = mul(ntt_wi[1][i - 1], f);
ntt_wi[0][i] = mul(ntt_wi[0][i - 1], g);
}
}

void ntt(int A[], int n, int ty) {
int i, j, k, t, w, f, g;
for (i = j = 0; i < n; i ++) {
if (i < j) swap(A[i], A[j]);
for (k = n >> 1; (j ^= k) < k; k >>= 1);
}
for (w = 1; t = MAX / (w << 1), w < n; w <<= 1)
for (k = 0; k < n; k += w << 1)
for (i = k, j = 0; i < k + w; i ++, j += t) {
f = A[i], g = mul(A[i + w], ntt_wi[ty][j]);
A[i] = add(f, g), A[i + w] = sub(f, g);
}
if (ty == 1) return ;
f = ksm(n, mod - 2);
for (i = 0; i < n; i ++) A[i] = mul(A[i], f);
}

int A[N], B[N], SA[20][N], SB[20][N];
int solve(int l, int r, int d) {
int i, j, k, len;
if (l == r) {
len = min(m * 2 + 1, n - 1);
for (i = 0; i < len; i ++)
A[i] = mul(S[2 * m + 1][i + 1], ksm(a[l], i + 1)), B[i] = 0;
for (i = 0; i < min(len, m + 1); i ++)
B[i] = mul(S[m + 1][i + 1], ksm(a[l], i + 1));
return len;
}

int mid = l + r >> 1, llen, rlen, p, *A0 = SA[d], *B0 = SB[d];
llen = solve(l, mid, d + 1);
copy(A, A + llen, A0);
copy(B, B + llen, B0);

rlen = solve(mid + 1, r, d + 1);
len = llen + rlen - 1;

for (p = 1; p < len; p <<= 1) ;
fill(A0 + llen, A0 + p, 0);
fill(B0 + llen, B0 + p, 0);
fill(A + rlen, A + p, 0);
fill(B + rlen, B + p, 0);

ntt(A0, p, 1), ntt(B, p, 1);
ntt(B0, p, 1), ntt(A, p, 1);
for (i = 0; i < p; i ++) {
A[i] = add(mul(A[i], B0[i]), mul(A0[i], B[i]));
B[i] = mul(B[i], B0[i]);
}
ntt(A, p, 0);
ntt(B, p, 0);

return min(len, n - 1);
}

int main() {
int i, j, k;
n = _R(), m = _R();
for (i = 1; i <= n; i ++) a[i] = _R(), s = add(s, a[i]);
init();

int p;
for (p = 1; p < n << 1; p <<= 1);
__pre_ntt(p);

solve(1, n, 0);

int ans = 0;
for (i = 0; i <= n - 2; i ++)
ans = add(ans, mul(mul(e[i], inv_fact[i]), A[n - 2 - i]));
printf("%d\n", mul(ans, fact[n - 2]));
}