百度之星2021复赛游记

400名有T恤穿,我403啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊

T1

题目概述

给一个长度为 nn 的序列,求怎么分割后每一段的 xorxor 值最大,输出这个最大值

思路

显然,每个数为一段最大,答案为每个数的和

Code

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
/*
Name:
Author: xiaruize
Date:
*/
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ull unsigned long long
#define ALL(a) (a).begin(), (a).end()
#define pb push_back
#define mk make_pair
#define pii pair<int, int>
#define pis pair<int, string>
#define sec second
#define fir first
#define sz(a) int((a).size())
#define rep(i, x, y) for (int i = x; i <= y; i++)
#define repp(i, x, y) for (int i = x; i >= y; i--)
#define Yes cout << "Yes" << endl
#define YES cout << "YES" << endl
#define No cout << "No" << endl
#define NO cout << "NO" << endl
#define debug(x) cerr << #x << ": " << x << endl
#define double long double
const int INF = 0x3f3f3f3f;
const int MOD = 1000000007;
const int N = 2e5 + 10;

// bool st;
int n;
int a[N];
int res = 0;
// bool en;

signed main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
// cerr<<(&en-&st)/1024.0/1024.0<<endl;
// auto t_1=chrono::high_resolution_clock::now();
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i], res += a[i];
cout << res << endl;
// auto t_2=chrono::high_resolution_clock::now();
// cout <<". Elapsed (ms): " << chrono::duration_cast<chrono::milliseconds>(t_2 - t_1).count() << endl;
return 0;
}

花絮

第一眼以为 T1T1 不是签到,所以等600多人 AC\color{green}{AC} 我才 AC\color{green}{AC}

T3

题目概述

长度为nn,且元素为 11 ~ mm 的序列的贡献为等于序列最大值的元素个数,求所有序列的贡献之和

1nm10121 \leq n \cdot m \leq 10^{12}

思路

对于一个受过良好义务教育 的小学生来说,很明显,答案为

res=ni=1min1res=n\cdot \sum^{m}_{i=1}{i^{n-1}}

然后,会发现,这玩意暴力算会 TLE\color{red}{TLE} ,所以,回到数据范围,发现可以分 m<nm<nmnm\geq n讨论

那么,如果mm较小,显然快速幂+暴力就能过

而如果mm较大,我们则希望计算的复杂度为 O(N)O(N) , 此时,可以使用拉格朗日插值来进行计算

Code

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*
Name:
Author: xiaruize
Date:
*/
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ull unsigned long long
#define ALL(a) (a).begin(), (a).end()
#define pb push_back
#define mk make_pair
#define pii pair<int, int>
#define pis pair<int, string>
#define sec second
#define fir first
#define sz(a) int((a).size())
#define rep(i, x, y) for (int i = x; i <= y; i++)
#define repp(i, x, y) for (int i = x; i >= y; i--)
#define Yes cout << "Yes" << endl
#define YES cout << "YES" << endl
#define No cout << "No" << endl
#define NO cout << "NO" << endl
#define debug(x) cerr << #x << ": " << x << endl
#define double long double
const int INF = 0x3f3f3f3f;
const int MOD = 998244353;
const int N = 3e6 + 10;

int k, tab[N], p[N], pcnt, f[N], pre[N], suf[N], fac[N], inv[N], ans;

