AtCoderBeginnerContest312 F 問題 500 点 「Cans and Openers」
問題リンク
問題概要
制約
解法
できるだけ大きい$X_i$を選べば最大になるので,貪欲法で溶けそうなことがわかる。缶切りが必要なものをどれだけ採用するかが問題になる。
缶切りもそれぞれ$X_i$回使用できるものがあるので,$X_i$が大きいものから順に使用することを考える。
缶切り不要のものと缶切りを n 個使用したときに選べるものの中から大きい順に M-n 個選べば良いことから,缶切りの数ごとに答えを求めていくと良い。
選べる値は SortedMultiSet で管理すると簡単になるが,都度合計を計算していくと TLE するので,SMS とは別に n 個の缶切りを使用した場合の合計値も管理しておくと,$\mathcal{O}(N)$で答えを求める事ができる。
お気持ち
コンテスト中にはできなかったが,これ 40 分位あれば自力でできたのでは…かなしい。
が,今回解くにあたって Rust で SortedMultiSet もどきを作ったので,今後活用して機能を増やしていきたい。しかし min/max 取るときに一回 Itertor にキャストしてたりしてどうなんかなと思う。1.64 あたりになればBTreeSet.first()
みたいのが使えるらしいけど,AtCoder はまだ 1.48 なのでしょうがないのか?
<追記>Rust で SortedMultiSet を Generics 使っていい感じにするやつに修正した。ChatGPT 使ったけど,値の返り値は Option<&T>にするのが良さそう。それはそうか。
## AC コード
```python
# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
N, M = map(int, input().split())
TX = [list(map(int, input().split())) for _ in range(N)]
mst = SortedMultiset([])
lst = []
cnt = 0
cuts = []
needs = SortedMultiset([])
for t, x in TX:
if t == 0:
mst.add(x)
# M個超えたら小さいの消す
if len(mst) > M:
mst.discard(mst[0])
elif t == 1:
needs.add(x)
else:
cuts.append(x)
cuts.sort(reverse=True)
lst.sort(reverse=True)
ans = sum(mst)
tmp = ans
for idx, c in enumerate(cuts, start=1):
if M == idx:
break
# M - (缶切りの数)を超えてたら引いてmstからけしとく
if len(mst) > M - idx:
tmp -= mst[0]
mst.discard(mst[0])
# 缶切りの使用回数分追加する
for _ in range(c):
if len(needs) == 0:
break
v = needs[-1]
needs.discard(v)
mst.add(v)
if len(mst) > M - idx:
u = mst[0]
mst.discard(u)
tmp -= u
tmp += v
ans = max(ans, tmp)
print(ans)
```
```rust
#![allow(non_snake_case)]
#![allow(unused_imports)]
#![allow(unused_macros)]
#![allow(clippy::needless_range_loop)]
#![allow(clippy::comparison_chain)]
#![allow(clippy::nonminimal_bool)]
#![allow(clippy::neg_multiply)]
#![allow(dead_code)]
#![recursion_limit = "1024"]
use std::collections::{vec_deque, BTreeSet, BinaryHeap, HashMap, HashSet, VecDeque};
use std::f64::consts::PI;
use std::hash::Hash;
use std::mem::swap;
use std::ops::Index;
use std::ops::Mul;
use itertools::{CombinationsWithReplacement, Itertools};
use num::{integer::Roots, traits::Pow, ToPrimitive};
use num_integer::{div_ceil, Integer};
use num_traits::{abs_sub, Float};
use permutohedron::Heap;
use proconio::input;
use proconio::marker::{Chars, Usize1};
use std::io::*;
use whiteread::parse_line;
//
struct SortedMultiSet
{
st: BTreeSet,
cnt: HashMap<T, usize>,
num: usize,
}
impl<T: Ord + Clone + Hash + Mul<Output = T> + From + Into> SortedMultiSet {
fn new() -> Self {
SortedMultiSet {
st: BTreeSet::new(),
cnt: HashMap::new(),
num: 0,
}
}
fn add(&mut self, v: T) {
self.st.insert(v.clone());
let c = *self.cnt.get(&v).unwrap_or(&0usize);
self.cnt.insert(v, c + 1);
self.num += 1;
}
fn discard(&mut self, v: T) {
let c = *self.cnt.get(&v).unwrap_or(&0usize);
if c == 1usize {
self.st.remove(&v);
self.cnt.remove(&v);
} else {
self.cnt.insert(v, c - 1);
}
self.num -= 1;
}
fn vmin(&self) -> Option<&T> {
self.st.iter().next()
}
fn vmax(&self) -> Option<&T> {
self.st.iter().last()
}
fn len(&self) -> usize {
self.num
}
fn sum(&self) -> usize {
let mut ret = 0;
for (k, v) in &self.cnt {
ret += *v * (*k).clone().into();
}
ret
}
}
fn solve() {
#[rustfmt::skip]
input! {
N:usize, M:usize,
TX:[(usize, usize); N]
}
let mut mst = SortedMultiSet::new();
let mut need = SortedMultiSet::new();
let mut cuts = vec![];
for &(t, x) in TX.iter() {
match t {
0 => {
mst.add(x);
if mst.len() > M {
let &v = mst.vmin().unwrap();
mst.discard(v);
}
}
1 => {
need.add(x);
}
2 => {
cuts.push(x);
}
_ => {}
}
}
cuts.sort();
// cuts.reverse();
let mut ans = mst.sum();
let mut tmp = ans.clone();
for (idx, &c) in cuts.iter().rev().enumerate() {
let idx = idx + 1;
if M == idx {
break;
}
if mst.len() > M - idx {
let &v = mst.vmin().unwrap();
tmp -= v;
mst.discard(v);
}
if need.len() == 0 {
break;
}
for _ in 0..c {
let &v = need.vmax().unwrap();
need.discard(v);
mst.add(v);
tmp += v;
while mst.len() > M - idx {
let &u = mst.vmin().unwrap();
mst.discard(u);
tmp -= u;
}
// println!("{}", mst.len());
// println!("{}", ans);
// println!("{:?}", mst.st);
// println!("{:?}", mst.cnt);
ans = ans.max(tmp);
if need.len() == 0 {
break;
}
}
}
println!("{}", ans);
}
fn main() {
// input! {N:usize}
// for _ in 0..N {
// solve();
// }
solve();
}
```
追記>