diff --git a/src/graph.rs b/src/graph.rs index 79e7f74..8a16af5 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -22,8 +22,8 @@ type Order = usize; /// [paper]: https://whileydave.com/publications/pk07_jea/ #[derive(Clone, Default, Debug)] pub struct DiGraph { - in_edges: HashMap>, - out_edges: HashMap>, + in_edges: HashMap>, + out_edges: HashMap>, /// Next topological sort order next_ord: Order, /// Poison flag, set if a cycle is detected when adding a new edge and @@ -42,7 +42,7 @@ impl DiGraph { /// order. /// /// New nodes are appended to the end of the topological order when added. - fn add_node(&mut self, n: MutexId) -> (&mut Vec, &mut Vec, Order) { + fn add_node(&mut self, n: MutexId) -> (&mut HashSet, &mut HashSet, Order) { let next_ord = &mut self.next_ord; let in_edges = self.in_edges.entry(n).or_default(); let out_edges = self.out_edges.entry(n).or_default(); @@ -61,11 +61,11 @@ impl DiGraph { None => false, Some(out_edges) => { for other in out_edges { - self.in_edges.get_mut(&other).unwrap().retain(|c| c != &n); + self.in_edges.get_mut(&other).unwrap().remove(&n); } for other in self.in_edges.remove(&n).unwrap() { - self.out_edges.get_mut(&other).unwrap().retain(|c| c != &n); + self.out_edges.get_mut(&other).unwrap().remove(&n); } if self.contains_cycle { @@ -89,16 +89,14 @@ impl DiGraph { let (_, out_edges, ub) = self.add_node(x); - if out_edges.contains(&y) { + if !out_edges.insert(y) { // Edge already exists, nothing to be done return false; } - out_edges.push(y); - let (in_edges, _, lb) = self.add_node(y); - in_edges.push(x); + in_edges.insert(x); if !self.contains_cycle && lb < ub { // This edge might introduce a cycle, need to recompute the topological sort