int quickMod(int a, int b)
{
int ans = 1;
while (b)
{
if (b & 1)
ans = ans * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return ans;
}

void sieve(int lim)
{
f[1] = 1;
for (int i = 2; i <= lim; i++)
{
if (!tab[i])
{
p[++pcnt] = i;
f[i] = quickMod(i, k);
}
for (int j = 1; j <= pcnt && 1LL * i * p[j] <= lim; j++)
{
tab[i * p[j]] = 1;
f[i * p[j]] = 1LL * f[i] * f[p[j]] % MOD;
if (!(i % p[j]))
break;
}
}
for (int i = 2; i <= lim; i++)
f[i] = (f[i - 1] + f[i]) % MOD;
}

int cal(int n)
{
ans = 0;
// cerr << k << endl;
if (n <= k + 2)
return f[n];
pre[0] = suf[k + 3] = 1;
for (int i = 1; i <= k + 2; i++)
pre[i] = 1LL * pre[i - 1] * (n - i) % MOD;
for (int i = k + 2; i >= 1; i--)
suf[i] = 1LL * suf[i + 1] * (n - i) % MOD;
fac[0] = inv[0] = fac[1] = inv[1] = 1;
for (int i = 2; i <= k + 2; i++)
{
fac[i] = 1LL * fac[i - 1] * i % MOD;
inv[i] = 1LL * (MOD - MOD / i) * inv[MOD % i] % MOD;
}
for (int i = 2; i <= k + 2; i++)
inv[i] = 1LL * inv[i - 1] * inv[i] % MOD;
for (int i = 1; i <= k + 2; i++)
{
int P = 1LL * pre[i - 1] * suf[i + 1] % MOD;
int Q = 1LL * inv[i - 1] * inv[k + 2 - i] % MOD;
int mul = ((k + 2 - i) & 1) ? -1 : 1;
ans = (ans + 1LL * (Q * mul + MOD) % MOD * P % MOD * f[i] % MOD) % MOD;
}
// cerr << ans << endl;
return ans;
}

signed main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
// cerr<<(&en-&st)/1024.0/1024.0<<endl;
// auto t_1=chrono::high_resolution_clock::now();
int a, b;
cin >> a >> b;
k = a - 1;
int res = 0;
if (b < a)
{
for (int i = 1; i <= b; i++)
{
(res += quickMod(i, a - 1)) %= MOD;
}
cout << res * a % MOD << endl;
return 0;
}
else
{
sieve(k + 2);
cout << a * cal(b) % MOD << endl;
}
// auto t_2=chrono::high_resolution_clock::now();
// cout <<". Elapsed (ms): " << chrono::duration_cast<chrono::milliseconds>(t_2 - t_1).count() << endl;
return 0;
}

T4

题目概述

给定一个长度为 nn 的序列,和 qq 个查询

每次查询有 k,l,rk,l,r ,求 a[l,...,r]a[l,...,r] 中有多少个子序列 ab1,ab2,,abpa_{b_1},a_{b_2},\cdots ,a_{b_p} 满足 i[1,p)i\in [1,p) , abi<kabi+1a_{b_i}<k \leq a_{b_{i+1}}abi+1<kabia_{b_{i+1}} < k \leq a_{b_{i}}

思路

先转化题意,即对于每个询问,把区间内大于等于 kk 的数记为 11 , 小于的记为 00

很明显可以通过dpdp , O(N)O(N) 的解决每个询问

但是上面的方法并不能满足时间限制,所以可以考虑通过线段树来代替dpdp

线段树上存 44 个数,分别表示当前区间中开头为 0/10/1, 结尾为 0/10/1 的情况,这样就可以做 pushuppushup

但是此时仍然需要 O(N)O(N) 去检查哪些需要 updateupdate , 考虑先对询问的按 kk 排序,并对原数组排序

按顺序执行询问时,每个数最多被修改一次,每次修改的复杂度为 O(logn)O(logn) .

由于事先对原数组排序,所以可以通过二分求出修改的起点和终点

所以最终的复杂度为 O(nlogn)O(nlogn),注意常数问题

Code

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
/*
Name:
Author: xiaruize
Date:
*/
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ull unsigned long long
#define ALL(a) (a).begin(), (a).end()
#define pb push_back
#define mk make_pair
#define pii pair<int, int>
#define pis pair<int, string>
#define sec second
#define fir first
#define sz(a) int((a).size())
#define rep(i, x, y) for (int i = x; i <= y; i++)
#define repp(i, x, y) for (int i = x; i >= y; i--)
#define Yes cout << "Yes" << endl
#define YES cout << "YES" << endl
#define No cout << "No" << endl
#define NO cout << "NO" << endl
#define debug(x) cerr << #x << ": " << x << endl
#define double long double
const int INF = 0x3f3f3f3f;
const int MOD = 1000000007;
const int N = 2e5 + 10;

