LibreOJ 3124 「CTS2019」氪金手游

首先考虑这张图是外向树的情况。
\(s_u\) 表示 \(u\) 子树内结点的 \(W\) 值之和。
则答案为 \[ \begin{aligned} & \prod\limits_{i=1}^n \frac{W_i}{\sum_{j=1}^n W_j} \sum\limits_{k=0}^{\infty} \left(1-\frac{s_i}{\sum_{j=1}^n W_j}\right)^k \\ =& \prod\limits_{i=1}^n \frac{W_i}{s_i} \end{aligned} \] 从而树形 DP,设 \(f_{u,i}\) 表示 \(u\) 的子树内 \(W\) 值之和为 \(i\),子树内的概率乘积。

然后发现不一定是外向树。
对于反向边,一个容易想到的处理方式是容斥:用不存在这条边的概率减去这条边是正向边的概率。
把容斥系数放到 DP 里即可。

复杂度 \(O(n^2)\)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <cstdio>
#include <cstring>
using namespace std;
const int N = 1e3;
const int mod = 998244353;
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int n,a[N + 5][4];
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[tot] = v,pre[tot] = first[u],first[u] = tot++;
}
int fa[N + 5],sz[N + 5],f[N + 5][N * 3 + 5],temp[N * 3 + 5];
void dfs(int p)
{
sz[p] = 1;
for(register int i = 1;i <= 3;++i)
f[p][i] = (long long)i * a[p][i] % mod;
for(register int i = first[p];~i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,dfs(to[i]);
memset(temp,0,sizeof temp);
for(register int j = 1;j <= 3 * sz[p];++j)
for(register int k = 1,v;k <= 3 * sz[to[i]];++k)
v = (long long)f[p][j] * f[to[i]][k] % mod,
i & 1 ? (temp[j] = (temp[j] + v) % mod,temp[j + k] = (temp[j + k] - v + mod) % mod) : (temp[j + k] = (temp[j + k] + v) % mod);
memcpy(f[p],temp,sizeof f[p]);
sz[p] += sz[to[i]];
}
for(register int i = 1;i <= 3 * sz[p];++i)
f[p][i] = (long long)f[p][i] * fpow(i,mod - 2) % mod;
}
int ans;
int main()
{
memset(first,-1,sizeof first);
scanf("%d",&n);
int a1,a2,a3,inv;
for(register int i = 1;i <= n;++i)
scanf("%d%d%d",&a1,&a2,&a3),inv = fpow(((a1 + a2) % mod + a3) % mod,mod - 2),
a[i][1] = (long long)a1 * inv % mod,
a[i][2] = (long long)a2 * inv % mod,
a[i][3] = (long long)a3 * inv % mod;
int u,v;
for(register int i = 1;i < n;++i)
scanf("%d%d",&u,&v),add(u,v),add(v,u);
dfs(1);
for(register int i = 1;i <= 3 * n;++i)
ans = (ans + f[1][i]) % mod;
printf("%d\n",ans);
}