yysy,确实是个简单题,而且很套路(
首先显然 \(f = \mu^2\)。
那么 \[
\begin{aligned}
\sum\limits_{i=1}^n \sum\limits_{j=1}^n (i+j)^k \mu^2(\gcd(i,j))\gcd(i,j)
&= \sum\limits_{d=1}^n \mu^2(d) d \sum\limits_{i=1}^n \sum\limits_{j=1}^n [\gcd(i,j)=d] (i+j)^k \\
&= \sum\limits_{d=1}^n \mu^2(d) d^{k+1} \sum\limits_{i=1}^{\left\lfloor\frac nd\right\rfloor} \sum\limits_{j=1}^{\left\lfloor\frac nd\right\rfloor} [\gcd(i,j)=1] (i+j)^k \\
&= \sum\limits_{d=1}^n \mu^2(d) d^{k+1} \sum\limits_{i=1}^{\left\lfloor\frac nd\right\rfloor} \sum\limits_{j=1}^{\left\lfloor\frac nd\right\rfloor} (i+j)^k \sum\limits_{u|\gcd(i,j)} \mu(u) \\
&= \sum\limits_{d=1}^n \mu^2(d) d^{k+1} \sum\limits_{u=1}^{\left\lfloor\frac nd\right\rfloor} \mu(u)u^k \sum\limits_{i=1}^{\left\lfloor\frac n{du}\right\rfloor} \sum\limits_{j=1}^{\left\lfloor\frac n{du}\right\rfloor} (i+j)^k
\end{aligned}
\]
注意到可以数论分块套数论分块,则这一部分的复杂度为 \(O(n^{3/4})\)。
则我们需要求得 \(\mu^2(i)i^{k+1},\mu(i)i^k\) 的前缀和,并对于所有 \(n'\) 求出 \(\sum\limits_{i=1}^{n'} \sum\limits_{j=1}^{n'} (i+j)^k\)。
前两个线性筛出所有的 \(\mu(i),i^k\) 即可计算,最后一个考虑观察其关于 \(n'\) 差分后的结果: \[
\begin{aligned}
\sum\limits_{i=1}^{n+1} \sum\limits_{j=1}^{n+1} (i+j)^k - \sum\limits_{i=1}^n \sum\limits_{j=1}^n (i+j)^k
&= \sum\limits_{i=1}^n \sum\limits_{j=1}^{n+1} (i+j)^k + \sum\limits_{j=1}^{n+1} (n+1+j)^k - \sum\limits_{i=1}^n \sum\limits_{j=1}^n (i+j)^k \\
&= \sum\limits_{i=1}^n \left(\sum\limits_{j=1}^{n+1} (i+j)^k - \sum\limits_{j=1}^n (i+j)^k\right) + \sum\limits_{j=1}^{n+1} (n+1+j)^k \\
&= \sum\limits_{i=1}^n (i+n+1)^k + \sum\limits_{j=1}^{n+1} (n+1+j)^k
\end{aligned}
\]
于是这个也可以 \(O(n)\) 递推求得。
总时间复杂度 \(O(n)\)。
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
using namespace std;
const int N = 1e7;
const int mod = 998244353;
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int n,k;
int vis[N + 5],cnt,prime[N + 5],mu[N + 5],pw[N + 5];
int sum[3][N + 5],ans;
int main()
{
long long x;
scanf("%d%lld",&n,&x),k = x % (mod - 1),n <<= 1,mu[1] = pw[1] = 1;
for(register int i = 2;i <= n;++i)
{
if(!vis[i])
mu[prime[++cnt] = i] = mod - 1,pw[i] = fpow(i,k);
for(register int j = 1;j <= cnt && i * prime[j] <= n;++j)
{
vis[i * prime[j]] = 1,pw[i * prime[j]] = (long long)pw[i] * pw[prime[j]] % mod;
if(!(i % prime[j]))
break;
mu[i * prime[j]] = dec(0,mu[i]);
}
}
n >>= 1;
for(register int i = 1;i <= n;++i)
sum[0][i] = (sum[0][i - 1] + (long long)mu[i] * mu[i] % mod * pw[i] % mod * i % mod) % mod,
sum[1][i] = (sum[1][i - 1] + (long long)mu[i] * pw[i] % mod) % mod;
n <<= 1;
for(register int i = 1;i <= n;++i)
pw[i] = (pw[i - 1] + pw[i]) % mod;
n >>= 1;
for(register int i = 1;i <= n;++i)
sum[2][i] = add(sum[2][i - 1],dec(pw[i << 1],pw[i])),
sum[2][i] = add(sum[2][i],dec(pw[(i << 1) - 1],pw[i]));
for(register int L = 1,R,s;L <= n;L = R + 1)
{
R = n / (n / L),s = 0;
for(register int l = 1,r,m = n / L;l <= m;l = r + 1)
{
r = m / (m / l);
s = (s + (long long)(sum[1][r] - sum[1][l - 1] + mod) * sum[2][m / l] % mod) % mod;
}
ans = (ans + (long long)(sum[0][R] - sum[0][L - 1] + mod) * s % mod) % mod;
}
printf("%d\n",ans);
}