#define ls u << 1
#define rs u << 1 | 1

// bool st;
int n, q;
int a[N];
pii b[N];
// bool en;
struct Node
{
int l, r;
int zz = 0, zo = 0, oz = 0, oo = 0;
Node operator+(Node y) const
{
Node res;
res.zz = zz + y.zz + zz * y.oz + zo * y.zz;
res.zo = zo + y.zo + zz * y.oo + zo * y.zo;
res.oo = oo + y.oo + oz * y.oo + oo * y.zo;
res.oz = oz + y.oz + oz * y.oz + oo * y.zz;
return res;
}
} tr[N * 4];

void pushup(int u)
{
tr[u].zz = tr[ls].zz + tr[rs].zz + tr[ls].zz * tr[rs].oz + tr[ls].zo * tr[rs].zz;
tr[u].zo = tr[ls].zo + tr[rs].zo + tr[ls].zz * tr[rs].oo + tr[ls].zo * tr[rs].zo;
tr[u].oo = tr[ls].oo + tr[rs].oo + tr[ls].oz * tr[rs].oo + tr[ls].oo * tr[rs].zo;
tr[u].oz = tr[ls].oz + tr[rs].oz + tr[ls].oz * tr[rs].oz + tr[ls].oo * tr[rs].zz;
}

void build(int u, int l, int r)
{
if (l == r)
{
tr[u].l = l;
tr[u].r = r;
tr[u].oo = 1;
tr[u].oz = 0;
tr[u].zo = 0;
tr[u].zz = 0;
}
else
{
tr[u].l = l;
tr[u].r = r;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}

void update(int u, int p, int d)
{
if (tr[u].l == tr[u].r)
{
tr[u].oo = tr[u].oz = tr[u].zo = tr[u].zz = 0;
if (d == 1)
tr[u].oo = 1;
else
tr[u].zz = 1;
return;
}
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (p <= mid)
update(u << 1, p, d);
if (p > mid)
update(u << 1 | 1, p, d);
pushup(u);
}
}

Node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
return tr[u];
}
else
{
int mid = tr[u].l + tr[u].r >> 1;
Node res = {1, 1, 0, 0, 0, 0};
if (l <= mid)
res = res + query(u << 1, l, r);
if (r > mid)
res = res + query(u << 1 | 1, l, r);
// cerr << l << ' ' << r << ' ' << res.oo << ' ' << res.zz << ' ' << res.oz << ' ' << res.zo << endl;
return res;
}
}
struct que
{
int k, l, r;
int id;
} s[N];

int res[N];

bool cmp(que a, que b)
{
return a.k < b.k;
}

signed main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
// cerr<<(&en-&st)/1024.0/1024.0<<endl;
// auto t_1=chrono::high_resolution_clock::now();
cin >> n >> q;
for (int i = 1; i <= n; i++)
{
cin >> a[i];
b[i] = {a[i], i};
}
sort(b + 1, b + n + 1);
build(1, 1, n);
for (int i = 1; i <= q; i++)
{
cin >> s[i].k >> s[i].l >> s[i].r;
s[i].id = i;
}
sort(s + 1, s + q + 1, cmp);
int now = 1;
for (int i = 1; i <= q; i++)
{
if (s[i].k != s[i - 1].k)
{
pii x = {s[i].k, -1};
int en = upper_bound(b + 1, b + n + 1, x) - b;
// cerr << tmp << ' ' << en << endl;
for (int j = now; j < en; j++)
{
// cerr << b[j].sec << endl;
update(1, b[j].sec, 0);
}
now = en;
}
Node x = query(1, s[i].l, s[i].r);
res[s[i].id] = x.oo + x.oz + x.zo + x.zz;
}
for (int i = 1; i <= q; i++)
cout << res[i] << endl;
// auto t_2=chrono::high_resolution_clock::now();
// cout <<". Elapsed (ms): " << chrono::duration_cast<chrono::milliseconds>(t_2 - t_1).count() << endl;
return 0;
}

花絮

赛后一分钟调完qwq