LibreOJ 3102 「JSOI2019」神经网络

虽然做过类似的题,但是比较有趣的说(

首先应该能考虑到把每棵树划分成链,再拼成回路。

首先需要求 \(f_i(n)\) 表示第 \(i\) 棵树划分为 \(n\) 条链的方案数。
这个可以用一个提高水平的树形 DP 解决:设 \(g_{u,i,0/1/2}\) 表示 \(u\) 的子树内,已经确定形态的有 \(i\) 条链,\(u\) 所属的链还有 \(0/1/2\) 个端点没有确定(即可以连接其他的链)。
转移并不难。但考虑到对于点数大于 \(1\) 的链,都有正反两种方案,这是应当在转移时特殊处理的。

得到这个之后,因为划分的链拼成回路的过程中不能有相邻的链来自同一棵树,所以可以考虑容斥。
对于第 \(i\) 棵树,考虑构造 EGF \[ F_i(x) = \sum\limits_{j=1}^k \frac{x^j}{j!} \sum\limits_{u=j}^k (-1)^{u-j} \binom{u-1}{j-1} u! f_i(u) \]

然后将 EGF 暴力乘起来即可。
复杂度 \(O(\sum k_i^2)\)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <cstdio>
#include <vector>
#include <cstring>
using namespace std;
const int N = 5e3;
const int M = 300;
const int mod = 998244353;
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int n,m,k;
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5],edge_tot;
inline void add(int u,int v)
{
to[++edge_tot] = v,pre[edge_tot] = first[u],first[u] = edge_tot;
}
int fac[N + 5],ifac[N + 5];
inline int C(int n,int m)
{
return n < m ? 0 : (long long)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
int fa[N + 5],sz[N + 5];
int g[N + 5][N + 5][3],temp[N + 5][3];
void dfs(int p)
{
sz[p] = 1;
g[p][1][0] = g[p][0][1] = g[p][0][2] = 1,
g[p][1][1] = g[p][1][2] = g[p][0][0] = 0;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,dfs(to[i]);
for(register int x = 0;x <= sz[p] + sz[to[i]];++x)
memset(temp[x],0,sizeof temp[x]);
for(register int x = 0;x <= sz[p];++x)
for(register int y = 0;y <= sz[to[i]];++y)
temp[x + y][0] = (temp[x + y][0] + (long long)g[p][x][0] * g[to[i]][y][0]) % mod,
temp[x + y][1] = (temp[x + y][1] + (long long)g[p][x][1] * g[to[i]][y][0]) % mod,
temp[x + y][2] = (temp[x + y][2] + (long long)g[p][x][2] * g[to[i]][y][0]) % mod,
temp[x + y][1] = (temp[x + y][1] + (long long)g[p][x][2] * g[to[i]][y][1]) % mod,
temp[x + y + 1][0] = (temp[x + y + 1][0] + 2LL * g[p][x][1] * g[to[i]][y][1]) % mod;
for(register int x = 0;x <= sz[p] + sz[to[i]];++x)
memcpy(g[p][x],temp[x],sizeof g[p][x]);
sz[p] += sz[to[i]];
}
}
vector<int> f[M + 5];
int ans;
int main()
{
fac[0] = 1;
for(register int i = 1;i <= N;++i)
fac[i] = (long long)fac[i - 1] * i % mod;
ifac[N] = fpow(fac[N],mod - 2);
for(register int i = N;i;--i)
ifac[i - 1] = (long long)ifac[i] * i % mod;
scanf("%d",&m),f[0].resize(1),f[0][0] = 1;
for(register int i = 1;i <= m;++i)
{
scanf("%d",&k),edge_tot = 0,memset(first + 1,0,sizeof(int) * k);
int u,v;
for(register int j = 2;j <= k;++j)
scanf("%d%d",&u,&v),add(u,v),add(v,u);
dfs(1),f[i].resize(k + 1);
for(register int x = 1;x <= k;++x)
for(register int j = 1;j <= x;++j)
f[i][j] = (f[i][j] + (long long)fac[x] * g[1][x][0] % mod * C(x - 1,j - 1) % mod * (x - j & 1 ? mod - 1 : 1) % mod * ifac[j]) % mod;
vector<int> res(n + k + 1);
for(register int x = 0;x <= n;++x)
for(register int y = 0;y <= k;++y)
res[x + y] = (res[x + y] + (long long)f[i - 1][x] * f[i][y]) % mod;
f[i] = res,n += k;
}
for(register int i = 1;i <= n;++i)
ans = (ans + (long long)f[m][i] * fac[i - 1]) % mod;
printf("%d\n",ans);
}