LibreOJ 6276 果树

非常有趣的一道扫描线题,巧妙地把路径计数转化成求矩形面积并。
这题是找 LZW 学长要来的。

观察数据范围,有一个特殊的限制: > 同一颜色在树上出现不超过 \(20\) 次。

所以第一个想到的是可以枚举每种颜色,设该颜色的点数有 \(s\) 个,把一条路径不能包含两个同颜色的结点转化为两个同颜色的结点不能同时出现在一条路径上,用 \(O(s^2)\) 找出这样的限制。

对于两个同颜色的点 \(u,v\),分类讨论:

  • \(u,v\) 没有祖孙关系,则相当于所有路径两端分别在 \(u\) 的子树中和 \(v\) 的子树中的路径都不合法。
  • \(u\)\(v\) 的祖先(可以交换),则设 \(u,v\) 路径上除 \(u\) 外距 \(u\) 最近的点为 \(t\),则相当于所有路径两端,其一不在 \(t\) 的子树中,另一个在 \(v\) 的子树中,的路径都不合法。

这么多「子树」「子树」的,第一反应是用 DFS 序转化成一段区间,于是限制条件就变成了 \(n \times n\) 矩阵上的子矩阵。
然后取一下矩阵的并,取补集,加上对角线再除以二就是答案了。
因为最后出来的对角线只有一条,直接除以二是不行的;根据题意显然对角线都是计入答案的,所以直接加上对角线(\(n\) 条路径)除以二即可。

注意扫描线本来是用于求矩形的并而我们求的是矩阵的并,需要做一些变化。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <cstdio>
#include <vector>
#include <algorithm>
#define ls (p << 1)
#define rs (ls | 1)
using namespace std;
const int N = 1e5;
const int C = 1e5;
const int LG = 20;
int n;
int c[N + 5];
vector<int> s[C + 5];
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v;
pre[tot] = first[u];
first[u] = tot;
}
int id[N + 5],sz[N + 5],f[N + 5][LG + 5],dep[N + 5];
void dfs(int p)
{
static int tot = 0;
id[p] = ++tot,sz[p] = 1;
for(register int i = 1;i <= LG;++i)
f[p][i] = f[f[p][i - 1]][i - 1];
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ f[p][0])
dep[to[i]] = dep[p] + 1,f[to[i]][0] = p,dfs(to[i]),sz[p] += sz[to[i]];
}
struct node
{
int cnt,len;
} seg[(N << 2) + 10];
struct edge
{
int x,y1,y2,k;
inline bool operator<(const edge &o) const
{
return x < o.x;
}
} e[(N << 7) + 10];
int tot;
long long ans;
inline void up(int p,int tl,int tr)
{
if(seg[p].cnt)
seg[p].len = tr - tl + 1;
else if(tl == tr)
seg[p].len = 0;
else
seg[p].len = seg[ls].len + seg[rs].len;
}
void update(int l,int r,int k,int p,int tl,int tr)
{
if(l <= tl && tr <= r)
{
seg[p].cnt += k;
up(p,tl,tr);
return ;
}
int mid = tl + tr >> 1;
if(l <= mid)
update(l,r,k,ls,tl,mid);
if(r > mid)
update(l,r,k,rs,mid + 1,tr);
up(p,tl,tr);
}
int main()
{
scanf("%d",&n);
for(register int i = 1;i <= n;++i)
scanf("%d",c + i),s[c[i]].push_back(i);
int u,v;
for(register int i = 1;i < n;++i)
scanf("%d%d",&u,&v),add(u,v),add(v,u);
dep[1] = 1,dfs(1);
for(register int i = 1;i <= n;++i)
for(register int j = 0;j + 1 < s[i].size();++j)
for(register int k = j + 1;k < s[i].size();++k)
{
int x = s[i][j],y = s[i][k];
if((id[x] <= id[y] && id[y] < id[x] + sz[x]) || (id[y] <= id[x] && id[x] < id[y] + sz[y]))
{
if(dep[x] > dep[y])
swap(x,y);
int t = y;
for(register int i = LG;~i;--i)
if(f[t][i] && dep[f[t][i]] > dep[x])
t = f[t][i];
if(id[t] > 1)
{
e[++tot] = (edge){1,id[y],id[y] + sz[y] - 1,1};
e[++tot] = (edge){id[t],id[y],id[y] + sz[y] - 1,-1};
e[++tot] = (edge){id[y],1,id[t] - 1,1};
e[++tot] = (edge){id[y] + sz[y],1,id[t] - 1,-1};
}
if(id[t] + sz[t] <= n)
{
e[++tot] = (edge){id[t] + sz[t],id[y],id[y] + sz[y] - 1,1};
e[++tot] = (edge){n + 1,id[y],id[y] + sz[y] - 1,-1};
e[++tot] = (edge){id[y],id[t] + sz[t],n,1};
e[++tot] = (edge){id[y] + sz[y],id[t] + sz[t],n,-1};
}
}
else
{
e[++tot] = (edge){id[x],id[y],id[y] + sz[y] - 1,1};
e[++tot] = (edge){id[x] + sz[x],id[y],id[y] + sz[y] - 1,-1};
e[++tot] = (edge){id[y],id[x],id[x] + sz[x] - 1,1};
e[++tot] = (edge){id[y] + sz[y],id[x],id[x] + sz[x] - 1,-1};
}
}
sort(e + 1,e + tot + 1);
for(register int i = 1;i <= tot;++i)
{
update(e[i].y1,e[i].y2,e[i].k,1,1,n);
if(i < tot)
ans += (long long)(e[i + 1].x - e[i].x) * seg[1].len;
}
printf("%lld\n",((long long)n * (n + 1) - ans) / 2);
}