JZOJ 6506 欢迎来到塞莱斯特山

第一次见这种 DP 方式,好神仙……

考虑设 \(f_{u,i}\) 表示 \(u\) 子树内的点确定了排列中的 \(i\) 个相对位置确定的,不相邻的连续段时这些连续段的贡献。
则考虑使用类似背包的转移方式,逐个儿子地合并答案。

考虑如何合并 \(f_{u,j}\)\(f_{v,k}\)\(f_{u,i}\),其中 \(v\)\(u\) 的一个儿子。
由于相对位置确定且不相邻,所以合并后会有 \(j+k-i\) 对相邻的位置产生贡献;而显然其 LCA 均为 \(u\),故贡献为 \[ f_{u,j} \cdot f_{v,k} \cdot g_{i,j,k} \cdot \mathrm{dep}_u^{j+k-i} \]

其中 \(g_{i,j,k}\) 表示将 \(j\) 个相对位置确定且不相邻的连续段和 \(k\) 个相对位置确定且不相邻的连续段合并为 \(i\) 个相对位置确定且不相邻的连续段的方案数。
这个也可以 DP 出来,具体是枚举这 \(i\) 个连续段的第一个连续段由 \(j,k\) 个连续段中各多少个组成,则 \[ g_{i,j,k} = \sum\limits_{u=1}^{\min\{j,k\}} 2 g_{i-1,j-u,k-u} + \sum\limits_{u=0}^{\min\{j-1,k\}} g_{i-1,j-u-1,k-u} + \sum\limits_{u=0}^{\min\{j,k-1\}} g_{i-1,j-u,k-u-1} \]

然而这样需要 \(O(n^4)\) 预处理,这就很不爽了。
但是注意到转移过来的状态的 \(j - k\) 的值都是确定的,所以考虑用前缀和的方式维护,注意不要计入超出范围的贡献。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <cstdio>
#define add(x,y) (x + y >= mod ? x + y - mod : x + y)
#define dec(x,y) (x < y ? x - y + mod : x - y)
using namespace std;
const int N = 500;
const int mod = 1e9 + 7;
int n,inv[N + 5];
int to[N + 5],pre[N + 5],first[N + 5];
inline void add_edge(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
int f[N + 5][N + 5],g[N + 5][N + 5][N + 5],t[N + 5],s_[N + 5][(N << 1) + 10],*s[N + 5];
int fa[N + 5],dep[N + 5],sz[N + 5];
void dfs(int p)
{
sz[p] = f[p][1] = 1;
for(register int i = first[p];i;i = pre[i])
{
dep[to[i]] = dep[p] + 1,dfs(to[i]);
for(register int j = 1;j <= sz[p] + sz[to[i]];++j)
t[j] = f[p][j],f[p][j] = 0;
for(register int j = 1,pw1 = inv[dep[p]];j <= sz[p] + sz[to[i]];++j,pw1 = (long long)pw1 * inv[dep[p]] % mod)
for(register int k = 0,pw2 = pw1;k <= sz[p];++k,pw2 = (long long)pw2 * dep[p] % mod)
for(register int l = 0,pw3 = pw2;l <= sz[to[i]];++l,pw3 = (long long)pw3 * dep[p] % mod)
f[p][j] = (f[p][j] + (long long)t[k] * f[to[i]][l] % mod * g[j][k][l] % mod * pw3 % mod) % mod;
sz[p] += sz[to[i]];
}
}
int main()
{
freopen("tree.in","r",stdin),freopen("tree.out","w",stdout);
scanf("%d",&n),g[0][0][0] = inv[1] = 1;
for(register int i = 0;i <= n;++i)
s[i] = s_[i] + N + 5;
for(register int i = 2;i <= n;++i)
inv[i] = (long long)(mod - mod / i) * inv[mod % i] % mod;
for(register int i = 1;i <= n;++i)
for(register int j = 0;j <= n;++j)
for(register int k = 0;k <= n - j;++k)
g[i][j][k] = add(g[i][j][k],add(s[i - 1][j - k],s[i - 1][j - k])),
g[i][j][k] = add(g[i][j][k],s[i - 1][j - k - 1]),
g[i][j][k] = add(g[i][j][k],s[i - 1][j - k + 1]),
s[i - 1][j - k] = add(s[i - 1][j - k],g[i - 1][j][k]);
for(register int i = 2;i <= n;++i)
scanf("%d",fa + i),add_edge(fa[i],i);
dep[1] = 1,dfs(1);
printf("%d\n",f[1][1]);
}