容易想到对每个点计算其作为重心的贡献。
设 \(f_{i,j},g_{i,j}\) 分别表示 \(i\) 的子树内,包含 \(i\) 的大小为 \(j\) 的保证或不保证以 \(i\) 为重心的连通块的个数。
树形背包计算。
然而要明确计算出贡献,还需要知道往上的连通块方案数。
这可能需要换根 DP,复杂度将达到 \(O(nk^2)\)。
实际上,可以将贡献放在连通块的顶端计算贡献。设 \(h_{i,j}\) 表示 \(i\) 的子树内,所有包含 \(i\) 的大小为 \(j\) 的连通块的重心贡献和。
转移时应当结合 \(f\)。
有一个问题,这个东西的复杂度为何是 \(O(nk)\)?
这里贴一个 @mrsrz 教我的证明:
设所有未合并至父亲的背包大小的多重集为 \(S\),设其势能函数为 \[ \Phi(S) = \frac12 \sum\limits_{x\in S} (3xk - x^2) \]
则初始时势能为 \(\Phi(S_0) = \frac13(3nk-n)\),最终 \(\Phi(S_{n-1}) = k^2\)。
考虑一次合并大小为 \(a,b\) 的两个背包的摊还代价: \[
\begin{aligned}
\hat c_i
& = c_i + \Phi(S_i) - \Phi(S_{i-1}) \\
& = ab + \frac12(3\min(a+b,k)k - \min(a+b,k)^2 - 3ak + a^2 - 3bk + b^2) \\
& = \frac12(a+b+\min(a+b,k))(a+b-\min(a+b,k)-3k) \\
& \le 0
\end{aligned}
\]
于是总代价为 \[ \sum\limits_{i=1}^{n-1} c_i = \Phi(S_0) - \Phi(S_{n-1}) + \sum\limits_{i=1}^{n-1} \hat c_i \le \Phi(S_0) - \Phi(S_{n-1}) = O(nk) \]
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
using namespace std;
const int N = 5e4;
const int K = 500;
const int mod = 1e9 + 7;
int n,k,a[N + 5];
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
int fa[N + 5],sz[N + 5];
int f[N + 5][K + 5],g[N + 5][K + 5],h[N + 5][K + 5];
int t[K + 5];
void dfs(int p)
{
sz[p] = 1,f[p][1] = g[p][1] = 1;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,dfs(to[i]);
memset(t,0,sizeof t);
for(register int j = 1;j <= min(sz[p],k);++j)
if(f[p][j])
for(register int l = 0;l <= min(sz[to[i]],k >> 1) && j + l <= k;++l)
if(g[to[i]][l])
t[j + l] = (t[j + l] + (long long)f[p][j] * g[to[i]][l]) % mod;
memcpy(f[p],t,sizeof t),memset(t,0,sizeof t);
for(register int j = 1;j <= min(sz[p],k);++j)
if(g[p][j])
for(register int l = 0;l <= min(sz[to[i]],k) && j + l <= k;++l)
if(g[to[i]][l])
t[j + l] = (t[j + l] + (long long)h[p][j] * g[to[i]][l]) % mod,
t[j + l] = (t[j + l] + (long long)h[to[i]][l] * g[p][j]) % mod;
memcpy(h[p],t,sizeof t),memset(t,0,sizeof t);
for(register int j = 1;j <= min(sz[p],k);++j)
if(g[p][j])
for(register int l = 0;l <= min(sz[to[i]],k) && j + l <= k;++l)
if(g[to[i]][l])
t[j + l] = (t[j + l] + (long long)g[p][j] * g[to[i]][l]) % mod;
memcpy(g[p],t,sizeof t),sz[p] += sz[to[i]];
}
f[p][0] = g[p][0] = 1;
for(register int i = k + 1 >> 1;i <= k;++i)
if(i == (k >> 1))
fa[p] > p && (h[p][i] = (h[p][i] + (long long)(a[p] - a[fa[p]] + mod) * f[p][i]) % mod);
else
h[p][i] = (h[p][i] + (long long)a[p] * f[p][i]) % mod;
}
int ans;
int main()
{
freopen("centroid.in","r",stdin),freopen("centroid.out","w",stdout);
scanf("%d%d",&n,&k);
for(register int i = 1;i <= n;++i)
scanf("%d",a + i);
int u,v;
for(register int i = 2;i <= n;++i)
scanf("%d%d",&u,&v),add(u,v),add(v,u);
dfs(1);
for(register int i = 1;i <= n;++i)
ans = (ans + h[i][k]) % mod;
printf("%d\n",ans);
}