Use interior mutability for updating graph order

This commit is contained in:
2021-05-24 14:25:44 +02:00
parent 39b493a871
commit d242ac5bc2

View File

@@ -1,4 +1,5 @@
use std::array::IntoIter;
use std::cell::Cell;
use std::collections::HashMap;
use std::collections::HashSet;
use std::hash::Hash;
@@ -36,7 +37,11 @@ where
{
in_edges: HashSet<V>,
out_edges: HashSet<V>,
ord: Order,
// The "Ord" field is a Cell to ensure we can update it in an immutable context.
// `std::collections::HashMap` doesn't let you have multiple mutable references to elements, but
// this way we can use immutable references and still update `ord`. This saves quite a few
// hashmap lookups in the final reorder function.
ord: Cell<Order>,
}
impl<V> DiGraph<V>
@@ -58,13 +63,13 @@ where
*next_ord = next_ord.checked_add(1).expect("Topological order overflow");
Node {
ord: order,
ord: Cell::new(order),
in_edges: Default::default(),
out_edges: Default::default(),
}
});
(&mut node.in_edges, &mut node.out_edges, node.ord)
(&mut node.in_edges, &mut node.out_edges, node.ord.get())
}
pub(crate) fn remove_node(&mut self, n: V) -> bool {
@@ -112,8 +117,8 @@ where
if lb < ub {
// This edge might introduce a cycle, need to recompute the topological sort
let mut visited = IntoIter::new([x, y]).collect();
let mut delta_f = vec![y];
let mut delta_b = vec![x];
let mut delta_f = Vec::new();
let mut delta_b = Vec::new();
if !self.dfs_f(&self.nodes[&y], ub, &mut visited, &mut delta_f) {
// This edge introduces a cycle, so we want to reject it and remove it from the
@@ -141,23 +146,25 @@ where
}
/// Forwards depth-first-search
fn dfs_f(
&self,
n: &Node<V>,
fn dfs_f<'a>(
&'a self,
n: &'a Node<V>,
ub: Order,
visited: &mut HashSet<V>,
delta_f: &mut Vec<V>,
delta_f: &mut Vec<&'a Node<V>>,
) -> bool {
delta_f.push(n);
n.out_edges.iter().all(|w| {
let node = &self.nodes[w];
let ord = node.ord.get();
if node.ord == ub {
if ord == ub {
// Found a cycle
false
} else if !visited.contains(w) && node.ord < ub {
} else if !visited.contains(w) && ord < ub {
// Need to check recursively
visited.insert(*w);
delta_f.push(*w);
self.dfs_f(node, ub, visited, delta_f)
} else {
// Already seen this one or not interesting
@@ -167,19 +174,26 @@ where
}
/// Backwards depth-first-search
fn dfs_b(&self, n: &Node<V>, lb: Order, visited: &mut HashSet<V>, delta_b: &mut Vec<V>) {
fn dfs_b<'a>(
&'a self,
n: &'a Node<V>,
lb: Order,
visited: &mut HashSet<V>,
delta_b: &mut Vec<&'a Node<V>>,
) {
delta_b.push(n);
for w in &n.in_edges {
let node = &self.nodes[w];
if !visited.contains(w) && lb < node.ord {
if !visited.contains(w) && lb < node.ord.get() {
visited.insert(*w);
delta_b.push(*w);
self.dfs_b(node, lb, visited, delta_b);
}
}
}
fn reorder(&mut self, mut delta_f: Vec<V>, mut delta_b: Vec<V>) {
fn reorder(&self, mut delta_f: Vec<&Node<V>>, mut delta_b: Vec<&Node<V>>) {
self.sort(&mut delta_f);
self.sort(&mut delta_b);
@@ -187,7 +201,7 @@ where
let mut orders = Vec::with_capacity(delta_f.len() + delta_b.len());
for v in delta_b.into_iter().chain(delta_f) {
orders.push(self.nodes[&v].ord);
orders.push(v.ord.get());
l.push(v);
}
@@ -196,13 +210,13 @@ where
orders.sort_unstable();
for (node, order) in l.into_iter().zip(orders) {
self.nodes.get_mut(&node).unwrap().ord = order;
node.ord.set(order);
}
}
fn sort(&self, ids: &mut [V]) {
fn sort(&self, ids: &mut [&Node<V>]) {
// Can use unstable sort because mutex ids should not be equal
ids.sort_unstable_by_key(|v| self.nodes[&v].ord);
ids.sort_unstable_by_key(|v| &v.ord);
}
}