JZOJ 6815 树的重心

容易想到对每个点计算其作为重心的贡献。

fi,j,gi,j 分别表示 i 的子树内,包含 i 的大小为 j 的保证或不保证以 i 为重心的连通块的个数。
树形背包计算。

然而要明确计算出贡献,还需要知道往上的连通块方案数。
这可能需要换根 DP,复杂度将达到 O(nk2)

实际上,可以将贡献放在连通块的顶端计算贡献。设 hi,j 表示 i 的子树内,所有包含 i 的大小为 j 的连通块的重心贡献和。
转移时应当结合 f

有一个问题,这个东西的复杂度为何是 O(nk)
这里贴一个 @mrsrz 教我的证明:

设所有未合并至父亲的背包大小的多重集为 S,设其势能函数为 Φ(S)=12xS(3xkx2)

则初始时势能为 Φ(S0)=13(3nkn),最终 Φ(Sn1)=k2
考虑一次合并大小为 a,b 的两个背包的摊还代价: ˆci=ci+Φ(Si)Φ(Si1)=ab+12(3min(a+b,k)kmin(a+b,k)23ak+a23bk+b2)=12(a+b+min(a+b,k))(a+bmin(a+b,k)3k)0

于是总代价为 n1i=1ci=Φ(S0)Φ(Sn1)+n1i=1ˆciΦ(S0)Φ(Sn1)=O(nk)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 5e4;
const int K = 500;
const int mod = 1e9 + 7;
int n,k,a[N + 5];
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
int fa[N + 5],sz[N + 5];
int f[N + 5][K + 5],g[N + 5][K + 5],h[N + 5][K + 5];
int t[K + 5];
void dfs(int p)
{
sz[p] = 1,f[p][1] = g[p][1] = 1;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,dfs(to[i]);
memset(t,0,sizeof t);
for(register int j = 1;j <= min(sz[p],k);++j)
if(f[p][j])
for(register int l = 0;l <= min(sz[to[i]],k >> 1) && j + l <= k;++l)
if(g[to[i]][l])
t[j + l] = (t[j + l] + (long long)f[p][j] * g[to[i]][l]) % mod;
memcpy(f[p],t,sizeof t),memset(t,0,sizeof t);
for(register int j = 1;j <= min(sz[p],k);++j)
if(g[p][j])
for(register int l = 0;l <= min(sz[to[i]],k) && j + l <= k;++l)
if(g[to[i]][l])
t[j + l] = (t[j + l] + (long long)h[p][j] * g[to[i]][l]) % mod,
t[j + l] = (t[j + l] + (long long)h[to[i]][l] * g[p][j]) % mod;
memcpy(h[p],t,sizeof t),memset(t,0,sizeof t);
for(register int j = 1;j <= min(sz[p],k);++j)
if(g[p][j])
for(register int l = 0;l <= min(sz[to[i]],k) && j + l <= k;++l)
if(g[to[i]][l])
t[j + l] = (t[j + l] + (long long)g[p][j] * g[to[i]][l]) % mod;
memcpy(g[p],t,sizeof t),sz[p] += sz[to[i]];
}
f[p][0] = g[p][0] = 1;
for(register int i = k + 1 >> 1;i <= k;++i)
if(i == (k >> 1))
fa[p] > p && (h[p][i] = (h[p][i] + (long long)(a[p] - a[fa[p]] + mod) * f[p][i]) % mod);
else
h[p][i] = (h[p][i] + (long long)a[p] * f[p][i]) % mod;
}
int ans;
int main()
{
freopen("centroid.in","r",stdin),freopen("centroid.out","w",stdout);
scanf("%d%d",&n,&k);
for(register int i = 1;i <= n;++i)
scanf("%d",a + i);
int u,v;
for(register int i = 2;i <= n;++i)
scanf("%d%d",&u,&v),add(u,v),add(v,u);
dfs(1);
for(register int i = 1;i <= n;++i)
ans = (ans + h[i][k]) % mod;
printf("%d\n",ans);
}

Related Issues not found

Please contact @Alpha1022 to initialize the comment