洛谷 4983 忘情

化式子先。 \[\begin{align*} & \dfrac{(\overline x \sum\limits_{i = 1}^n x_i + \overline x)^2}{\overline x^2} \\ = & \dfrac{\overline x^2(\sum\limits_{i = 1}^n x_i)^2 + 2\overline x^2 \sum\limits_{i = 1}^n x_i + \overline x^2}{\overline x^2} \\ = & (\sum\limits_{i = 1}^n x_i)^2 + 2\sum\limits_{i = 1}^n x_i + 1 \\ = & (\sum\limits_{i = 1}^n x_i + 1)^2 \end{align*}\]

朴素方程有 \(F_{i,k} = \min\limits_{0 \le j < i}(F_{j,k - 1} + (sum_i - sum_j + 1)^2)\)
其中 \(sum_i = \sum\limits_{j = 1}^i x_j\)

首先不考虑段数的限制,方程即 \(f_i = \min\limits_{0 \le j < i}(f_j + (sum_i - sum_j + 1)^2)\)
假设决策 \(0 \le k < j < i\) 使得 \(j\) 优于 \(k\),即 \[\begin{align*} f_j + (sum_i - sum_j + 1)^2 & < f_k + (sum_i - sum_k + 1)^2 \\ f_j - 2sum_i(sum_j - 1) + (sum_j - 1)^2 & < f_k - 2sum_i(sum_k - 1) + (sum_k - 1)^2 \\ (f_j + (sum_j - 1)^2) - (f_k + (sum_k - 1)^2) & < 2sum_i(sum_j - sum_k) \\ \dfrac{(f_j + (sum_j - 1)^2) - (f_k + (sum_k - 1)^2)}{sum_j - sum_k} & < 2sum_i \end{align*}\] 注意这个地方如果要使用 WQS 二分的话不能包含斜率等于 \(2sum_i\) 的情况,否则会错(可能没有单调性)。

然后来考虑一下段数的限制,如果直接 \(O(nm)\) DP 肯定不行,但是我们发现如果没有这个段数就可以 \(O(n)\) 过去。
所以我们就有了 WQS 二分。
主要思想是,如果转移方程是 \(f_i = \min\limits_{0 \le j < i}(f_j + (sum_i - sum_j + 1)^2 + C)\),那么 \(C\) 越大影响就越大,段数就越少;反之亦然。
所以我们就二分这个 \(C\),找到第一个使得段数 \(\le m\)整数 \(C\)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <cstdio>
using namespace std;
const int N = 1e5;
int n,m;
int g[N + 5];
long long sum[N + 5],f[N + 5];
int q[N + 5],head,tail;
long long l,r,mid,ans;
inline double slope(int x,int y)
{
return (double)(f[x] + (sum[x] - 1) * (sum[x] - 1) - f[y] - (sum[y] - 1) * (sum[y] - 1)) / (sum[x] - sum[y]);
}
int check()
{
q[head = tail = 1] = 0;
for(register int i = 1;i <= n;++i)
{
for(;head < tail && slope(q[head],q[head + 1]) < 2 * sum[i];++head);
f[i] = f[q[head]] + (sum[i] - sum[q[head]] + 1) * (sum[i] - sum[q[head]] + 1) + mid,g[i] = g[q[head]] + 1;
for(;head < tail && slope(q[tail - 1],q[tail]) > slope(q[tail],i);--tail);
q[++tail] = i;
}
return g[n] <= m;
}
int main()
{
scanf("%d%d",&n,&m);
for(register int i = 1;i <= n;++i)
scanf("%lld",sum + i),sum[i] += sum[i - 1];
l = 0,r = 0x3f3f3f3f3f3f3f3f;
while(l <= r)
{
mid = l + r >> 1;
if(check())
r = mid - 1,ans = mid;
else
l = mid + 1;
}
mid = ans,check();
printf("%lld\n",f[n] - ans * m);
}