diff --git a/src/graph.rs b/src/graph.rs new file mode 100644 index 0000000..19ba2a4 --- /dev/null +++ b/src/graph.rs @@ -0,0 +1,116 @@ +use std::collections::HashMap; +use std::collections::HashSet; + +#[derive(Clone, Default, Debug)] +pub struct DiGraph { + in_edges: HashMap>, + out_edges: HashMap>, +} + +impl DiGraph { + fn add_node(&mut self, node: usize) -> (&mut Vec, &mut Vec) { + let in_edges = self.in_edges.entry(node).or_default(); + let out_edges = self.out_edges.entry(node).or_default(); + + (in_edges, out_edges) + } + + pub fn remove_node(&mut self, node: usize) -> bool { + match self.out_edges.remove(&node) { + None => false, + Some(out_edges) => { + for other in out_edges { + self.in_edges + .get_mut(&other) + .unwrap() + .retain(|c| c != &node); + } + + for other in self.in_edges.remove(&node).unwrap() { + self.out_edges + .get_mut(&other) + .unwrap() + .retain(|c| c != &node); + } + + true + } + } + } + + pub fn add_edge(&mut self, from: usize, to: usize) -> bool { + if from == to { + return false; + } + + let (_, out_edges) = self.add_node(from); + + out_edges.push(to); + + let (in_edges, _) = self.add_node(to); + + // No need for existence check assuming the datastructure is consistent + in_edges.push(from); + + true + } + + pub fn has_cycles(&self) -> bool { + let mut marks = HashSet::new(); + let mut temp = HashSet::new(); + + self.out_edges + .keys() + .copied() + .any(|node| !self.visit(node, &mut marks, &mut temp)) + } + + fn visit(&self, node: usize, marks: &mut HashSet, temp: &mut HashSet) -> bool { + if marks.contains(&node) { + return true; + } + + if !temp.insert(node) { + return false; + } + + if self.out_edges[&node] + .iter() + .copied() + .any(|node| !self.visit(node, marks, temp)) + { + return false; + } + + temp.remove(&node); + + marks.insert(node); + + true + } +} + +#[cfg(test)] +mod tests { + use super::DiGraph; + + #[test] + fn test_digraph() { + let mut graph = DiGraph::default(); + + graph.add_edge(1, 2); + graph.add_edge(2, 3); + graph.add_edge(3, 4); + graph.add_edge(5, 2); + + assert!(!graph.has_cycles()); + + graph.add_edge(4, 2); + + assert!(graph.has_cycles()); + + assert!(graph.remove_node(4)); + + assert!(!graph.has_cycles()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 5b18f0d..3154a3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,116 +1,184 @@ -use std::collections::HashMap; -use std::collections::HashSet; +use std::cell::RefCell; +use std::fmt; +use std::fmt::Display; +use std::ops::Deref; +use std::ops::DerefMut; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::sync::LockResult; +use std::sync::Mutex; +use std::sync::MutexGuard; +use std::sync::PoisonError; +use std::sync::TryLockError; +use std::sync::TryLockResult; -#[derive(Clone, Default, Debug)] -struct DiGraph { - in_edges: HashMap>, - out_edges: HashMap>, +mod graph; + +/// Counter for Mutex IDs. Atomic avoids the need for locking. +static ID_SEQUENCE: AtomicUsize = AtomicUsize::new(0); + +thread_local! { + /// Stack to track which locks are held + /// + /// Assuming that locks are roughly released in the reverse order in which they were acquired, + /// a stack should be more efficient to keep track of the current state than a set would be. + static HELD_LOCKS: RefCell> = RefCell::new(Vec::new()); } -impl DiGraph { - fn add_node(&mut self, node: usize) -> (&mut Vec, &mut Vec) { - let in_edges = self.in_edges.entry(node).or_default(); - let out_edges = self.out_edges.entry(node).or_default(); +/// Wrapper for std::sync::Mutex +#[derive(Debug)] +pub struct TracingMutex { + inner: Mutex, + id: usize, +} - (in_edges, out_edges) +/// Wrapper for std::sync::MutexGuard +#[derive(Debug)] +pub struct TracingMutexGuard<'a, T> { + inner: MutexGuard<'a, T>, + mutex: &'a TracingMutex, +} + +fn next_mutex_id() -> usize { + ID_SEQUENCE + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |id| id.checked_add(1)) + .expect("Mutex ID wraparound happened, results unreliable") +} + +impl TracingMutex { + pub fn new(t: T) -> Self { + Self { + inner: Mutex::new(t), + id: next_mutex_id(), + } } - pub fn remove_node(&mut self, node: usize) -> bool { - match self.out_edges.remove(&node) { - None => false, - Some(out_edges) => { - for other in out_edges { - self.in_edges - .get_mut(&other) - .unwrap() - .retain(|c| c != &node); - } + pub fn lock(&self) -> LockResult> { + let result = self.inner.lock(); + self.register_lock(); - for other in self.in_edges.remove(&node).unwrap() { - self.out_edges - .get_mut(&other) - .unwrap() - .retain(|c| c != &node); - } + let mapper = |guard| TracingMutexGuard { + mutex: self, + inner: guard, + }; - true + result + .map(mapper) + .map_err(|poison| PoisonError::new(mapper(poison.into_inner()))) + } + + pub fn try_lock(&self) -> TryLockResult> { + let result = self.inner.try_lock(); + self.register_lock(); + + let mapper = |guard| TracingMutexGuard { + mutex: self, + inner: guard, + }; + + result.map(mapper).map_err(|error| match error { + TryLockError::Poisoned(poison) => PoisonError::new(mapper(poison.into_inner())).into(), + TryLockError::WouldBlock => TryLockError::WouldBlock, + }) + } + + pub fn get_id(&self) -> usize { + self.id + } + + pub fn is_poisoned(&self) -> bool { + self.inner.is_poisoned() + } + + pub fn get_mut(&mut self) -> LockResult<&mut T> { + self.inner.get_mut() + } + + fn register_lock(&self) { + HELD_LOCKS.with(|locks| locks.borrow_mut().push(self.id)) + } +} + +impl Default for TracingMutex { + fn default() -> Self { + Self::new(T::default()) + } +} + +impl From for TracingMutex { + fn from(t: T) -> Self { + Self::new(t) + } +} + +impl<'a, T> Deref for TracingMutexGuard<'a, T> { + type Target = MutexGuard<'a, T>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl<'a, T> DerefMut for TracingMutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl<'a, T: Display> fmt::Display for TracingMutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} + +impl<'a, T> Drop for TracingMutexGuard<'a, T> { + fn drop(&mut self) { + HELD_LOCKS.with(|locks| { + let id = self.mutex.id; + let mut locks = locks.borrow_mut(); + + for (i, &lock) in locks.iter().enumerate().rev() { + if lock == id { + locks.remove(i); + return; + } } - } - } - pub fn add_edge(&mut self, from: usize, to: usize) -> bool { - if from == to { - return false; - } - - let (_, out_edges) = self.add_node(from); - - out_edges.push(to); - - let (in_edges, _) = self.add_node(to); - - // No need for existence check assuming the datastructure is consistent - in_edges.push(from); - - true - } - - pub fn has_cycles(&self) -> bool { - let mut marks = HashSet::new(); - let mut temp = HashSet::new(); - - self.out_edges - .keys() - .copied() - .any(|node| !self.visit(node, &mut marks, &mut temp)) - } - - fn visit(&self, node: usize, marks: &mut HashSet, temp: &mut HashSet) -> bool { - if marks.contains(&node) { - return true; - } - - if !temp.insert(node) { - return false; - } - - if self.out_edges[&node] - .iter() - .copied() - .any(|node| !self.visit(node, marks, temp)) - { - return false; - } - - temp.remove(&node); - - marks.insert(node); - - true + panic!("Tried to drop lock for mutex {} but it wasn't held", id) + }); } } #[cfg(test)] mod tests { - use crate::DiGraph; + use std::sync::Arc; + use std::thread; + + use super::*; #[test] - fn test_digraph() { - let mut graph = DiGraph::default(); + fn test_next_mutex_id() { + let initial = next_mutex_id(); + let next = next_mutex_id(); - graph.add_edge(1, 2); - graph.add_edge(2, 3); - graph.add_edge(3, 4); - graph.add_edge(5, 2); + // Can't assert N + 1 because multiple threads running tests + assert!(initial < next); + } - assert!(!graph.has_cycles()); + #[test] + fn test_mutex_usage() { + let mutex = Arc::new(TracingMutex::new(())); + let mutex_clone = mutex.clone(); - graph.add_edge(4, 2); + let _guard = mutex.lock().unwrap(); - assert!(graph.has_cycles()); + // Now try to cause a blocking exception in another thread + let handle = thread::spawn(move || { + let result = mutex_clone.try_lock().unwrap_err(); - assert!(graph.remove_node(4)); + assert!(matches!(result, TryLockError::WouldBlock)); + }); - assert!(!graph.has_cycles()) + handle.join().unwrap(); } }