Codeforces 1039E Intersection of Permutations

打完比赛之后在比赛公告的评论区发现了一种分块思路。

我一看,数据结构题诶。
但是不会做。
写了个分块 + bitset 上去成功 TLE 飞。

如前文所说,在评论区发现了这样一种思路:
\(sum_{i,j}\) 表示 \(b\) 的第 \(i\) 块与 \(a\) \(j\) 块有多少个数相同。
\(place_{a,i}\) 表示 \(i\)\(a\) 中的编号,设 \(place_{b,i}\) 表示 \(i\)\(b\) 中的编号。

然后查询我们先处理 \(b\) 的完整块,
我们用 \(sum\) 数组计算 \(a\) 的完整块对 \(b\) 的完整块的答案贡献,
再用 \(place_b\) 计算 \(a\) 的角块对 \(b\) 的完整块的答案贡献。

接着处理 \(b\) 的角块,这个就用 \(place_a\) 计算答案的贡献。

由于 \(sum\) 数组第二维是前缀和,所以我们能以修改从 \(O(1)\) 退化成 \(O(\sqrt{n})\) 的代价,保证查询的复杂度为 \(O(\sqrt{n})\) 而没有退化成 \(O(n)\)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;

const int BUFF_SIZE = 1 << 20;
char BUFF[BUFF_SIZE],*BB,*BE;
#define gc() (BB == BE ? (BE = (BB = BUFF) + fread(BUFF,1,BUFF_SIZE,stdin),BB == BE ? EOF : *BB++) : *BB++)
template<class T>
inline void read(T &x)
{
x = 0;
char ch = 0,w = 0;
while(ch < '0' || ch > '9')
w |= ch == '-',ch = gc();
while(ch >= '0' && ch <= '9')
x = (x << 3) + (x << 1) + (ch ^ '0'),ch = gc();
w ? x = -x : x;
}

const int N = 2e5;
const int BLOCK_NUM = 500;
int n,m;
int ind[2][N + 10];
int a[2][N + 10];
int block,pos[N + 10];
int cnt[BLOCK_NUM + 10][BLOCK_NUM + 10];
int sum[BLOCK_NUM + 10][BLOCK_NUM + 10];
int main()
{
read(n),read(m);
block = sqrt(n);
for(register int i = 1;i <= n;++i)
read(a[0][i]),pos[i] = (i - 1) / block + 1,ind[0][a[0][i]] = i;
for(register int i = 1;i <= n;++i)
read(a[1][i]),ind[1][a[1][i]] = i,++cnt[pos[i]][pos[ind[0][a[1][i]]]];
for(register int i = 1;i <= pos[n];++i)
for(register int j = 1;j <= pos[n];++j)
sum[i][j] = sum[i][j - 1] + cnt[i][j];
int op,l1,r1,l2,r2;
while(m--)
{
read(op),read(l1),read(r1);
if(op == 1)
{
int ans = 0;
read(l2),read(r2);
if(pos[l1] + 1 < pos[r1])
for(register int i = pos[l2] + 1;i < pos[r2];++i)
ans += sum[i][pos[r1] - 1] - sum[i][pos[l1]];
for(register int i = l1;i <= min(pos[l1] * block,r1);++i)
ans += pos[l2] + 1 <= pos[ind[1][a[0][i]]] && pos[ind[1][a[0][i]]] <= pos[r2] - 1;
if(pos[l1] ^ pos[r1])
for(register int i = (pos[r1] - 1) * block + 1;i <= r1;++i)
ans += pos[l2] + 1 <= pos[ind[1][a[0][i]]] && pos[ind[1][a[0][i]]] <= pos[r2] - 1;

for(register int i = l2;i <= min(pos[l2] * block,r2);++i)
ans += l1 <= ind[0][a[1][i]] && ind[0][a[1][i]] <= r1;
if(pos[l2] ^ pos[r2])
for(register int i = (pos[r2] - 1) * block + 1;i <= r2;++i)
ans += l1 <= ind[0][a[1][i]] && ind[0][a[1][i]] <= r1;
printf("%d\n",ans);
}
else
{
--cnt[pos[l1]][pos[ind[0][a[1][l1]]]],--cnt[pos[r1]][pos[ind[0][a[1][r1]]]];
swap(a[1][l1],a[1][r1]);
ind[1][a[1][l1]] = l1,ind[1][a[1][r1]] = r1;
++cnt[pos[l1]][pos[ind[0][a[1][l1]]]],++cnt[pos[r1]][pos[ind[0][a[1][r1]]]];
for(register int i = 1;i <= pos[n];++i)
sum[pos[l1]][i] = sum[pos[l1]][i - 1] + cnt[pos[l1]][i],sum[pos[r1]][i] = sum[pos[r1]][i - 1] + cnt[pos[r1]][i];
}
}
}