diff --git a/src/lib.rs b/src/lib.rs index 6d85d95..71d7642 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,22 +1,16 @@ use std::cell::RefCell; -use std::fmt; -use std::fmt::Display; -use std::ops::Deref; use std::ops::DerefMut; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; -use std::sync::LockResult; use std::sync::Mutex; -use std::sync::MutexGuard; use std::sync::PoisonError; -use std::sync::TryLockError; -use std::sync::TryLockResult; use lazy_static::lazy_static; use crate::graph::DiGraph; mod graph; +pub mod stdsync; /// Counter for Mutex IDs. Atomic avoids the need for locking. static ID_SEQUENCE: AtomicUsize = AtomicUsize::new(0); @@ -33,20 +27,6 @@ lazy_static! { static ref DEPENDENCY_GRAPH: Mutex = Default::default(); } -/// Wrapper for std::sync::Mutex -#[derive(Debug)] -pub struct TracingMutex { - inner: Mutex, - id: usize, -} - -/// Wrapper for std::sync::MutexGuard -#[derive(Debug)] -pub struct TracingMutexGuard<'a, T> { - inner: MutexGuard<'a, T>, - mutex: &'a TracingMutex, -} - fn next_mutex_id() -> usize { ID_SEQUENCE .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |id| id.checked_add(1)) @@ -65,6 +45,21 @@ fn register_lock(lock: usize) { HELD_LOCKS.with(|locks| locks.borrow_mut().push(lock)) } +fn drop_lock(id: usize) { + HELD_LOCKS.with(|locks| { + let mut locks = locks.borrow_mut(); + + for (i, &lock) in locks.iter().enumerate().rev() { + if lock == id { + locks.remove(i); + return; + } + } + + panic!("Tried to drop lock for mutex {} but it wasn't held", id) + }); +} + /// Register a dependency in the dependency graph /// /// If the dependency is new, check for cycles in the dependency graph. If not, there shouldn't be @@ -81,146 +76,13 @@ fn register_dependency(lock: usize) { }) } -fn map_lockresult(result: LockResult, mapper: F) -> LockResult -where - F: FnOnce(I) -> T, -{ - match result { - Ok(inner) => Ok(mapper(inner)), - Err(poisoned) => Err(PoisonError::new(mapper(poisoned.into_inner()))), - } -} - -fn map_trylockresult(result: TryLockResult, mapper: F) -> TryLockResult -where - F: FnOnce(I) -> T, -{ - match result { - Ok(inner) => Ok(mapper(inner)), - Err(TryLockError::WouldBlock) => Err(TryLockError::WouldBlock), - Err(TryLockError::Poisoned(poisoned)) => { - Err(PoisonError::new(mapper(poisoned.into_inner())).into()) - } - } -} - -impl TracingMutex { - pub fn new(t: T) -> Self { - Self { - inner: Mutex::new(t), - id: next_mutex_id(), - } - } - - #[track_caller] - pub fn lock(&self) -> LockResult> { - register_dependency(self.id); - let result = self.inner.lock(); - register_lock(self.id); - - let mapper = |guard| TracingMutexGuard { - mutex: self, - inner: guard, - }; - - map_lockresult(result, mapper) - } - - #[track_caller] - pub fn try_lock(&self) -> TryLockResult> { - register_dependency(self.id); - let result = self.inner.try_lock(); - register_lock(self.id); - - let mapper = |guard| TracingMutexGuard { - mutex: self, - inner: guard, - }; - - map_trylockresult(result, mapper) - } - - pub fn get_id(&self) -> usize { - self.id - } - - pub fn is_poisoned(&self) -> bool { - self.inner.is_poisoned() - } - - pub fn get_mut(&mut self) -> LockResult<&mut T> { - self.inner.get_mut() - } -} - -impl Default for TracingMutex { - fn default() -> Self { - Self::new(T::default()) - } -} - -impl From for TracingMutex { - fn from(t: T) -> Self { - Self::new(t) - } -} - -impl Drop for TracingMutex { - fn drop(&mut self) { - let id = self.id; - - get_depedency_graph().remove_node(id); - } -} - -impl<'a, T> Deref for TracingMutexGuard<'a, T> { - type Target = MutexGuard<'a, T>; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl<'a, T> DerefMut for TracingMutexGuard<'a, T> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -impl<'a, T: Display> fmt::Display for TracingMutexGuard<'a, T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.inner.fmt(f) - } -} - -impl<'a, T> Drop for TracingMutexGuard<'a, T> { - fn drop(&mut self) { - HELD_LOCKS.with(|locks| { - let id = self.mutex.id; - let mut locks = locks.borrow_mut(); - - for (i, &lock) in locks.iter().enumerate().rev() { - if lock == id { - locks.remove(i); - return; - } - } - - panic!("Tried to drop lock for mutex {} but it wasn't held", id) - }); - } -} - #[cfg(test)] mod tests { - use std::sync::Arc; - use std::thread; - use super::*; lazy_static! { /// Mutex to isolate tests manipulating the global mutex graph - static ref GRAPH_MUTEX: Mutex<()> = Mutex::new(()); + pub(crate) static ref GRAPH_MUTEX: Mutex<()> = Mutex::new(()); } #[test] @@ -231,40 +93,4 @@ mod tests { // Can't assert N + 1 because multiple threads running tests assert!(initial < next); } - - #[test] - fn test_mutex_usage() { - let _graph_lock = GRAPH_MUTEX.lock(); - - let mutex = Arc::new(TracingMutex::new(())); - let mutex_clone = mutex.clone(); - - let _guard = mutex.lock().unwrap(); - - // Now try to cause a blocking exception in another thread - let handle = thread::spawn(move || { - let result = mutex_clone.try_lock().unwrap_err(); - - assert!(matches!(result, TryLockError::WouldBlock)); - }); - - handle.join().unwrap(); - } - - #[test] - #[should_panic(expected = "Mutex order graph should not have cycles")] - fn test_detect_cycle() { - let _graph_lock = GRAPH_MUTEX.lock(); - - let a = TracingMutex::new(()); - let b = TracingMutex::new(()); - - let hold_a = a.lock().unwrap(); - let _ = b.lock(); - - drop(hold_a); - - let _hold_b = b.lock().unwrap(); - let _ = a.lock(); - } } diff --git a/src/stdsync.rs b/src/stdsync.rs new file mode 100644 index 0000000..dedd3ce --- /dev/null +++ b/src/stdsync.rs @@ -0,0 +1,193 @@ +//! Tracing Mutex implementations for `std::sync`. +use std::fmt; +use std::ops::Deref; +use std::ops::DerefMut; +use std::sync::LockResult; +use std::sync::Mutex; +use std::sync::MutexGuard; +use std::sync::PoisonError; +use std::sync::TryLockError; +use std::sync::TryLockResult; + +use crate::drop_lock; +use crate::get_depedency_graph; +use crate::next_mutex_id; +use crate::register_dependency; +use crate::register_lock; + +/// Wrapper for std::sync::Mutex +#[derive(Debug)] +pub struct TracingMutex { + inner: Mutex, + id: usize, +} + +/// Wrapper for std::sync::MutexGuard +#[derive(Debug)] +pub struct TracingMutexGuard<'a, T> { + inner: MutexGuard<'a, T>, + mutex: &'a TracingMutex, +} + +fn map_lockresult(result: LockResult, mapper: F) -> LockResult +where + F: FnOnce(I) -> T, +{ + match result { + Ok(inner) => Ok(mapper(inner)), + Err(poisoned) => Err(PoisonError::new(mapper(poisoned.into_inner()))), + } +} + +fn map_trylockresult(result: TryLockResult, mapper: F) -> TryLockResult +where + F: FnOnce(I) -> T, +{ + match result { + Ok(inner) => Ok(mapper(inner)), + Err(TryLockError::WouldBlock) => Err(TryLockError::WouldBlock), + Err(TryLockError::Poisoned(poisoned)) => { + Err(PoisonError::new(mapper(poisoned.into_inner())).into()) + } + } +} + +impl TracingMutex { + pub fn new(t: T) -> Self { + Self { + inner: Mutex::new(t), + id: next_mutex_id(), + } + } + + #[track_caller] + pub fn lock(&self) -> LockResult> { + register_dependency(self.id); + let result = self.inner.lock(); + register_lock(self.id); + + let mapper = |guard| TracingMutexGuard { + mutex: self, + inner: guard, + }; + + map_lockresult(result, mapper) + } + + #[track_caller] + pub fn try_lock(&self) -> TryLockResult> { + register_dependency(self.id); + let result = self.inner.try_lock(); + register_lock(self.id); + + let mapper = |guard| TracingMutexGuard { + mutex: self, + inner: guard, + }; + + map_trylockresult(result, mapper) + } + + pub fn get_id(&self) -> usize { + self.id + } + + pub fn is_poisoned(&self) -> bool { + self.inner.is_poisoned() + } + + pub fn get_mut(&mut self) -> LockResult<&mut T> { + self.inner.get_mut() + } +} + +impl Default for TracingMutex { + fn default() -> Self { + Self::new(T::default()) + } +} + +impl From for TracingMutex { + fn from(t: T) -> Self { + Self::new(t) + } +} + +impl Drop for TracingMutex { + fn drop(&mut self) { + let id = self.id; + + get_depedency_graph().remove_node(id); + } +} + +impl<'a, T> Deref for TracingMutexGuard<'a, T> { + type Target = MutexGuard<'a, T>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl<'a, T> DerefMut for TracingMutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl<'a, T: fmt::Display> fmt::Display for TracingMutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} + +impl<'a, T> Drop for TracingMutexGuard<'a, T> { + fn drop(&mut self) { + drop_lock(self.mutex.id); + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::thread; + + use super::*; + use crate::tests::GRAPH_MUTEX; + + #[test] + fn test_mutex_usage() { + let _graph_lock = GRAPH_MUTEX.lock(); + + let mutex = Arc::new(TracingMutex::new(())); + let mutex_clone = mutex.clone(); + + let _guard = mutex.lock().unwrap(); + + // Now try to cause a blocking exception in another thread + let handle = thread::spawn(move || { + let result = mutex_clone.try_lock().unwrap_err(); + + assert!(matches!(result, TryLockError::WouldBlock)); + }); + + handle.join().unwrap(); + } + + #[test] + #[should_panic(expected = "Mutex order graph should not have cycles")] + fn test_detect_cycle() { + let _graph_lock = GRAPH_MUTEX.lock(); + + let a = TracingMutex::new(()); + let b = TracingMutex::new(()); + + let hold_a = a.lock().unwrap(); + let _ = b.lock(); + + drop(hold_a); + + let _hold_b = b.lock().unwrap(); + let _ = a.lock(); + } +}