JZOJ 5069 蛋糕

首先可以将每种方案与每条两端点为葡萄干且不经过其余葡萄干的线段一一对应,那么问题转化为数这样的线段条数。

斜率为 \(0\) 或无穷的线段显然有 \(2n(n-1)\) 条,其余的线段,斜率为正的有 \[ \begin{aligned} & \sum\limits_{x_1 = 1}^n \sum\limits_{y_1 = 1}^n \sum\limits_{x_2 = x_1 + 1}^n \sum\limits_{y2 = y1 + 1}^n [\gcd(x_2 - x_1,y_2 - y_1) = 1] \\ = & \sum\limits_{i=1}^n\sum\limits_{j=1}^n [\gcd(i,j)=1](n-i)(n-j) \end{aligned} \]

综合一下,再加以推导,有 \[ \begin{aligned} &\ 2n(n-1)+2\sum\limits_{i=1}^n\sum\limits_{j=1}^n [\gcd(i,j)=1](n-i)(n-j) \\ = &\ 2n(n-1)+2\left((n-1)^2+2\sum\limits_{i=2}^n (n-i)\sum\limits_{j=1}^{i-1}[\gcd(i,j)=1](n-j)\right) \\ = &\ 4n^2-6n+2+4\sum\limits_{i=2}^n (n-i)\left(n\varphi(i) - \frac12i\varphi(i)\right) \\ = &\ 4n^2-6n+2+\sum\limits_{i=2}^n (4n^2\varphi(i)-6n\varphi(i)i+2\varphi(i)i^2) \\ = &\ 4n^2\sum\limits_{i=1}^n \varphi(i)-6n\sum\limits_{i=1}^n \varphi(i)i+2\sum\limits_{i=1}^n \varphi(i)i^2 \end{aligned} \]

杜教筛即可。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int MX = 1e6;
int n,mod,ans;
int vis[MX + 5],cnt,prime[MX + 5],phi[MX + 5];
int f1[MX + 5],f2[MX + 5];
inline int calc_sum(int n)
{
return (long long)n * (n + 1) / 2 % mod;
}
inline int calc_sqrsum(int n)
{
static int tmp[3];
tmp[0] = n,tmp[1] = n + 1,tmp[2] = 2 * n + 1;
for(register int i = 0;i < 3;++i)
if(!(tmp[i] & 1))
{
tmp[i] >>= 1;
break;
}
for(register int i = 0;i < 3;++i)
if(!(tmp[i] % 3))
{
tmp[i] /= 3;
break;
}
return (long long)tmp[0] * tmp[1] % mod * tmp[2] % mod;
}
inline int calc_cubesum(int n)
{
int s = calc_sum(n);
return (long long)s * s % mod;
}
int mem[3][MX + 5];
inline int id(int x)
{
return n / x;
}
int calc1(int n)
{
if(n <= MX)
return phi[n];
if(~mem[0][id(n)])
return mem[0][id(n)];
int ret = calc_sum(n);
for(register int l = 2,r;l <= n;l = r + 1)
{
r = n / (n / l);
ret = (ret - (long long)(r - l + 1) * calc1(n / l) % mod + mod) % mod;
}
return mem[0][id(n)] = ret;
}
int calc2(int n)
{
if(n <= MX)
return f1[n];
if(~mem[1][id(n)])
return mem[1][id(n)];
int ret = calc_sqrsum(n);
for(register int l = 2,r;l <= n;l = r + 1)
{
r = n / (n / l);
ret = (ret - (long long)(calc_sum(r) - calc_sum(l - 1) + mod) * calc2(n / l) % mod + mod) % mod;
}
return mem[1][id(n)] = ret;
}
int calc3(int n)
{
if(n <= MX)
return f2[n];
if(~mem[2][id(n)])
return mem[2][id(n)];
int ret = calc_cubesum(n);
for(register int l = 2,r;l <= n;l = r + 1)
{
r = n / (n / l);
ret = (ret - (long long)(calc_sqrsum(r) - calc_sqrsum(l - 1) + mod) * calc3(n / l) % mod + mod) % mod;
}
return mem[2][id(n)] = ret;
}
int main()
{
freopen("cake.in","r",stdin),freopen("cake.out","w",stdout);
memset(mem,-1,sizeof mem);
scanf("%d%d",&n,&mod);
phi[1] = 1;
for(register int i = 2;i <= MX;++i)
{
if(!vis[i])
phi[prime[++cnt] = i] = i - 1;
for(register int j = 1;j <= cnt && i * prime[j] <= MX;++j)
{
vis[i * prime[j]] = 1;
if(!(i % prime[j]))
{
phi[i * prime[j]] = phi[i] * prime[j];
break;
}
phi[i * prime[j]] = phi[i] * (prime[j] - 1);
}
}
for(register int i = 1;i <= MX;++i)
f1[i] = (f1[i - 1] + (long long)phi[i] * i) % mod,
f2[i] = (f2[i - 1] + (long long)phi[i] * i % mod * i) % mod,
phi[i] = (phi[i] + phi[i - 1]) % mod;
ans = (ans + 4LL * n % mod * n % mod * calc1(n)) % mod,
ans = (ans - 6LL * n % mod * calc2(n) % mod + mod) % mod,
ans = (ans + 2LL * calc3(n)) % mod;
printf("%d\n",ans);
}