From d242ac5bc2bbf71edcce9e9dd118c506ef6f93e0 Mon Sep 17 00:00:00 2001 From: Bert Peters Date: Mon, 24 May 2021 14:25:44 +0200 Subject: [PATCH] Use interior mutability for updating graph order --- src/graph.rs | 54 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/src/graph.rs b/src/graph.rs index 3f0b2d2..0a55729 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1,4 +1,5 @@ use std::array::IntoIter; +use std::cell::Cell; use std::collections::HashMap; use std::collections::HashSet; use std::hash::Hash; @@ -36,7 +37,11 @@ where { in_edges: HashSet, out_edges: HashSet, - ord: Order, + // The "Ord" field is a Cell to ensure we can update it in an immutable context. + // `std::collections::HashMap` doesn't let you have multiple mutable references to elements, but + // this way we can use immutable references and still update `ord`. This saves quite a few + // hashmap lookups in the final reorder function. + ord: Cell, } impl DiGraph @@ -58,13 +63,13 @@ where *next_ord = next_ord.checked_add(1).expect("Topological order overflow"); Node { - ord: order, + ord: Cell::new(order), in_edges: Default::default(), out_edges: Default::default(), } }); - (&mut node.in_edges, &mut node.out_edges, node.ord) + (&mut node.in_edges, &mut node.out_edges, node.ord.get()) } pub(crate) fn remove_node(&mut self, n: V) -> bool { @@ -112,8 +117,8 @@ where if lb < ub { // This edge might introduce a cycle, need to recompute the topological sort let mut visited = IntoIter::new([x, y]).collect(); - let mut delta_f = vec![y]; - let mut delta_b = vec![x]; + let mut delta_f = Vec::new(); + let mut delta_b = Vec::new(); if !self.dfs_f(&self.nodes[&y], ub, &mut visited, &mut delta_f) { // This edge introduces a cycle, so we want to reject it and remove it from the @@ -141,23 +146,25 @@ where } /// Forwards depth-first-search - fn dfs_f( - &self, - n: &Node, + fn dfs_f<'a>( + &'a self, + n: &'a Node, ub: Order, visited: &mut HashSet, - delta_f: &mut Vec, + delta_f: &mut Vec<&'a Node>, ) -> bool { + delta_f.push(n); + n.out_edges.iter().all(|w| { let node = &self.nodes[w]; + let ord = node.ord.get(); - if node.ord == ub { + if ord == ub { // Found a cycle false - } else if !visited.contains(w) && node.ord < ub { + } else if !visited.contains(w) && ord < ub { // Need to check recursively visited.insert(*w); - delta_f.push(*w); self.dfs_f(node, ub, visited, delta_f) } else { // Already seen this one or not interesting @@ -167,19 +174,26 @@ where } /// Backwards depth-first-search - fn dfs_b(&self, n: &Node, lb: Order, visited: &mut HashSet, delta_b: &mut Vec) { + fn dfs_b<'a>( + &'a self, + n: &'a Node, + lb: Order, + visited: &mut HashSet, + delta_b: &mut Vec<&'a Node>, + ) { + delta_b.push(n); + for w in &n.in_edges { let node = &self.nodes[w]; - if !visited.contains(w) && lb < node.ord { + if !visited.contains(w) && lb < node.ord.get() { visited.insert(*w); - delta_b.push(*w); self.dfs_b(node, lb, visited, delta_b); } } } - fn reorder(&mut self, mut delta_f: Vec, mut delta_b: Vec) { + fn reorder(&self, mut delta_f: Vec<&Node>, mut delta_b: Vec<&Node>) { self.sort(&mut delta_f); self.sort(&mut delta_b); @@ -187,7 +201,7 @@ where let mut orders = Vec::with_capacity(delta_f.len() + delta_b.len()); for v in delta_b.into_iter().chain(delta_f) { - orders.push(self.nodes[&v].ord); + orders.push(v.ord.get()); l.push(v); } @@ -196,13 +210,13 @@ where orders.sort_unstable(); for (node, order) in l.into_iter().zip(orders) { - self.nodes.get_mut(&node).unwrap().ord = order; + node.ord.set(order); } } - fn sort(&self, ids: &mut [V]) { + fn sort(&self, ids: &mut [&Node]) { // Can use unstable sort because mutex ids should not be equal - ids.sort_unstable_by_key(|v| self.nodes[&v].ord); + ids.sort_unstable_by_key(|v| &v.ord); } }