Use dedicated iterator instead of range

This commit is contained in:
2021-12-18 18:00:40 +01:00
parent cc81a7012b
commit 101ebee505

View File

@@ -1,8 +1,6 @@
use std::io::Read; use std::io::Read;
use std::ops::RangeInclusive; use std::ops::RangeInclusive;
use itertools::Itertools;
use itertools::MinMaxResult;
use nom::bytes::complete::tag; use nom::bytes::complete::tag;
use nom::combinator::map; use nom::combinator::map;
use nom::sequence::preceded; use nom::sequence::preceded;
@@ -31,7 +29,7 @@ fn position(initial: i32, time: i32) -> i32 {
time * (2 * initial - time + 1) / 2 time * (2 * initial - time + 1) / 2
} }
fn find_hit(initial: i32, range: &RangeInclusive<i32>) -> Option<RangeInclusive<i32>> { fn find_hit(initial: i32, range: &RangeInclusive<i32>) -> impl Iterator<Item = i32> + '_ {
// y position at time x: f(x) = x * (1 + initial + initial - x) / 2 // y position at time x: f(x) = x * (1 + initial + initial - x) / 2
// = -1/2x^2 + (initial + 0.5)x // = -1/2x^2 + (initial + 0.5)x
// //
@@ -40,30 +38,19 @@ fn find_hit(initial: i32, range: &RangeInclusive<i32>) -> Option<RangeInclusive<
// -1/2x^2 + (initial + 0.5)x - (max(box) + min(box)) / 2 = 0 // -1/2x^2 + (initial + 0.5)x - (max(box) + min(box)) / 2 = 0
let middle = (*range.start() + *range.end()) as f64 / 2.; let middle = (*range.start() + *range.end()) as f64 / 2.;
let b = initial as f64 + 0.5; let b = initial as f64 + 0.5;
let hit = solve_quadratic(-0.5, b, -middle)? as i32; let hit = if let Some(hit) = solve_quadratic(-0.5, b, -middle) {
hit as i32
if hit < 0 {
// Should not happen because of the shape but for correctness
None
} else { } else {
let min_hit = (0..=hit) -1
};
(0..=hit)
.rev() .rev()
.take_while(|&n| range.contains(&position(initial, n))) .take_while(move |&n| range.contains(&position(initial, n)))
.min(); .chain(((hit + 1)..).take_while(move |&n| range.contains(&position(initial, n))))
let max_hit = ((hit + 1)..)
.take_while(|&n| range.contains(&position(initial, n)))
.max();
match (min_hit, max_hit) {
(Some(min), Some(max)) => Some(min..=max),
(Some(val), None) | (None, Some(val)) => Some(val..=val),
_ => None,
}
}
} }
fn find_speed(x: i32, range: &RangeInclusive<i32>) -> Option<RangeInclusive<i32>> { fn find_speed(x: i32, range: &RangeInclusive<i32>) -> Option<(i32, i32)> {
if *range.end() <= position(x, x) { if *range.end() <= position(x, x) {
// Can and should come to a full stop // Can and should come to a full stop
let max = solve_quadratic(0.5, 0.5, -*range.end() as f64)? as i32; let max = solve_quadratic(0.5, 0.5, -*range.end() as f64)? as i32;
@@ -73,7 +60,7 @@ fn find_speed(x: i32, range: &RangeInclusive<i32>) -> Option<RangeInclusive<i32>
.take_while(|&n| range.contains(&position(n, n))) .take_while(|&n| range.contains(&position(n, n)))
.last()?; .last()?;
Some(min..=max) Some((min, max))
} else { } else {
// Might hit the target at speed // Might hit the target at speed
let max = (x * x + 2 * *range.end() - x) / (2 * x); let max = (x * x + 2 * *range.end() - x) / (2 * x);
@@ -83,7 +70,7 @@ fn find_speed(x: i32, range: &RangeInclusive<i32>) -> Option<RangeInclusive<i32>
.take_while(|&n| range.contains(&position(n, n.min(x)))) .take_while(|&n| range.contains(&position(n, n.min(x))))
.last()?; .last()?;
Some(min..=max) Some((min, max))
} }
} }
@@ -105,22 +92,16 @@ fn parse_input(input: &[u8]) -> IResult<&[u8], (RangeInclusive<i32>, RangeInclus
pub fn part1(input: &mut dyn Read) -> String { pub fn part1(input: &mut dyn Read) -> String {
let (x_range, y_range) = read_input(input, parse_input); let (x_range, y_range) = read_input(input, parse_input);
let check_value = |y_speed| { let check_value =
let mut time = find_hit(y_speed, &y_range)?; |y_speed| find_hit(y_speed, &y_range).any(|time| find_speed(time, &x_range).is_some());
if time.any(|time| find_speed(time, &x_range).is_some()) {
Some(position(y_speed, y_speed))
} else {
None
}
};
debug_assert!(*y_range.start() < 0); debug_assert!(*y_range.start() < 0);
let y_max = -*y_range.start(); let y_max = -*y_range.start();
(0..y_max) (0..y_max)
.filter_map(check_value) .rev()
.max() .find(|&speed| check_value(speed))
.map(|speed| position(speed, speed))
.unwrap() .unwrap()
.to_string() .to_string()
} }
@@ -129,26 +110,17 @@ pub fn part2(input: &mut dyn Read) -> String {
let (x_range, y_range) = read_input(input, parse_input); let (x_range, y_range) = read_input(input, parse_input);
let num_options = |y_speed| { let num_options = |y_speed| {
let time = find_hit(y_speed, &y_range)?; find_hit(y_speed, &y_range)
let range = time
.filter_map(|time| find_speed(time, &x_range)) .filter_map(|time| find_speed(time, &x_range))
.flat_map(|x| [*x.start(), *x.end()]) .reduce(|(a_min, a_max), (b_min, b_max)| (a_min.min(b_min), a_max.max(b_max)))
.minmax(); .map(|(min, max)| max - min + 1)
.unwrap_or(0)
Some(match range {
MinMaxResult::NoElements => 0,
MinMaxResult::OneElement(_) => 1,
MinMaxResult::MinMax(min, max) => max - min + 1,
})
}; };
debug_assert!(*y_range.start() < 0); debug_assert!(*y_range.start() < 0);
let y_max = -*y_range.start(); let y_max = -*y_range.start();
(-y_max..y_max) (-y_max..y_max).map(num_options).sum::<i32>().to_string()
.filter_map(num_options)
.sum::<i32>()
.to_string()
} }
#[cfg(test)] #[cfg(test)]
@@ -161,9 +133,9 @@ mod tests {
#[test] #[test]
fn test_find_hit() { fn test_find_hit() {
assert_eq!(find_hit(2, &(-10..=-5)), Some(7..=7)); assert_eq!(find_hit(2, &(-10..=-5)).collect::<Vec<_>>(), vec![7]);
assert_eq!(find_hit(3, &(-10..=-5)), Some(9..=9)); assert_eq!(find_hit(3, &(-10..=-5)).collect::<Vec<_>>(), vec![9]);
assert_eq!(find_hit(0, &(-10..=-5)), Some(4..=5)); assert_eq!(find_hit(0, &(-10..=-5)).collect::<Vec<_>>(), vec![4, 5]);
} }
#[test] #[test]