JZOJ 7405 诡秘的爪钩中心

首先考虑一个显然重要的子问题:
假定已有 \(k_i\) 条定向的长为 \(i\) 的链 \((1 \le i \le n)\),将其任意连接为若干环,求环大小的乘积的和。

事实上我在场上使用多元拉格朗日反演得出了结论,不过在此不表。
考虑一个经典的组合意义:转化为从每个环中选出一个元素的方案数。
枚举 \(c_i\) 表示长度为 \(i\) 的链被选择多少次,则其同时被钦定不在同一环内(这个方案数可以直接通过简单的组合意义得到): \[ \sum_{0\le c_i \le k_i} \left(\prod_{j=1}^n \binom{k_j}{c_j}j^{c_j}\right) \frac{(\sum_j k_j-1)!}{(\sum_j c_j-1)!} \]

也就是 \[ \left(\sum_{j=1}^n k_j-1\right)! \sum_{i\ge 1} \frac1{(i-1)!} [x^i] \prod_{j=1}^n (1+jx)^{k_j} \]

我们把其放到容斥中,那么看起来我们务必要再用一元计量 \(\sum_j k_j\)
也就是从环上每断出长 \(l\) 的链会贡献 \((-1)^{l-1} t(1+lx)\)。为了快速计算,我们再用一元 \(u\) 计量环长。当然这里是当做链来处理了,之后每一项要乘环长再除以链数(也就是 \(t\) 的次数)。

注意到 \[ \sum_{l\ge 0}(-1)^{l-1} t(1+lx)u^l = \frac{tu}{1+u} + \frac{xtu}{(1+u)^2} = \frac{tu(1+x+u)}{(1+u)^2} \]

拼成链 \[ \frac1{1-\frac{tu(1+x+u)}{(1+u)^2}} \]

其任意一项系数是容易提取的(可以按照 \(t,x,u\) 的顺序提取),于是分治并做二维卷积计算系数即可。

为了偷懒我直接写了映射到一维来实现二维的卷积,不过这样会有个问题就是卷积长度达到了 \(2^{22}\),板子里的某个优化就用不了了(不过 DIT DIF 当然还是能写的)。
所以大概常数大了点。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
#include <queue>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>

using ll = long long;

using namespace std;

const int mod = 998244353;
inline int norm(int x) {
return x >= mod ? x - mod : x;
}
inline int reduce(int x) {
return x < 0 ? x + mod : x;
}
inline int neg(int x) {
return x ? mod - x : 0;
}
inline void add(int &x, int y) {
if ((x += y - mod) < 0)
x += mod;
}
inline void sub(int &x, int y) {
if ((x -= y) < 0)
x += mod;
}
inline void fam(int &x, int y, int z) {
x = (x + (ll)y * z) % mod;
}
inline int qpow(int a, int b) {
int ret = 1;
for (; b; b >>= 1)
(b & 1) && (ret = (ll)ret * a % mod),
a = (ll)a * a % mod;
return ret;
}

