LibreOJ 3045 「ZJOI2019」开关

S=ni=1pi

由于不同开关的操作之间有顺序,即有标号,考虑对每个开关的操作构造 EGF 并乘起来。
Fi(x)=n=0[nsi(mod2)]pniSxnn!

F(x)=ni=1Fi(x)
F(x) 即序列 {fn} 的 EGF,其中 fn=n![xn]F(x) 表示 n 次达到指定状态的概率。

但是容易发现题目要求第一次达到指定状态的期望次数,考虑再构造一些东西。

gn 表示 n 次关闭全部开关的概率,hn 表示 n 次达到期望状态且是首次达到的概率。
f(x),g(x),h(x) 分别为 {fn},{gn},{hn} 的 OGF,则容易发现 f(x)=h(x)g(x)h(x)=f(x)g(x)

再根据一些基本知识,易知所谓期望步数即 n=0nhn=n=0[xn]h(x)=h(1)

考虑如何求答案。
首先易知 F(x)=ni=1exp(piSx)+(1)siexp(piSx)2G(x)=ni=1exp(piSx)+exp(piSx)2

考虑把 F(x),G(x) 看做关于 exp(1Sx) 的多项式。
F(x)=Si=SFiexp(iSx)G(x)=Si=SGiexp(iSx)

系数可以通过背包 DP 求出。

易得 f(x)=Si=SFi1iSxg(x)=Si=SGi1iSx

那么根据基本知识 h(x)=f(x)g(x)+f(x)g(x)g2(x),考虑求 f(1),f(1),g(1),g(1)
但是可惜的是它们并不收敛……

考虑把 f,g 乘上 (1x),再推一推,可得答案为 1G2iS1i=S(FiGSFSGi)SiS

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <cstdio>
using namespace std;
const int N = 100;
const int S = 5e4;
const int mod = 998244353;
const int inv = 499122177;
int n,s[N + 5],p[N + 5],sum,isum,ans;
int fr[N + 5][(S << 1) + 5],gr[N + 5][(S << 1) + 5];
int *f[N + 5],*g[N + 5];
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int main()
{
scanf("%d",&n);
for(register int i = 1;i <= n;++i)
scanf("%d",s + i),s[i] = s[i] ? mod - 1 : 1;
for(register int i = 1;i <= n;++i)
scanf("%d",p + i),sum += p[i];
for(register int i = 0;i <= n;++i)
f[i] = fr[i] + S,g[i] = gr[i] + S;
f[0][0] = g[0][0] = 1;
for(register int i = 1;i <= n;++i)
for(register int j = -sum;j <= sum;++j)
j + p[i] <= sum && (f[i][j + p[i]] = (f[i][j + p[i]] + (long long)f[i - 1][j] * inv) % mod,g[i][j + p[i]] = (g[i][j + p[i]] + (long long)g[i - 1][j] * inv) % mod),
j - p[i] >= -sum && (f[i][j - p[i]] = (f[i][j - p[i]] + (long long)s[i] * f[i - 1][j] % mod * inv) % mod,g[i][j - p[i]] = (g[i][j - p[i]] + (long long)g[i - 1][j] * inv) % mod);
isum = fpow(sum,mod - 2);
for(register int i = -sum;i < sum;++i)
ans = (ans + ((long long)f[n][i] * g[n][sum] % mod - (long long)f[n][sum] * g[n][i] % mod + mod) * sum % mod * fpow((i - sum + mod) % mod,mod - 2)) % mod;
ans = (long long)ans * fpow(g[n][sum],mod - 3) % mod;
printf("%d\n",ans);
}

Related Issues not found

Please contact @Alpha1022 to initialize the comment