AtCoderBeginnerContest312 F 問題 500 点 「Cans and Openers」

問題リンク

問題概要

制約

解法

できるだけ大きい$X_i$を選べば最大になるので,貪欲法で溶けそうなことがわかる。缶切りが必要なものをどれだけ採用するかが問題になる。 缶切りもそれぞれ$X_i$回使用できるものがあるので,$X_i$が大きいものから順に使用することを考える。 缶切り不要のものと缶切りを n 個使用したときに選べるものの中から大きい順に M-n 個選べば良いことから,缶切りの数ごとに答えを求めていくと良い。 選べる値は SortedMultiSet で管理すると簡単になるが,都度合計を計算していくと TLE するので,SMS とは別に n 個の缶切りを使用した場合の合計値も管理しておくと,$\mathcal{O}(N)$で答えを求める事ができる。

お気持ち

コンテスト中にはできなかったが,これ 40 分位あれば自力でできたのでは…かなしい。 が,今回解くにあたって Rust で SortedMultiSet もどきを作ったので,今後活用して機能を増やしていきたい。しかし min/max 取るときに一回 Itertor にキャストしてたりしてどうなんかなと思う。1.64 あたりになればBTreeSet.first()みたいのが使えるらしいけど,AtCoder はまだ 1.48 なのでしょうがないのか?

<追記>Rust で SortedMultiSet を Generics 使っていい感じにするやつに修正した。ChatGPT 使ったけど,値の返り値は Option<&T>にするのが良さそう。それはそうか。 ## AC コード ```python # https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py N, M = map(int, input().split()) TX = [list(map(int, input().split())) for _ in range(N)] mst = SortedMultiset([]) lst = [] cnt = 0 cuts = [] needs = SortedMultiset([]) for t, x in TX: if t == 0: mst.add(x) # M個超えたら小さいの消す if len(mst) > M: mst.discard(mst[0]) elif t == 1: needs.add(x) else: cuts.append(x) cuts.sort(reverse=True) lst.sort(reverse=True) ans = sum(mst) tmp = ans for idx, c in enumerate(cuts, start=1): if M == idx: break # M - (缶切りの数)を超えてたら引いてmstからけしとく if len(mst) > M - idx: tmp -= mst[0] mst.discard(mst[0]) # 缶切りの使用回数分追加する for _ in range(c): if len(needs) == 0: break v = needs[-1] needs.discard(v) mst.add(v) if len(mst) > M - idx: u = mst[0] mst.discard(u) tmp -= u tmp += v ans = max(ans, tmp) print(ans) ``` ```rust #![allow(non_snake_case)] #![allow(unused_imports)] #![allow(unused_macros)] #![allow(clippy::needless_range_loop)] #![allow(clippy::comparison_chain)] #![allow(clippy::nonminimal_bool)] #![allow(clippy::neg_multiply)] #![allow(dead_code)] #![recursion_limit = "1024"] use std::collections::{vec_deque, BTreeSet, BinaryHeap, HashMap, HashSet, VecDeque}; use std::f64::consts::PI; use std::hash::Hash; use std::mem::swap; use std::ops::Index; use std::ops::Mul; use itertools::{CombinationsWithReplacement, Itertools}; use num::{integer::Roots, traits::Pow, ToPrimitive}; use num_integer::{div_ceil, Integer}; use num_traits::{abs_sub, Float}; use permutohedron::Heap; use proconio::input; use proconio::marker::{Chars, Usize1}; use std::io::*; use whiteread::parse_line; // struct SortedMultiSet { st: BTreeSet, cnt: HashMap<T, usize>, num: usize, } impl<T: Ord + Clone + Hash + Mul<Output = T> + From + Into> SortedMultiSet { fn new() -> Self { SortedMultiSet { st: BTreeSet::new(), cnt: HashMap::new(), num: 0, } } fn add(&mut self, v: T) { self.st.insert(v.clone()); let c = *self.cnt.get(&v).unwrap_or(&0usize); self.cnt.insert(v, c + 1); self.num += 1; } fn discard(&mut self, v: T) { let c = *self.cnt.get(&v).unwrap_or(&0usize); if c == 1usize { self.st.remove(&v); self.cnt.remove(&v); } else { self.cnt.insert(v, c - 1); } self.num -= 1; } fn vmin(&self) -> Option<&T> { self.st.iter().next() } fn vmax(&self) -> Option<&T> { self.st.iter().last() } fn len(&self) -> usize { self.num } fn sum(&self) -> usize { let mut ret = 0; for (k, v) in &self.cnt { ret += *v * (*k).clone().into(); } ret } } fn solve() { #[rustfmt::skip] input! { N:usize, M:usize, TX:[(usize, usize); N] } let mut mst = SortedMultiSet::new(); let mut need = SortedMultiSet::new(); let mut cuts = vec![]; for &(t, x) in TX.iter() { match t { 0 => { mst.add(x); if mst.len() > M { let &v = mst.vmin().unwrap(); mst.discard(v); } } 1 => { need.add(x); } 2 => { cuts.push(x); } _ => {} } } cuts.sort(); // cuts.reverse(); let mut ans = mst.sum(); let mut tmp = ans.clone(); for (idx, &c) in cuts.iter().rev().enumerate() { let idx = idx + 1; if M == idx { break; } if mst.len() > M - idx { let &v = mst.vmin().unwrap(); tmp -= v; mst.discard(v); } if need.len() == 0 { break; } for _ in 0..c { let &v = need.vmax().unwrap(); need.discard(v); mst.add(v); tmp += v; while mst.len() > M - idx { let &u = mst.vmin().unwrap(); mst.discard(u); tmp -= u; } // println!("{}", mst.len()); // println!("{}", ans); // println!("{:?}", mst.st); // println!("{:?}", mst.cnt); ans = ans.max(tmp); if need.len() == 0 { break; } } } println!("{}", ans); } fn main() { // input! {N:usize} // for _ in 0..N { // solve(); // } solve(); } ```