namespace Poly {
const int LG = 21;
const int N = 1 << LG + 1;
const int G = 3;

int lg2[N + 5];
int fac[N + 5], ifac[N + 5], inv[N + 5];
int rt[N + 5];

inline void init() {
for (int i = 2; i <= N; ++i)
lg2[i] = lg2[i >> 1] + 1;
int w = qpow(G, (mod - 1) >> LG + 1);
rt[N >> 1] = 1;
for (int i = (N >> 1) + 1; i <= N; ++i)
rt[i] = (ll)rt[i - 1] * w % mod;
for (int i = (N >> 1) - 1; i; --i)
rt[i] = rt[i << 1];
fac[0] = 1;
for (int i = 1; i <= N; ++i)
fac[i] = (ll)fac[i - 1] * i % mod;
ifac[N] = qpow(fac[N], mod - 2);
for (int i = N; i; --i)
ifac[i - 1] = (ll)ifac[i] * i % mod;
for (int i = 1; i <= N; ++i)
inv[i] = (ll)ifac[i] * fac[i - 1] % mod;
}

struct poly {

vector<int> a;
inline poly(int x = 0) {
if (x)
a.push_back(x);
}
inline poly(const vector<int> &o) {
a = o;
}
inline poly(const poly &o) {
a = o.a;
}

inline int size() const {
return a.size();
}
inline bool empty() const {
return a.empty();
}
inline void resize(int x) {
a.resize(x);
}
inline int operator[](int x) const {
if (x < 0 || x >= size())
return 0;
return a[x];
}
inline void clear() {
vector<int>().swap(a);
}
inline poly modxn(int n) const {
if (a.empty())
return poly();
n = min(n, size());
return poly(vector<int>(a.begin(), a.begin() + n));
}
inline poly rever() const {
return poly(vector<int>(a.rbegin(), a.rend()));
}

inline void dif() {
int n = size();
for (int len = n >> 1; len; len >>= 1)
for (int j = 0; j < n; j += len << 1)
for (int k = j, *w = rt + len; k < j + len; ++k, ++w) {
int R = norm(a[k] + a[k + len]);
a[k + len] = (ll)*w * (a[k] - a[k + len] + mod) % mod,
a[k] = R;
}
}
inline void dit() {
int n = size();
for (int len = 1; len < n; len <<= 1)
for (int j = 0; j < n; j += len << 1)
for (int k = j, *w = rt + len; k < j + len; ++k, ++w) {
int R = (ll)*w * a[k + len] % mod;
a[k + len] = reduce(a[k] - R),
add(a[k], R);
}
reverse(a.begin() + 1, a.end());
for (int i = 0; i < n; ++i)
a[i] = (ll)a[i] * inv[n] % mod;
}
inline void ntt(int type = 1) {
type == 1 ? dif() : dit();
}
};

struct poly2D {
int n, m;
vector<int> a;
inline poly2D(int r = 0, int s = 0): n(r), m(s) {
a.resize(n * m);
}
inline poly2D(int r, int s, vector<int> vec): n(r), m(s), a(vec) {}

inline int size() const {
return a.size();
}
inline bool empty() const {
return a.empty();
}
inline void resize(int r, int s) {
n = r, m = s, a.resize(n * m);
}
inline int operator[](int x) const {
if (x < 0 || x >= size())
return 0;
return a[x];
}

friend inline poly2D operator*(const poly2D &a, const poly2D &b) {
int n = a.n + b.n - 1, m = a.m + b.m - 1;
poly aBuf, bBuf, resBuf;
poly2D ret(n, m);
int tot = n * m, lim = 1;
for (; lim < tot; lim <<= 1);
aBuf.resize(lim), bBuf.resize(lim), resBuf.resize(lim);
for (int i = 0; i < a.n * a.m; ++i) {
int x = i / a.m, y = i % a.m;
aBuf.a[x * m + y] = a[i];
}
for (int i = 0; i < b.n * b.m; ++i) {
int x = i / b.m, y = i % b.m;
bBuf.a[x * m + y] = b[i];
}
aBuf.ntt(), bBuf.ntt();
for (int i = 0; i < lim; ++i)
fam(resBuf.a[i], aBuf[i], bBuf[i]);
resBuf.ntt(-1);
for (int i = 0; i < tot; ++i)
ret.a[i] = resBuf[i];
return ret;
}
};
}
using Poly::fac;
using Poly::ifac;
using Poly::inv;
using Poly::init;
using Poly::poly;
using Poly::poly2D;

inline int binom(int n, int m) {
return n < m || m < 0 ? 0 : (ll)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}

const int N = 2e3;

int n;
int p[N + 5];

struct UnionFind {
int fa[N + 5], size[N + 5];
inline UnionFind() {
for (int i = 1; i <= N; ++i)
size[i] = 1;
}
inline bool isRoot(int x) {
return !fa[x];
}
inline int find(int x) {
return isRoot(x) ? x : fa[x] = find(fa[x]);
}
inline void merge(int x, int y) {
int fx = find(x), fy = find(y);
if (fx != fy)
fa[fx] = fy, size[fy] += size[fx];
}
} uf;

struct comparer {
inline bool operator()(const poly2D &x, const poly2D &y) {
return x.size() > y.size();
}
};
priority_queue<poly2D, vector<poly2D>, comparer> q;

int ans = 1;

int main() {
freopen("C.in", "r", stdin), freopen("C.out", "w", stdout);
init();

scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", p + i), uf.merge(i, p[i]);
for (int i = 1; i <= n; ++i)
if (uf.isRoot(i)) {
int size = uf.size[i];
ans = (ll)ans * size % mod;
poly2D buf(size + 1, size + 1);
for (int k = 1; k <= size; ++k)
for (int t = 0; t <= k; ++t) {
buf.a[k * (size + 1) + t] = (ll)binom(size + t - 1, t + k - 1) * binom(k, t) % mod * size % mod * inv[k] % mod;
if ((size - k) & 1)
buf.a[k * (size + 1) + t] = neg(buf[k * (size + 1) + t]);
}
buf.a[0] = size & 1 ? neg(size) : size;
q.push(buf);
}
if (n & 1)
ans = neg(ans);
while (q.size() > 1) {
poly2D x = q.top();
q.pop(), x = x * q.top(), q.pop(), q.push(x);
}
poly2D f = q.top();
for (int k = 1; k <= n; ++k)
for (int t = 1; t <= k; ++t)
ans = (ans + (ll)fac[k - 1] * ifac[t - 1] % mod * f[k * (n + 1) + t]) % mod;
printf("%d\n", ans);
}