diff --git a/2021/src/day07.rs b/2021/src/day07.rs index 899b4fc..c0a5625 100644 --- a/2021/src/day07.rs +++ b/2021/src/day07.rs @@ -4,29 +4,33 @@ use itertools::Itertools; use crate::common::ordered; -fn compute_groups<'a>(it: impl IntoIterator) -> Vec<(usize, usize)> { - let mut it = it.into_iter().copied().dedup_with_count(); +fn compute_cumulative<'a, I, It>(it: I) -> Vec +where + I: IntoIterator, + It: Iterator + ExactSizeIterator, +{ + let mut it = it.into_iter().copied(); + let mut costs = Vec::with_capacity(it.len()); + costs.push(0); let (mut population, mut last_pos) = it.next().unwrap(); let mut last_cost = 0; - let mut costs = vec![(last_pos, 0)]; - for (number, pos) in it { let (first, last) = ordered(last_pos, pos); let new_cost = last_cost + population * (last - first); - costs.push((pos, new_cost)); - population += number; last_pos = pos; last_cost = new_cost; + + costs.push(new_cost); } costs } -fn read_input(input: &mut dyn Read) -> Vec { +fn read_input(input: &mut dyn Read) -> Vec<(usize, usize)> { let mut buf = String::new(); input.read_to_string(&mut buf).unwrap(); @@ -38,23 +42,20 @@ fn read_input(input: &mut dyn Read) -> Vec { crabs.sort_unstable(); - crabs + crabs.into_iter().dedup_with_count().collect() } pub fn part1(input: &mut dyn Read) -> String { let crabs = read_input(input); - let forward_costs = compute_groups(&crabs); - let backwards_costs = compute_groups(crabs.iter().rev()); + let forwards_costs = compute_cumulative(&crabs); + let backwards_costs = compute_cumulative(crabs.iter().rev()); - backwards_costs - .into_iter() - .rev() - .zip(forward_costs) - .map(|((_, cost_b), (_, cost_f))| cost_f + cost_b) - .min() - .unwrap() - .to_string() + // Note: the optimal position can be proven to be one of the original positions. + ternary_search(0, forwards_costs.len() - 1, |idx| { + forwards_costs[idx] + backwards_costs[backwards_costs.len() - 1 - idx] + }) + .to_string() } pub fn sum_until(end: usize) -> usize { @@ -72,13 +73,13 @@ fn cost_at(pos: usize, groups: &[(usize, usize)]) -> usize { .sum() } -fn ternary_search(groups: &[(usize, usize)], mut min: usize, mut max: usize) -> usize { +fn ternary_search(mut min: usize, mut max: usize, callback: impl Fn(usize) -> usize) -> usize { while max - min > 6 { let mid1 = min + (max - min) / 3; let mid2 = max - (max - min) / 3; - let cost1 = cost_at(mid1, groups); - let cost2 = cost_at(mid2, groups); + let cost1 = callback(mid1); + let cost2 = callback(mid2); if cost1 < cost2 { max = mid2 - 1 @@ -88,17 +89,16 @@ fn ternary_search(groups: &[(usize, usize)], mut min: usize, mut max: usize) -> } // Ternary search isn't effective at such small intervals so we iterate the remaining part - (min..=max).map(|pos| cost_at(pos, groups)).min().unwrap() + (min..=max).map(callback).min().unwrap() } pub fn part2(input: &mut dyn Read) -> String { - let crabs = read_input(input); - let groups: Vec<_> = crabs.into_iter().dedup_with_count().collect(); + let groups = read_input(input); let min = groups.first().unwrap().1; let max = groups.last().unwrap().1; - ternary_search(&groups, min, max).to_string() + ternary_search(min, max, |pos| cost_at(pos, &groups)).to_string() } #[cfg(test)]