先考虑求 \(W(S)\)。
有一个线性规划对偶然后贪心的解释,不过大概和以下做法本质是一样的。
设 \(f'_u\) 为考虑 \(u\) 子树内的路径的答案,则有转移 \[
\newcommand\fa{ \operatorname{fa} }
\newcommand\lca{ \operatorname{lca} }
\newcommand\child{ \operatorname{child} }
\newcommand\path{ \operatorname{path} }
f'_u = \max\left\{\sum_{v \in \child(u)} f'_v\right\} \bigcup \left\{\left.w + \sum_{\fa(v) \in \path(x,y)\land v \not\in \path(x,y)} f'_v\,\right|\,(x,y,w) \in S \land \lca(x,y)=u\right\}
\]
这样的转移很阴间。不过这样的对「路径下方所挂的点」的求和,提示我们作树上差分。
设 \[
f_u = f'_u - \sum_{v\in \child(u)} f'_v
\]
这样,不难验证转移会变成 \[ f_u = \max\{0\} \bigcup \left\{\left.w - \sum_{ \fa(v) \in \path(x,y) \setminus\{u\} } f_v\,\right|\,(x,y,w) \in S \land \lca(x,y)=u\right\} \]
于是可以使用树状数组维护单点加链求和来做到 \(O((n+m)\log n)\)。
接下来考虑计算 \(f(x,y)\)。观察到,若树的根在 \(\path(x,y)\) 上,那么 \[ f(x,y) = \sum_{u\in\path(x,y)} f_u \]
也就是说问题变成了换根 DP。
设 \(g_u\) 为根在 \(u\) 子树内时 \(\fa(u)\) 的 DP 值,\(h_u\) 为根为 \(u\) 时 \(u\) 的 DP 值。
计算 \(g_u\) 时,考虑一条经过 \(\fa(u)\) 而不经过 \(u\) 的路径 \((x,y,w)\),从 \(\fa(u)\) 出发向下的一段取 \(f\),\(\fa(u)\) 到 \(\lca(x,y)\) 的一段取 \(g\),\(\lca(x,y)\) 往下的另一段也取 \(f\)。
计算 \(h_u\) 时类似,不过考虑的是经过 \(u\) 的路径。
考虑继续在 \(u\) 处枚举 \(\lca(x,y) = u\) 的路径 \((x,y,w)\),我们希望对路径上的点都挂上一个贡献,同时 \(u\) 处要特殊处理,因为有儿子的要求。
进一步,对于 \(u\) 枚举 \(v \in \child(u)\),考虑端点在 \(u\) 子树内而不在 \(v\) 子树内的路径,这样的路径就是对 \(g_v\) 有贡献的。对于 \(h_u\) 则更简单。
于是我们不妨用线段树维护 DFS 序来计算这部分转移。
\(u\) 处的特殊贡献的话,也可以直接对所有儿子建立线段树(或者如果足够闲,甚至可以通过差分和堆计算),这样不转移 \(O(1)\) 个儿子可以转化为对 \(O(1)\) 个区间取 \(\max\),然后单点查询。
时间复杂度 \(O((n+m)\log n)\)。
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
using ll = long long;
using namespace std;
const int mod = 998244353;
const int N = 3e5;
const ll inf = 0x3f3f3f3f3f3f3f3f;
int n, m;
tuple<int, int, int> path[N + 5];
tuple<int, int, int> pathLCA[N + 5];
vector<int> e[N + 5], ch[N + 5];
int label[N + 5];
int fa[N + 5], dep[N + 5], size[N + 5], son[N + 5], top[N + 5], id[N + 5], rk[N + 5];
void dfs(int u) {
static int tot = 0;
rk[id[u] = ++tot] = u, size[u] = 1;
for (int v: e[u])
if (v != fa[u]) {
fa[v] = u, dep[v] = dep[u] + 1, dfs(v), size[u] += size[v];
if (!son[u] || size[son[u]] < size[v])
son[u] = v;
label[v] = ch[u].size(), ch[u].emplace_back(v);
}
}
inline tuple<int, int, int> lca(int x, int y) {
int xson = 0, yson = 0;
while (top[x] != top[y])
if (dep[top[x]] > dep[top[y]])
xson = top[x], x = fa[top[x]];
else
yson = top[y], y = fa[top[y]];
if (dep[x] < dep[y])
yson = son[x];
else if(dep[x] > dep[y])
xson = son[y];
return {dep[x] < dep[y] ? x : y, xson, yson};
}
vector<int> pathsThrough[N + 5];
struct BinaryIndexedTree {
ll c[N + 5];
inline int lowbit(int x) {
return x & -x;
}
inline void update(int x, ll k) {
for (; x <= n; x += lowbit(x))
c[x] += k;
}
inline void update(int l, int r, ll k) {
update(l, k), update(r + 1, -k);
}
inline ll query(int x) {
ll ret = 0;
for (; x; x -= lowbit(x))
ret += c[x];
return ret;
}
} bit;
struct SegmentTree {
ll seg[N * 4 + 5];
};
struct SegmentTree_RangeUpdate: SegmentTree {
int n;
void build(int p, int tl, int tr) {
if (tl > tr)
return ;
seg[p] = 0;
if (tl == tr)
return ;
int mid = tl + tr >> 1;
build(ls, tl, mid), build(rs, mid + 1, tr);
}
void build(int m) {
build(1, 0, (n = m) - 1);
}
void update(int l, int r, ll k, int p, int tl, int tr) {
if (l > r)
return ;
if (l <= tl && tr <= r) {
seg[p] = max(seg[p], k);
return ;
}
int mid = tl + tr >> 1;
if (l <= mid)
update(l, r, k, ls, tl, mid);
if (r > mid)
update(l, r, k, rs, mid + 1, tr);
}
void update(int l, int r, ll k) {
update(l, r, k, 1, 0, n - 1);
}
ll query(int x, int p, int tl, int tr) {
if (tl == tr)
return seg[p];
int mid = tl + tr >> 1;
return max(seg[p], x <= mid ? query(x, ls, tl, mid) : query(x, rs, mid + 1, tr));
}
ll query(int x) {
return query(x, 1, 0, n - 1);
}
} seg0;
struct SegmentTree_RangeQuery: SegmentTree {
int n;
void build(int p, int tl, int tr) {
if (tl > tr)
return ;
seg[p] = -inf;
if (tl == tr)
return ;
int mid = tl + tr >> 1;
build(ls, tl, mid), build(rs, mid + 1, tr);
}
void build(int m) {
build(1, 1, n = m);
}
void insert(int x, ll k, int p, int tl, int tr) {
seg[p] = max(seg[p], k);
if (tl == tr)
return ;
int mid = tl + tr >> 1;
if (x <= mid)
insert(x, k, ls, tl, mid);
else
insert(x, k, rs, mid + 1, tr);
}
void insert(int x, ll k) {
insert(x, k, 1, 1, n);
}
ll query(int l, int r, int p, int tl, int tr) {
if (l > r)
return -inf;
if (l <= tl && tr <= r)
return seg[p];
int mid = tl + tr >> 1;
ll ret = -inf;
if (l <= mid)
ret = max(ret, query(l, r, ls, tl, mid));
if (r > mid)
ret = max(ret, query(l, r, rs, mid + 1, tr));
return ret;
}
ll query(int l, int r) {
return query(l, r, 1, 1, n);
}
} seg1;
ll f[N + 5], fSum[N + 5];
ll g[N + 5], gSum[N + 5];
ll h[N + 5];
int ans;
int main() {
scanf("%d%d", &n, &m);
for (int i = 2; i <= n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
e[u].emplace_back(v), e[v].emplace_back(u);
}
dfs(1);
for (int i = 1; i <= n; ++i) {
int u = rk[i];
top[u] = u == son[fa[u]] ? top[fa[u]] : u;
}
for (int i = 1; i <= m; ++i) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
path[i] = {u, v, w}, pathsThrough[get<0>(pathLCA[i] = lca(u, v))].emplace_back(i);
}
for (int i = n; i; --i) {
int u = rk[i];
for (int j: pathsThrough[u]) {
int x = get<0>(path[j]), y = get<1>(path[j]), w = get<2>(path[j]);
f[u] = max(f[u], w - bit.query(id[x]) - bit.query(id[y]));
}
bit.update(id[u], id[u] + size[u] - 1, f[u]);
h[u] = f[u];
}
for (int i = 1; i <= n; ++i) {
int u = rk[i];
fSum[u] = fSum[fa[u]] + f[u];
}
seg1.build(n);
for (int i = 1; i <= n; ++i) {
int u = rk[i];
gSum[u] = gSum[fa[u]] + g[u];
seg0.build(ch[u].size());
for (int v: ch[u])
g[v] = max(g[v], max(seg1.query(id[u], id[v] - 1), seg1.query(id[v] + size[v], id[u] + size[u] - 1)) + fSum[u] - gSum[u]);
h[u] = max(h[u], seg1.query(id[u], id[u] + size[u] - 1) + fSum[u] - gSum[u]);
for (int j: pathsThrough[u]) {
int x = get<0>(path[j]), y = get<1>(path[j]), w = get<2>(path[j]);
int xson = get<1>(pathLCA[j]), yson = get<2>(pathLCA[j]);
if (dep[x] > dep[y])
swap(x, y), swap(xson, yson);
if (x == y)
seg0.update(0, ch[u].size() - 1, w);
else if(x == u) {
ll v = w - fSum[y] + gSum[x];
seg1.insert(id[y], v);
v = w - fSum[y] + fSum[u];
seg0.update(0, label[yson] - 1, v), seg0.update(label[yson] + 1, ch[u].size() - 1, v);
} else {
ll v = w - fSum[x] - fSum[y] + fSum[u] + gSum[u];
seg1.insert(id[x], v), seg1.insert(id[y], v);
v = w - fSum[x] - fSum[y] + 2 * fSum[u];
int id0 = label[xson], id1 = label[yson];
if (id0 > id1)
swap(id0, id1);
seg0.update(0, id0 - 1, v), seg0.update(id0 + 1, id1 - 1, v), seg0.update(id1 + 1, ch[u].size() - 1, v);
}
}
for (int v: ch[u])
g[v] = max(g[v], seg0.query(label[v]));
}
for (int i = 1; i <= n; ++i)
ans = (ans + h[i] * n) % mod,
ans = (ans + (f[i] + g[i]) % mod * size[i] % mod * (n - size[i])) % mod;
printf("%d\n", ans);
}