首先为了方便,把 \(0\) 当成 \(-1\) 来看。这样只要边权和为 \(0\) 即可算阴阳平衡。
找重心然后求所有点到分治重心的边权和,然后记录每一个点 \(x\) 是否可以在其到重心路径上找到一个点 \(y\) 使得 \(x,y\) 路径阴阳平衡,记为 \(valid_x\)。
考虑统计路径时两个点到重心的路径能拼成合法路径的条件:
- 不在同一棵子树。
- 其中一个点 \(valid\) 值为真。
- 整条路径阴阳平衡。
整条阴阳平衡并且一半阴阳平衡那么另一半肯定也阴阳平衡,所以可以如此统计。
求 \(valid\) 和合并路径时都可以记录每个 \(dis\) 对应多少个点。
注意若干细节特判。
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
using namespace std;
const int N = 1e5;
int n;
int to[(N << 1) + 5],val[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v,int w)
{
static int tot = 0;
to[++tot] = v,val[tot] = w,pre[tot] = first[u],first[u] = tot;
}
int rt,sum;
int vis[N + 5],sz[N + 5],max_part[N + 5],dis[N + 5];
int temp[3][(N << 1) + 5];
int *f = temp[0] + N,*g = temp[1] + N,*cnt = temp[2] + N;
long long ans;
void get_rt(int p,int fa)
{
sz[p] = 1,max_part[p] = 0;
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]] && to[i] ^ fa)
get_rt(to[i],p),sz[p] += sz[to[i]],max_part[p] = max(max_part[p],sz[to[i]]);
max_part[p] = max(max_part[p],sum - sz[p]);
if(max_part[p] < max_part[rt])
rt = p;
}
void get_dis(int p,int fa)
{
if(cnt[dis[p]])
++g[dis[p]];
else
++f[dis[p]];
++cnt[dis[p]];
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]] && to[i] ^ fa)
get_dis(to[i],p);
--cnt[dis[p]];
}
void get_ans(int p,int fa)
{
ans += g[-dis[p]];
if(cnt[dis[p]])
ans += f[-dis[p]];
if(!dis[p])
ans += cnt[0] > 1;
++cnt[dis[p]];
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]] && to[i] ^ fa)
dis[to[i]] = dis[p] + val[i],get_ans(to[i],p);
--cnt[dis[p]];
}
void clear(int p,int fa)
{
if(cnt[dis[p]])
--g[dis[p]];
else
--f[dis[p]];
++cnt[dis[p]];
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]] && to[i] ^ fa)
clear(to[i],p);
--cnt[dis[p]];
}
void solve(int p)
{
vis[p] = 1,dis[p] = 0,++cnt[0];
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]])
dis[to[i]] = dis[p] + val[i],get_ans(to[i],p),get_dis(to[i],p);
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]])
clear(to[i],p);
--cnt[0];
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]])
{
rt = 0,sum = sz[to[i]],get_rt(to[i],p);
solve(rt);
}
}
int main()
{
scanf("%d",&n);
int u,v,w;
for(register int i = 1;i < n;++i)
scanf("%d%d%d",&u,&v,&w),!w && (w = -1),add(u,v,w),add(v,u,w);
max_part[0] = 0x3f3f3f3f,sum = n,get_rt(1,0),solve(rt);
printf("%lld\n",ans);
}