diff --git a/Cargo.toml b/Cargo.toml index 5dcaeb9..87d4e32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,4 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +lazy_static = "1" diff --git a/src/lib.rs b/src/lib.rs index 3154a3d..6d85d95 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,10 @@ use std::sync::PoisonError; use std::sync::TryLockError; use std::sync::TryLockResult; +use lazy_static::lazy_static; + +use crate::graph::DiGraph; + mod graph; /// Counter for Mutex IDs. Atomic avoids the need for locking. @@ -25,6 +29,10 @@ thread_local! { static HELD_LOCKS: RefCell> = RefCell::new(Vec::new()); } +lazy_static! { + static ref DEPENDENCY_GRAPH: Mutex = Default::default(); +} + /// Wrapper for std::sync::Mutex #[derive(Debug)] pub struct TracingMutex { @@ -45,6 +53,57 @@ fn next_mutex_id() -> usize { .expect("Mutex ID wraparound happened, results unreliable") } +/// Get a reference to the current dependency graph +fn get_depedency_graph() -> impl DerefMut { + DEPENDENCY_GRAPH + .lock() + .unwrap_or_else(PoisonError::into_inner) +} + +/// Register that a lock is currently held +fn register_lock(lock: usize) { + HELD_LOCKS.with(|locks| locks.borrow_mut().push(lock)) +} + +/// Register a dependency in the dependency graph +/// +/// If the dependency is new, check for cycles in the dependency graph. If not, there shouldn't be +/// any cycles so we don't need to check. +fn register_dependency(lock: usize) { + HELD_LOCKS.with(|locks| { + if let Some(&previous) = locks.borrow().last() { + let mut graph = get_depedency_graph(); + + if graph.add_edge(previous, lock) && graph.has_cycles() { + panic!("Mutex order graph should not have cycles"); + } + } + }) +} + +fn map_lockresult(result: LockResult, mapper: F) -> LockResult +where + F: FnOnce(I) -> T, +{ + match result { + Ok(inner) => Ok(mapper(inner)), + Err(poisoned) => Err(PoisonError::new(mapper(poisoned.into_inner()))), + } +} + +fn map_trylockresult(result: TryLockResult, mapper: F) -> TryLockResult +where + F: FnOnce(I) -> T, +{ + match result { + Ok(inner) => Ok(mapper(inner)), + Err(TryLockError::WouldBlock) => Err(TryLockError::WouldBlock), + Err(TryLockError::Poisoned(poisoned)) => { + Err(PoisonError::new(mapper(poisoned.into_inner())).into()) + } + } +} + impl TracingMutex { pub fn new(t: T) -> Self { Self { @@ -53,33 +112,32 @@ impl TracingMutex { } } + #[track_caller] pub fn lock(&self) -> LockResult> { + register_dependency(self.id); let result = self.inner.lock(); - self.register_lock(); + register_lock(self.id); let mapper = |guard| TracingMutexGuard { mutex: self, inner: guard, }; - result - .map(mapper) - .map_err(|poison| PoisonError::new(mapper(poison.into_inner()))) + map_lockresult(result, mapper) } + #[track_caller] pub fn try_lock(&self) -> TryLockResult> { + register_dependency(self.id); let result = self.inner.try_lock(); - self.register_lock(); + register_lock(self.id); let mapper = |guard| TracingMutexGuard { mutex: self, inner: guard, }; - result.map(mapper).map_err(|error| match error { - TryLockError::Poisoned(poison) => PoisonError::new(mapper(poison.into_inner())).into(), - TryLockError::WouldBlock => TryLockError::WouldBlock, - }) + map_trylockresult(result, mapper) } pub fn get_id(&self) -> usize { @@ -93,10 +151,6 @@ impl TracingMutex { pub fn get_mut(&mut self) -> LockResult<&mut T> { self.inner.get_mut() } - - fn register_lock(&self) { - HELD_LOCKS.with(|locks| locks.borrow_mut().push(self.id)) - } } impl Default for TracingMutex { @@ -111,6 +165,14 @@ impl From for TracingMutex { } } +impl Drop for TracingMutex { + fn drop(&mut self) { + let id = self.id; + + get_depedency_graph().remove_node(id); + } +} + impl<'a, T> Deref for TracingMutexGuard<'a, T> { type Target = MutexGuard<'a, T>; @@ -156,6 +218,11 @@ mod tests { use super::*; + lazy_static! { + /// Mutex to isolate tests manipulating the global mutex graph + static ref GRAPH_MUTEX: Mutex<()> = Mutex::new(()); + } + #[test] fn test_next_mutex_id() { let initial = next_mutex_id(); @@ -167,6 +234,8 @@ mod tests { #[test] fn test_mutex_usage() { + let _graph_lock = GRAPH_MUTEX.lock(); + let mutex = Arc::new(TracingMutex::new(())); let mutex_clone = mutex.clone(); @@ -181,4 +250,21 @@ mod tests { handle.join().unwrap(); } + + #[test] + #[should_panic(expected = "Mutex order graph should not have cycles")] + fn test_detect_cycle() { + let _graph_lock = GRAPH_MUTEX.lock(); + + let a = TracingMutex::new(()); + let b = TracingMutex::new(()); + + let hold_a = a.lock().unwrap(); + let _ = b.lock(); + + drop(hold_a); + + let _hold_b = b.lock().unwrap(); + let _ = a.lock(); + } }