LibreOJ 2542 「PKUWC2018」随机游走

看到这种东西,首先考虑 min-max 容斥(虽然好像也有不用 min-max 容斥的做法)。
\(\max(S)\) 表示走完 \(S\) 内所有点的步数(即走到 \(S\) 内每个点步数的最大值),\(\min(S)\) 表示走到 \(S\) 中任意点的步数(即走到 \(S\) 内每个点步数的最小值)。
那么有 \[ E(\max(S)) = \sum\limits_{T\subseteq S} (-1)^{|T|+1} E(\min(T)) \] 这个可以在处理完所有 \(E(\min(S))\) 后在 \(O(n 2^n)\) 复杂度内使用 FWT 计算。

于是考虑 DP,设 \(f(u)\) 表示从结点 \(u\) 出发走到 \(S\) 中任意点的期望步数。
那么有 \[ \newcommand\fa{ {\rm fa} } \newcommand\son{ {\rm son} } f(u) = \frac 1{\deg_u} (f(\fa_u) + \sum\limits_{v \in \son_u} f(v)) + 1 \]

看起来转移有环,但是实际上在树上可以用待定系数解决这样一个 DP。
\[ f(u) = k_u f(\fa_u) + b_u \]

那么有 \[ \begin{aligned} f(u) &= \frac 1{\deg_u} (f(\fa_u) + \sum\limits_{v \in \son_u} f(v)) + 1 \\ &= \frac 1{\deg_u} (f(\fa_u) + \sum\limits_{v \in \son_u} (k_v f(u) + b_v)) + 1 \\ \deg_u f(u) &= f(\fa_u) + \sum\limits_{v \in \son_u} (k_v f(u) + b_v) + \deg_u \\ &= f(\fa_u) + f(u) \sum\limits_{v \in \son_u} k_v + \sum\limits_{v \in \son_u} b_v + \deg_u \\ (\deg_u - \sum\limits_{v \in \son_u} k_v) f(u) &= f(\fa_u) + \sum\limits_{v \in \son_u} b_v + \deg_u \\ f(u) &= \frac 1{\deg_u - \sum\limits_{v \in \son_u} k_v} \cdot f(\fa_u) + \frac{\deg_u + \sum\limits_{v \in \son_u} b_v}{\deg_u - \sum\limits_{v \in \son_u} k_v} \end{aligned} \]

于是这样枚举所有 \(S\) DP 即可,复杂度 \(O(n 2^n \log p)\),其中 \(p = 998244353\)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <cstdio>
using namespace std;
const int N = 18;
const int mod = 998244353;
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int n,q,rt,s;
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
struct note
{
int k,b;
} f[N + 5];
int fa[N + 5],deg[N + 5];
int ans[(1 << N) + 5];
void dfs(int p)
{
f[p].k = f[p].b = deg[p];
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
fa[to[i]] = p,dfs(to[i]),f[p].k = (f[p].k - f[to[i]].k + mod) % mod,f[p].b = (f[p].b + f[to[i]].b) % mod;
if(s & (1 << p - 1))
f[p].k = f[p].b = 0;
else
f[p].b = (long long)f[p].b * (f[p].k = fpow(f[p].k,mod - 2)) % mod;
}
inline void fwt(int *a,int n)
{
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
a[i | j | m] = (a[i | j | m] + a[i | j]) % mod;
}
int main()
{
scanf("%d%d%d",&n,&q,&rt);
int u,v;
for(register int i = 2;i <= n;++i)
scanf("%d%d",&u,&v),add(u,v),add(v,u),++deg[u],++deg[v];
for(s = 0;s < (1 << n);++s)
dfs(rt),ans[s] = (__builtin_popcount(s) & 1) ? f[rt].b : (mod - f[rt].b) % mod;
fwt(ans,1 << n);
for(int k;q;--q)
{
scanf("%d",&k),s = 0;
for(int x;k;--k)
scanf("%d",&x),s |= 1 << x - 1;
printf("%d\n",ans[s]);
}
}