BZOJ 3653 谈笑风生

首先,非常显然的一点是,\(a,b,c\) 在一条到根的链上。
然后我们可以分别讨论 \(a\)\(b\) 的祖孙关系。

首先记 \(size_p\) 表示以 \(p\) 为根的子树大小,\(dep_p\) 表示 \(p\) 的深度。
那么同一条链上 \(a,b\) 的距离就是 \(|dep_a - dep_b|\)

先考虑 \(b\)\(a\) 的祖先的情况,那么 \(c\)\(size_a - 1\) 种可能(除去 \(a\) 本身),\(b\)\(\min(dep_a - 1,x)\) 种可能(避免距离超出 \(x\) 或超出根)。
乘起来就好了。

再考虑 \(a\)\(b\) 的祖先的情况,这个时候就比较麻烦了。
假如已知 \(b\)\(c\) 的可能性有 \(size_b - 1\) 种。
然鹅现在我们要考虑 \(b\) 的选择要满足的条件: 1. \(b\)\(a\) 的子树内。 2. \(dep_b \in (dep_a,dep_a + x]\)

对于第二个条件,我们可以用权值线段树来维护。
对于第一个条件,我们可以线段树合并。
然后在线的话,需要把线段树合并的过程可持久化一下。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <cstdio>
#include <algorithm>
#define ls(p) seg[p].ls
#define rs(p) seg[p].rs
using namespace std;

const int BUFF_SIZE = 1 << 20;
char BUFF[BUFF_SIZE],*BB,*BE;
#define gc() (BB == BE ? (BE = (BB = BUFF) + fread(BUFF,1,BUFF_SIZE,stdin),BB == BE ? EOF : *BB++) : *BB++)
template<class T>
inline void read(T &x)
{
x = 0;
char ch = 0,w = 0;
for(;ch < '0' || ch > '9';w |= ch == '-',ch = gc());
for(;ch >= '0' && ch <= '9';x = (x << 3) + (x << 1) + (ch ^ '0'),ch = gc());
w ? x = -x : x;
}

const int N = 3e5;
int n,q;
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
struct node
{
long long sum;
int ls,rs;
} seg[(N << 6) + 10];
int rt[N + 5],seg_tot;
void insert(int x,long long k,int &p,int tl,int tr)
{
if(!p)
p = ++seg_tot;
seg[p].sum += k;
if(tl == tr)
return ;
int mid = tl + tr >> 1;
x <= mid ? insert(x,k,ls(p),tl,mid) : insert(x,k,rs(p),mid + 1,tr);
}
long long query(int l,int r,int p,int tl,int tr)
{
if(!p || (l <= tl && tr <= r))
return seg[p].sum;
int mid = tl + tr >> 1;
long long ret = 0;
if(l <= mid)
ret += query(l,r,ls(p),tl,mid);
if(r > mid)
ret += query(l,r,rs(p),mid + 1,tr);
return ret;
}
int merge(int x,int y)
{
if(!x || !y)
return x | y;
int p = ++seg_tot;
seg[p].sum = seg[x].sum + seg[y].sum;
ls(p) = merge(ls(x),ls(y)),rs(p) = merge(rs(x),rs(y));
return p;
}
int fa[N + 5],dep[N + 5],sz[N + 5];
void dfs1(int p)
{
sz[p] = 1;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
fa[to[i]] = p,dep[to[i]] = dep[p] + 1,dfs1(to[i]),sz[p] += sz[to[i]];
insert(dep[p],sz[p] - 1,rt[p],1,n);
}
void dfs2(int p)
{
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
dfs2(to[i]),rt[p] = merge(rt[p],rt[to[i]]);
}
int main()
{
read(n),read(q);
int u,v;
for(register int i = 1;i < n;++i)
read(u),read(v),add(u,v),add(v,u);
dep[1] = 1,dfs1(1),dfs2(1);
int p,k;
while(q--)
read(p),read(k),printf("%lld\n",query(dep[p] + 1,dep[p] + k,rt[p],1,n) + (long long)min(dep[p] - 1,k) * (sz[p] - 1));
}