JZOJ 5251 决战

首先考虑一个 SB 状压 DP:
\(f_{i,S,j}\) 表示前 \(i\) 行,第 \(i\) 行放了哲学家的位置的集合为 \(S\),目前放了 \(j\) 个哲学家。
这个玩意复杂度是 \(O(4^c nm)\) 的,其中 \(c = 3\)

考虑设 \(F_{i,S}(x) = \sum\limits_{j=0}^{3n} f_{i,S,j} x^j\)
那么转移可以写作 \(F_{i,S}(x) = \sum\limits_{\text{It is valid to transfer from } T \text{ to } S \text.} x^{|S|} F_{i-1,T}(x)\)

考虑基于 \(F\) 的点值表达式进行 DP,设 \(F_{i,k,S}\)\(F_{i,S}(\omega_{3n}^k)\),那么便可把上述式子写成点值相乘并相加的形式。
然而复杂度还是 \(O(4^c nm)\) 的,并无卵用(

真的无卵用吗?注意到这样子我们就让 \(S\) 这一维的转移相对独立了,并且 \(S\) 十分的小,考虑用矩阵乘法优化这一维的转移。
复杂度为 \(O(8^c n \log n)\)
(注意多项式的次数界应为 \(3n\),因为我们无法在转移的同时把点值 \(\bmod x^m\)。)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#include <cstdio>
#include <cstring>
#include <algorithm>
#define add(a,b) (a + b >= mod ? a + b - mod : a + b)
#define dec(a,b) (a < b ? a - b + mod : a - b)
using namespace std;
const int N = 2500;
const int M = 7500;
const int K = 3;
const int S = 8;
const int mod = 998244353;
const int G = 3;
int n,m,full = (1 << K) - 1,a[K + 5];
int up[(1 << K) + 5],down[(1 << K) + 5],valid[(1 << K) + 5];
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
namespace NTT
{
const int N = 15000;
struct poly
{
int a[N + 5];
inline const int &operator[](int x) const
{
return a[x];
}
inline int &operator[](int x)
{
return a[x];
}
inline void clear(int x = 0)
{
memset(a + x,0,(N - x + 1) << 2);
}
} f;
int n,lg2[N + 5];
int rev[N + 5],fac[N + 5],ifac[N + 5],inv[N + 5];
int rt[N + 5],irt[N + 5];
inline void init(int len)
{
for(n = 1;n < len;n <<= 1);
for(register int i = 2;i <= n;++i)
lg2[i] = lg2[i >> 1] + 1;
int w = fpow(G,(mod - 1) / n);
rt[n >> 1] = 1;
for(register int i = (n >> 1) + 1;i <= n;++i)
rt[i] = (long long)rt[i - 1] * w % mod;
for(register int i = (n >> 1) - 1;i;--i)
rt[i] = rt[i << 1];
fac[0] = 1;
for(register int i = 1;i <= n;++i)
fac[i] = (long long)fac[i - 1] * i % mod;
ifac[n] = fpow(fac[n],mod - 2);
for(register int i = n;i;--i)
ifac[i - 1] = (long long)ifac[i] * i % mod;
for(register int i = 1;i <= n;++i)
inv[i] = (long long)ifac[i] * fac[i - 1] % mod;
}
inline void ntt(poly &a,int type,int n)
{
type == -1 && (reverse(a.a + 1,a.a + n),1);
int lg = lg2[n] - 1;
for(register int i = 0;i < n;++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg),
i < rev[i] && (swap(a[i],a[rev[i]]),1);
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
int t = (long long)rt[m | j] * a[i | j | m] % mod;
a[i | j | m] = dec(a[i | j],t),a[i | j] = add(a[i | j],t);
}
if(type == -1)
for(register int i = 0;i < n;++i)
a[i] = (long long)a[i] * inv[n] % mod;
}
}
using NTT::init;
using NTT::ntt;
struct Matrix
{
int a[S + 5][S + 5];
inline void clear()
{
memset(a,0,sizeof a);
}
inline Matrix()
{
memset(a,0,sizeof a);
}
inline Matrix(int)
{
memset(a,0,sizeof a);
for(register int i = 0;i <= full;++i)
a[i][i] = 1;
}
inline const int *operator[](const int &x) const
{
return a[x];
}
inline int *operator[](const int &x)
{
return a[x];
}
inline Matrix operator*(const Matrix &o) const
{
Matrix ret;
for(register int i = 0;i <= full;++i)
for(register int j = 0;j <= full;++j)
for(register int k = 0;k <= full;++k)
ret[i][j] = (ret[i][j] + (long long)a[i][k] * o[k][j]) % mod;
return ret;
}
} mat;
inline Matrix fpow(Matrix a,int b)
{
Matrix ret(1);
for(;b;b >>= 1)
(b & 1) && (ret = ret * a,1),a = a * a;
return ret;
}
int pw[4];
int dp(int w)
{
int ret = 0;
pw[0] = 1;
for(register int i = 1;i <= K;++i)
pw[i] = (long long)pw[i - 1] * w % mod;
mat.clear();
for(register int S = 0;S <= full;++S)
for(register int T = 0;T <= full;++T)
if(valid[S] && valid[T] && !(up[T] & S) && !(down[S] & T))
mat[S][T] = pw[__builtin_popcount(T)];
mat = fpow(mat,n);
for(register int i = 0;i <= full;++i)
ret = add(ret,mat[0][i]);
return ret;
}
int main()
{
scanf("%d%d",&n,&m);
if(m > 3 * n)
{
puts("0");
return 0;
}
int x;
for(register int i = 1;i <= K;++i)
for(register int j = 1;j <= K;++j)
scanf("%d",&x),a[i] |= x << j - 1;
for(register int i = 0;i <= full;++i)
{
valid[i] = 1;
for(register int j = 1;j <= K;++j)
if(i & (1 << j - 1))
up[i] |= a[1] << 1 >> K - j,down[i] |= a[3] << 1 >> K - j,
valid[i] &= !((i ^ (1 << j - 1)) & (a[2] << 1 >> K - j));
up[i] &= full,down[i] &= full;
}
NTT::init(n * 3);
for(register int i = 0;i < NTT::n;++i)
NTT::f[i] = dp(fpow(G,(mod - 1) / NTT::n * i));
ntt(NTT::f,-1,NTT::n);
printf("%d\n",NTT::f[m]);
}