JZOJ 5058 采蘑菇

点分治。

DFS 时分别考虑当前点到当前分治重心的路径上出现了的和没出现的颜色的贡献,并不难写。
(事实上可以线段树 + 换根做 qwq)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <cstdio>
#include <algorithm>
using namespace std;

const int BUFF_SIZE = 1 << 20;
char BUFF[BUFF_SIZE],*BB,*BE;
#define gc() (BB == BE ? (BE = (BB = BUFF) + fread(BUFF,1,BUFF_SIZE,stdin),BB == BE ? EOF : *BB++) : *BB++)
template<class T>
inline void read(T &x)
{
x = 0;
char ch = 0,w = 0;
for(;ch < '0' || ch > '9';w |= ch == '-',ch = gc());
for(;ch >= '0' && ch <= '9';x = (x << 3) + (x << 1) + (ch ^ '0'),ch = gc());
w && (x = -x);
}

const int N = 3e5;
int n,a[N + 5];
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
int vis[N + 5],sum;
int sz[N + 5],max_part[N + 5],rt;
void get_rt(int p,int fa)
{
sz[p] = 1,max_part[p] = 0;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa && !vis[to[i]])
get_rt(to[i],p),sz[p] += sz[to[i]],max_part[p] = max(max_part[p],sz[to[i]]);
max_part[p] = max(max_part[p],sum - sz[p]),
max_part[p] < max_part[rt] && (rt = p);
}
void get_sz(int p,int fa)
{
sz[p] = 1;
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]] && (to[i] ^ fa))
get_sz(to[i],p),sz[p] += sz[to[i]];
}
int cnt[N + 5],Cnt,c[N + 5];
long long csum;
long long ans[N + 5];
void update(int p,int fa,int op = 1)
{
!cnt[a[p]]++ && (c[a[p]] += op * sz[p],csum += op * sz[p]);
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]] && (to[i] ^ fa))
update(to[i],p,op);
--cnt[a[p]];
}
void calc(int p,int fa)
{
!cnt[a[p]]++ && (++Cnt,csum -= c[a[p]]);
ans[p] += (long long)Cnt * sum + csum,ans[rt] += Cnt;
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]] && (to[i] ^ fa))
calc(to[i],p);
!--cnt[a[p]] && (--Cnt,csum += c[a[p]]);
}
void solve(int p)
{
vis[p] = 1,++ans[p],get_sz(p,0);
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]])
update(to[i],p);
++cnt[a[p]],++Cnt,csum -= c[a[p]];
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]])
update(to[i],p,-1),sum -= sz[to[i]],calc(to[i],p),update(to[i],p),sum += sz[to[i]];
--cnt[a[p]],--Cnt,csum += c[a[p]];
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]])
update(to[i],p,-1);
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]])
{
rt = 0,sum = sz[to[i]],get_rt(to[i],p);
solve(rt);
}
}
int main()
{
freopen("mushroom.in","r",stdin),freopen("mushroom.out","w",stdout);
max_part[0] = 0x3f3f3f3f;
read(n);
for(register int i = 1;i <= n;++i)
read(a[i]);
int u,v;
for(register int i = 2;i <= n;++i)
read(u),read(v),add(u,v),add(v,u);
sum = n,get_rt(1,0),solve(rt);
for(register int i = 1;i <= n;++i)
printf("%lld\n",ans[i]);
}