考虑序列上的问题。
首先考虑枚举数字 \(i\) 划分为 \(a_i\) 段,所有方案的权值之和为 \(\binom{c_i+a_i-1}{2a_i-1}\)。
这不难通过组合意义证明:权值可以视作将 \(c_i\) 个数字划分为 \(a_i\) 段,再在每一段中选择一个的方案数。
那么这些数字段之间的顺序安排的方案数,即有 \(a_i\) 个数字 \(i\),相同数字不得相邻的多重集排列数。
对于数字 \(i\),\(a_i\) 之间没有顺序影响,于是不相邻的限制可以看做 \(a_i-1\) 对数字 \(i\) 不得相邻的限制。
考虑容斥,枚举 \(b_i\) 表示打破多少限制,那么相当于数字 \(i\) 被强制缩为 \(a_i-b_i\) 个段,有 \[
\sum\limits_{0\le b_i \le a_i-1} \prod\limits_{i=1}^n (-1)^{b_i} \binom{a_i-1}{b_i} \frac{[\sum_{i=1}^n (a_i-b_i)]!}{(a_i-b_i)!}
\]
综合 \(a_i\) 的枚举,并改为枚举 \(d_i=a_i-b_i\),可得 \[ \begin{aligned} & \sum\limits_{1 \le a_i \le c_i} \sum\limits_{1 \le d_i \le a_i} \left(\sum\limits_{i=1}^n d_i\right)! \prod\limits_{i=1}^n \binom{c_i+a_i-1}{2a_i-1} (-1)^{a_i-d_i} \frac{(a_i-1)!}{(a_i-d_i)!(d_i-1)d_i!} \\ =& \sum\limits_{1 \le d_i \le c_i} \left(\sum\limits_{i=1}^n d_i\right)! \sum\limits_{d_i \le a_i \le c_i} \prod\limits_{i=1}^n (-1)^{a_i-d_i} \binom{c_i+a_i-1}{2a_i-1} \frac{(a_i-1)!}{(a_i-d_i)!(d_i-1)!d_i!} \end{aligned} \]
考虑构造生成函数以通过卷积关于 \(c_i\) 之和统计答案。 \[ \begin{aligned} F_i(x) &= \sum\limits_{j=1}^{c_i} \frac{x^j}{(j-1)!j!} \sum\limits_{k=j}^{c_i} \binom{c_i+k-1}{2k-1} (-1)^{k-j} \frac{(k-1)!}{(k-j)!} \\ &= \sum\limits_{j=1}^{c_i} \frac{x^j}{(j-1)!j!} \sum\limits_{k=0}^{c_i-j} \binom{c_i+(j+k)-1}{2(j+k)-1}(j+k-1)! \cdot (-1)^{k} \frac1{k!} \end{aligned} \]
不难构造卷积求出生成函数,再执行分治 NTT 即可。
再考虑环上的问题。可以强制使开头为 \(1\),结尾不为 \(1\),即开头为 \(1\) 的减去开头结尾均为 \(1\) 的。
分别相当于强制安排了 \(1,2\) 段数字 \(1\) 的位置,可以对 \(F_1(x)\) 左移来实现(注意左移不包括那个 \(\frac1{j!}\),因为强制确定了顺序)。
然后循环移位也要计算贡献,故乘上 \(\sum\limits_{i=1}^n c_i\)。
然而这样 \(i\) 段 \(1\) 的序列会被计算 \(i\) 次。在 \(F_1(x)\) 中乘上 \(\frac1{a_1}\) 即可。
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
using namespace std;
const int N = 2e5;
const int mod = 998244353;
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int n,c[N + 5],sum;
namespace Poly
{
const int N = 1 << 19;
const int G = 3;
int lg2[N + 5];
int rev[N + 5],fac[N + 5],ifac[N + 5],inv[N + 5];
int rt[N + 5],irt[N + 5];
inline void init()
{
for(register int i = 2;i <= N;++i)
lg2[i] = lg2[i >> 1] + 1;
int w = fpow(G,(mod - 1) / N);
rt[N >> 1] = 1;
for(register int i = (N >> 1) + 1;i <= N;++i)
rt[i] = (long long)rt[i - 1] * w % mod;
for(register int i = (N >> 1) - 1;i;--i)
rt[i] = rt[i << 1];
fac[0] = 1;
for(register int i = 1;i <= N;++i)
fac[i] = (long long)fac[i - 1] * i % mod;
ifac[N] = fpow(fac[N],mod - 2);
for(register int i = N;i;--i)
ifac[i - 1] = (long long)ifac[i] * i % mod;
for(register int i = 1;i <= N;++i)
inv[i] = (long long)ifac[i] * fac[i - 1] % mod;
}
struct poly
{
vector<int> a;
inline poly(int x = 0)
{
x && (a.push_back(x),1);
}
inline poly(const vector<int> &o)
{
a = o,shrink();
}
inline void shrink()
{
for(;!a.empty() && !a.back();a.pop_back());
}
inline int size() const
{
return a.size();
}
inline void resize(int x)
{
a.resize(x);
}
inline int operator[](int x) const
{
if(x < 0 || x >= size())
return 0;
return a[x];
}
inline int &operator[](int x)
{
return a[x];
}
inline void clear()
{
vector<int>().swap(a);
}
inline void ntt(int type = 1)
{
int n = size();
type == -1 && (reverse(a.begin() + 1,a.end()),1);
int lg = lg2[n] - 1;
for(register int i = 0;i < n;++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg),
i < rev[i] && (swap(a[i],a[rev[i]]),1);
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
int t = (long long)rt[m | j] * a[i | j | m] % mod;
a[i | j | m] = dec(a[i | j],t),a[i | j] = add(a[i | j],t);
}
if(type == -1)
for(register int i = 0;i < n;++i)
a[i] = (long long)a[i] * inv[n] % mod;
}
friend inline poly operator+(const poly &a,const poly &b)
{
vector<int> ret(max(a.size(),b.size()));
for(register int i = 0;i < ret.size();++i)
ret[i] = add(a[i],b[i]);
return poly(ret);
}
friend inline poly operator-(const poly &a,const poly &b)
{
vector<int> ret(max(a.size(),b.size()));
for(register int i = 0;i < ret.size();++i)
ret[i] = dec(a[i],b[i]);
return poly(ret);
}
friend inline poly operator*(poly a,poly b)
{
if(a.a.empty() || b.a.empty())
return poly();
int lim = 1,tot = a.size() + b.size() - 1;
for(;lim < tot;lim <<= 1);
a.resize(lim),b.resize(lim);
a.ntt(),b.ntt();
for(register int i = 0;i < lim;++i)
a[i] = (long long)a[i] * b[i] % mod;
a.ntt(-1),a.shrink();
return a;
}
poly &operator+=(const poly &o)
{
resize(max(size(),o.size()));
for(register int i = 0;i < o.size();++i)
a[i] = add(a[i],o[i]);
return *this;
}
poly &operator-=(const poly &o)
{
resize(max(size(),o.size()));
for(register int i = 0;i < o.size();++i)
a[i] = dec(a[i],o[i]);
return *this;
}
poly &operator*=(poly o)
{
return (*this) = (*this) * o;
}
poly deriv() const
{
if(a.empty())
return poly();
vector<int> ret(size() - 1);
for(register int i = 0;i < size() - 1;++i)
ret[i] = (long long)(i + 1) * a[i + 1] % mod;
return poly(ret);
}
poly integ() const
{
if(a.empty())
return poly();
vector<int> ret(size() + 1);
for(register int i = 0;i < size();++i)
ret[i + 1] = (long long)a[i] * inv[i + 1] % mod;
return poly(ret);
}
inline poly modxn(int n) const
{
n = min(n,size());
return poly(vector<int>(a.begin(),a.begin() + n));
}
inline poly inver(int m) const
{
poly ret(fpow(a[0],mod - 2));
for(register int k = 1;k < m;)
k <<= 1,ret = (ret * (2 - modxn(k) * ret)).modxn(k);
return ret.modxn(m);
}
inline poly log(int m) const
{
return (deriv() * inver(m)).integ(),modxn(m);
}
inline poly exp(int m) const
{
poly ret(1);
for(register int k = 1;k < m;)
k <<= 1,ret = (ret * (1 - ret.log(k) + modxn(k))).modxn(k);
return ret.modxn(m);
}
};
}
using Poly::init;
using Poly::poly;
poly f[N + 5],g,h,p;
int ans;
inline int C(int n,int m)
{
return n < m ? 0 : (long long)Poly::fac[n] * Poly::ifac[m] % mod * Poly::ifac[n - m] % mod;
}
poly solve(int l,int r)
{
if(l == r)
return f[l];
int mid = l + r >> 1;
return solve(l,mid) * solve(mid + 1,r);
}
int main()
{
Poly::init();
scanf("%d",&n);
for(register int i = 1;i <= n;++i)
scanf("%d",c + i),sum += c[i];
for(register int i = 1;i <= n;++i)
{
f[i].resize(c[i] + 1),g.resize(c[i] + 1),h.resize(c[i] + 1);
for(register int j = 0;j <= c[i];++j)
g[j] = (long long)(j & 1 ? mod - 1 : 1) * Poly::ifac[j] % mod;
h[0] = 0;
for(register int j = 1;j <= c[i];++j)
h[c[i] - j] = (long long)C(c[i] + j - 1,2 * j - 1) * Poly::fac[j - 1] % mod;
if(i == 1)
for(register int j = 1;j <= c[i];++j)
h[c[i] - j] = (long long)h[c[i] - j] * Poly::inv[j] % mod;
g *= h;
for(register int j = 1;j <= c[i];++j)
f[i][j] = (long long)g[c[i] - j] * Poly::ifac[j - 1] % mod * Poly::ifac[j] % mod;
}
p = solve(2,n),g.resize(c[1]);
for(register int i = 0;i < c[1];++i)
g[i] = (long long)f[1][i + 1] * (i + 1) % mod;
h = g * p,g.resize(c[1] - 1);
for(register int i = 0;i < c[1] - 1;++i)
g[i] = (long long)f[1][i + 2] * (i + 2) % mod * (i + 1) % mod;
h -= g * p;
for(register int i = 0;i < h.size();++i)
ans = (ans + (long long)h[i] * Poly::fac[i]) % mod;
ans = (long long)ans * sum % mod;
printf("%d\n",ans);
}