BZOJ 4999 This Problem Is Too Simple!

题目名字持续嘲讽……
非常套路。

首先考虑对每个颜色开一棵线段树,然后树剖维护即可。
这个想法非常显然(\(O(n \log ^2 n)\)),但是有更好的做法。

发现每种操作最多涉及两种颜色(删掉原来的,添加新的),所以我们可以对每种颜色独立讨论。
那么问题变成树上单点修改,路径查询。
这个用 DFS 序 + 树状数组就可以做到 \(O(n \log n)\) 了。

以下代码为某 naive 爆的两只 \(\log\) 做法。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include <cstdio>
#include <algorithm>
#include <tr1/unordered_map>
#define ls seg[p].lson
#define rs seg[p].rson
using namespace std;
using namespace tr1;

const int BUFF_SIZE = 3 << 20;
char BUFF[BUFF_SIZE],*BB,*BE;
#define gc() (BB == BE ? (BE = (BB = BUFF) + fread(BUFF,1,BUFF_SIZE,stdin),BB == BE ? EOF : *BB++) : *BB++)
template<class T>
inline void read(T &x)
{
x = 0;
char ch = 0,w = 0;
while(ch < '0' || ch > '9')
w |= ch == '-',ch = gc();
while(ch >= '0' && ch <= '9')
x = (x << 3) + (x << 1) + (ch ^ '0'),ch = gc();
w ? x = -x : x;
}

const int N = 1e5;
const int Q = 2e5;
int n,q;
int to[(N << 1) + 10],pre[(N << 1) + 10],first[N + 10];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v;
pre[tot] = first[u];
first[u] = tot;
}
unordered_map<int,int> col_id;
int col_tot;
int rt[N + Q + 10],a[N + 10];
struct segnode
{
int sum,lson,rson;
} seg[((N + Q) << 5) + 10];
void modify(int x,int k,int &p,int tl,int tr)
{
static int tot = 0;
if(!p)
p = ++tot;
if(tl == tr)
{
seg[p].sum += k;
return ;
}
int mid = tl + tr >> 1;
if(x <= mid)
modify(x,k,ls,tl,mid);
else
modify(x,k,rs,mid + 1,tr);
seg[p].sum = seg[ls].sum + seg[rs].sum;
}
int query(int l,int r,int p,int tl,int tr)
{
if(!p)
return 0;
if(l <= tl && tr <= r)
return seg[p].sum;
int mid = tl + tr >> 1;
int ret = 0;
if(l <= mid)
ret += query(l,r,ls,tl,mid);
if(r > mid)
ret += query(l,r,rs,mid + 1,tr);
return ret;
}
int fa[N + 10],dep[N + 10],sz[N + 10],son[N + 10],top[N + 10],id[N + 10],rk[N + 10];
void dfs1(int p)
{
sz[p] = 1;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,dep[to[i]] = dep[p] + 1,dfs1(to[i]),sz[p] += sz[to[i]];
if(!son[p] || sz[to[i]] > sz[son[p]])
son[p] = to[i];
}
}
void dfs2(int p)
{
static int tot = 0;
rk[id[p] = ++tot] = p;
if(!son[p])
return ;
top[son[p]] = top[p];
dfs2(son[p]);
for(register int i = first[p];i;i = pre[i])
if(!top[to[i]])
top[to[i]] = to[i],dfs2(to[i]);
}
int answer(int x,int y,int k)
{
int ret = 0;
while(top[x] ^ top[y])
dep[top[x]] > dep[top[y]] ? (ret += query(id[top[x]],id[x],rt[col_id[k]],1,n),x = fa[top[x]]) : (ret += query(id[top[y]],id[y],rt[col_id[k]],1,n),y = fa[top[y]]);
if(dep[x] > dep[y])
swap(x,y);
ret += query(id[x],id[y],rt[col_id[k]],1,n);
return ret;
}
int main()
{
read(n),read(q);
for(register int i = 1;i <= n;++i)
read(a[i]);
int u,v;
for(register int i = 1;i < n;++i)
read(u),read(v),add(u,v),add(v,u);
dep[1] = top[1] = 1,dfs1(1),dfs2(1);
for(register int i = 1;i <= n;++i)
{
if(!col_id.count(a[i]))
col_id[a[i]] = ++col_tot;
modify(id[i],1,rt[col_id[a[i]]],1,n);
}
while(q--)
{
char op = 0;
while((op ^ 'Q') && (op ^ 'C'))
op = gc();
int x,y,z;
if(op == 'Q')
{
read(x),read(y),read(z);
if(!col_id.count(z))
puts("0");
else
printf("%d\n",answer(x,y,z));
}
else
{
read(x),read(y);
modify(id[x],-1,rt[col_id[a[x]]],1,n);
a[x] = y;
if(!col_id.count(a[x]))
col_id[a[x]] = ++col_tot;
modify(id[x],1,rt[col_id[a[x]]],1,n);
}
}
}