Slightly more efficient search

This commit is contained in:
2023-01-28 13:45:58 +01:00
parent 787e215f84
commit 5045f83df8

View File

@@ -1,7 +1,6 @@
use std::array;
use std::cmp::Ordering; use std::cmp::Ordering;
use std::collections::BinaryHeap; use std::collections::BinaryHeap;
use std::ops::Add;
use std::ops::Sub;
use anyhow::Result; use anyhow::Result;
use nom::bytes::complete::tag; use nom::bytes::complete::tag;
@@ -45,81 +44,146 @@ impl TryFrom<&'_ [u8]> for Mineral {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
struct Resources([u8; 4]);
impl Resources {
fn enough_for(self, other: Self) -> bool {
self.0.iter().zip(&other.0).all(|(a, b)| a >= b)
}
}
impl Sub for Resources {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self(std::array::from_fn(|i| self.0[i] - rhs.0[i]))
}
}
impl Add<[u8; 4]> for Resources {
type Output = Self;
fn add(self, rhs: [u8; 4]) -> Self::Output {
Self(std::array::from_fn(|i| self.0[i] + rhs[i]))
}
}
#[derive(Debug)] #[derive(Debug)]
struct BluePrint { struct BluePrint {
id: u32, id: u32,
costs: [Resources; 4], costs: [[u8; 3]; 4],
} }
impl BluePrint { impl BluePrint {
pub fn max_geodes(&self, time: u32) -> u32 { pub fn max_geodes(&self, time: u8) -> u8 {
self.max_geodes_recursive(time, 0, [1, 0, 0, 0], Resources::default()) as u32 /// How much would we produce if all we did was produce geode robots for the remaining time
} fn ideal(remaining: u32) -> u32 {
if remaining <= 1 {
fn max_geodes_recursive( 0
&self, } else {
time_left: u32, (remaining - 1) * remaining / 2
// forbidden is a bitset for convenience
forbidden: u8,
machines: [u8; 4],
resources: Resources,
) -> u8 {
if time_left <= 1 || forbidden.count_ones() == 4 {
return resources.0[3] + machines[3] * (time_left as u8);
}
let resources_after = resources + machines;
let mut best = 0;
let mut can_buy = 0;
for (i, &cost) in self.costs.iter().enumerate() {
if ((1 << i) & forbidden) == 0 && resources.enough_for(cost) {
can_buy |= 1 << i;
let mut new_machines = machines;
new_machines[i] += 1;
best = best.max(self.max_geodes_recursive(
time_left - 1,
0,
new_machines,
resources_after - cost,
))
} }
} }
best.max(self.max_geodes_recursive( #[derive(Eq, PartialEq)]
time_left - 1, struct State {
forbidden | can_buy, missed: u32,
got: u8,
time_left: u8,
resources: [u8; 3],
machines: [u8; 3],
}
impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering {
Ordering::Equal
.then(other.missed.cmp(&self.missed))
.then(self.got.cmp(&other.got))
.then(self.time_left.cmp(&other.time_left))
.then(self.machines.cmp(&other.machines))
}
}
impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
let max_needed = self.max_needed();
let mut todo = BinaryHeap::new();
let mut best = 0;
todo.push(State {
missed: 0,
got: 0,
time_left: time,
resources: [0; 3],
machines: [1, 0, 0],
});
while let Some(State {
missed,
got,
time_left,
resources,
machines, machines,
resources_after, }) = todo.pop()
)) {
let ideal_from_now = ideal(time_left as u32);
if u32::from(best - got) > ideal_from_now {
continue;
}
if todo.len() > 1_000_000 {
panic!(
"Safety: got a todo list of len {}, best: {best}",
todo.len()
);
}
'element: for element in 0..4 {
let mut min_to_build = 0;
for ((&cost, &avail), &machine) in
self.costs[element].iter().zip(&resources).zip(&machines)
{
if cost > avail {
if machine == 0 {
continue 'element;
} else {
min_to_build = min_to_build.max((cost - avail + machine - 1) / machine);
}
}
}
// +1 because we need a turn to build
let built_after = min_to_build + 1;
if built_after >= time_left {
continue;
}
let resources_after = array::from_fn(|i| {
resources[i] + machines[i] * built_after - self.costs[element][i]
});
let time_after = time_left - built_after;
if element == Mineral::Geode as usize {
let new_got = got + time_after;
todo.push(State {
missed,
got: new_got,
time_left: time_after,
resources: resources_after,
machines,
});
best = best.max(new_got);
} else {
if machines[element] >= max_needed[element] {
continue;
}
let mut new_machines = machines;
new_machines[element] += 1;
let new_missed = ideal_from_now - ideal(time_after as u32);
todo.push(State {
missed: new_missed,
got,
time_left: time_after,
resources: resources_after,
machines: new_machines,
})
}
}
}
best
}
fn max_needed(&self) -> [u8; 3] {
let mut max_needed = [0; 3];
for cost in &self.costs {
for (max, &new) in max_needed.iter_mut().zip(cost) {
*max = (*max).max(new);
}
}
max_needed
} }
} }
@@ -137,7 +201,7 @@ fn parse_blueprint(input: &[u8]) -> IResult<&[u8], BluePrint> {
let (mut input, id) = let (mut input, id) =
terminated(delimited(tag("Blueprint "), u32, tag(":")), multispace1)(input)?; terminated(delimited(tag("Blueprint "), u32, tag(":")), multispace1)(input)?;
let mut costs: [Resources; 4] = Default::default(); let mut costs: [[u8; 3]; 4] = Default::default();
let mut parse_robot = terminated( let mut parse_robot = terminated(
tuple(( tuple((
@@ -152,10 +216,10 @@ fn parse_blueprint(input: &[u8]) -> IResult<&[u8], BluePrint> {
let (remaining, (element, (amount1, req1), cost2)) = parse_robot(input)?; let (remaining, (element, (amount1, req1), cost2)) = parse_robot(input)?;
input = remaining; input = remaining;
costs[element as usize].0[req1 as usize] = amount1; costs[element as usize][req1 as usize] = amount1;
if let Some((amount2, req2)) = cost2 { if let Some((amount2, req2)) = cost2 {
costs[element as usize].0[req2 as usize] = amount2; costs[element as usize][req2 as usize] = amount2;
} }
} }
@@ -190,8 +254,25 @@ mod tests {
const SAMPLE: &[u8] = include_bytes!("./samples/19.txt"); const SAMPLE: &[u8] = include_bytes!("./samples/19.txt");
fn get_samples() -> Vec<BluePrint> {
parse_input(SAMPLE, many1(parse_blueprint)).unwrap()
}
#[test] #[test]
fn sample_part1() { fn sample_part1() {
let samples = get_samples();
assert_eq!(samples[0].max_geodes(24), 9);
assert_eq!(samples[1].max_geodes(24), 12);
assert_eq!(part1(SAMPLE).unwrap(), "33"); assert_eq!(part1(SAMPLE).unwrap(), "33");
} }
#[test]
fn sample_part2() {
let samples = get_samples();
assert_eq!(samples[0].max_geodes(32), 56);
assert_eq!(samples[1].max_geodes(32), 62);
}
} }