diff --git a/src/stdsync.rs b/src/stdsync.rs index dedd3ce..2d3dd5f 100644 --- a/src/stdsync.rs +++ b/src/stdsync.rs @@ -1,11 +1,30 @@ -//! Tracing Mutex implementations for `std::sync`. +//! Tracing mutex wrappers for types found in `std::sync`. +//! +//! This module provides wrappers for `std::sync` primitives with exactly the same API and +//! functionality as their counterparts, with the exception that their acquisition order is +//! tracked. +//! +//! ```rust +//! # use tracing_mutex::stdsync::TracingMutex; +//! # use tracing_mutex::stdsync::TracingRwLock; +//! let mutex = TracingMutex::new(()); +//! mutex.lock().unwrap(); +//! +//! let rwlock = TracingRwLock::new(()); +//! rwlock.read().unwrap(); +//! ``` use std::fmt; +use std::mem; use std::ops::Deref; use std::ops::DerefMut; +use std::ptr; use std::sync::LockResult; use std::sync::Mutex; use std::sync::MutexGuard; use std::sync::PoisonError; +use std::sync::RwLock; +use std::sync::RwLockReadGuard; +use std::sync::RwLockWriteGuard; use std::sync::TryLockError; use std::sync::TryLockResult; @@ -15,14 +34,14 @@ use crate::next_mutex_id; use crate::register_dependency; use crate::register_lock; -/// Wrapper for std::sync::Mutex +/// Wrapper for `std::sync::Mutex` #[derive(Debug)] pub struct TracingMutex { inner: Mutex, id: usize, } -/// Wrapper for std::sync::MutexGuard +/// Wrapper for `std::sync::MutexGuard` #[derive(Debug)] pub struct TracingMutexGuard<'a, T> { inner: MutexGuard<'a, T>, @@ -99,6 +118,20 @@ impl TracingMutex { pub fn get_mut(&mut self) -> LockResult<&mut T> { self.inner.get_mut() } + + pub fn into_inner(self) -> LockResult { + self.deregister(); + + // Safety: we forget the original immediately after + let inner = unsafe { ptr::read(&self.inner) }; + mem::forget(self); + + inner.into_inner() + } + + fn deregister(&self) { + get_depedency_graph().remove_node(self.id); + } } impl Default for TracingMutex { @@ -115,9 +148,7 @@ impl From for TracingMutex { impl Drop for TracingMutex { fn drop(&mut self) { - let id = self.id; - - get_depedency_graph().remove_node(id); + self.deregister(); } } @@ -147,6 +178,150 @@ impl<'a, T> Drop for TracingMutexGuard<'a, T> { } } +/// Wrapper for `std::sync::RwLock` +#[derive(Debug)] +pub struct TracingRwLock { + inner: RwLock, + id: usize, +} + +/// Hybrid wrapper for both `std::sync::RwLockReadGuard` and `std::sync::RwLockWriteGuard`. +/// +/// Please refer to [`TracingReadGuard`] and [`TracingWriteGuard`] for usable types. +#[derive(Debug)] +pub struct TracingRwLockGuard<'a, T, L> { + inner: L, + mutex: &'a TracingRwLock, +} + +/// Wrapper around `std::sync::RwLockReadGuard`. +pub type TracingReadGuard<'a, T> = TracingRwLockGuard<'a, T, RwLockReadGuard<'a, T>>; +/// Wrapper around `std::sync::RwLockWriteGuard`. +pub type TracingWriteGuard<'a, T> = TracingRwLockGuard<'a, T, RwLockWriteGuard<'a, T>>; + +impl TracingRwLock { + pub fn new(t: T) -> Self { + Self { + inner: RwLock::new(t), + id: next_mutex_id(), + } + } + + pub fn get_id(&self) -> usize { + self.id + } + + #[track_caller] + pub fn read(&self) -> LockResult> { + register_dependency(self.id); + let result = self.inner.read(); + register_lock(self.id); + + map_lockresult(result, |lock| TracingRwLockGuard { + inner: lock, + mutex: self, + }) + } + + #[track_caller] + pub fn write(&self) -> LockResult> { + register_dependency(self.id); + let result = self.inner.write(); + register_lock(self.id); + let id = self.id; + + get_depedency_graph().remove_node(id); + map_lockresult(result, |lock| TracingRwLockGuard { + inner: lock, + mutex: self, + }) + } + + #[track_caller] + pub fn try_read(&self) -> TryLockResult> { + register_dependency(self.id); + let result = self.inner.try_read(); + register_lock(self.id); + + map_trylockresult(result, |lock| TracingRwLockGuard { + inner: lock, + mutex: self, + }) + } + + #[track_caller] + pub fn try_write(&self) -> TryLockResult> { + register_dependency(self.id); + let result = self.inner.try_write(); + register_lock(self.id); + + map_trylockresult(result, |lock| TracingRwLockGuard { + inner: lock, + mutex: self, + }) + } + + pub fn get_mut(&mut self) -> LockResult<&mut T> { + self.inner.get_mut() + } + + pub fn into_inner(self) -> LockResult { + self.deregister(); + + // Grab our contents and then forget ourselves + // Safety: we immediately forget the mutex after copying + let inner = unsafe { ptr::read(&self.inner) }; + mem::forget(self); + + inner.into_inner() + } + + fn deregister(&self) { + get_depedency_graph().remove_node(self.id); + } +} + +impl Drop for TracingRwLock { + fn drop(&mut self) { + self.deregister(); + } +} + +impl Default for TracingRwLock +where + T: Default, +{ + fn default() -> Self { + Self::new(T::default()) + } +} + +impl<'a, T, L> Drop for TracingRwLockGuard<'a, T, L> { + fn drop(&mut self) { + drop_lock(self.mutex.id) + } +} + +impl<'a, T, L> Deref for TracingRwLockGuard<'a, T, L> +where + L: Deref, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + self.inner.deref() + } +} + +impl<'a, T, L> DerefMut for TracingRwLockGuard<'a, T, L> +where + L: Deref + DerefMut, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner.deref_mut() + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -174,6 +349,28 @@ mod tests { handle.join().unwrap(); } + #[test] + fn test_rwlock_usage() { + let _graph_lock = GRAPH_MUTEX.lock(); + + let rwlock = Arc::new(TracingRwLock::new(())); + let rwlock_clone = rwlock.clone(); + + let _read_lock = rwlock.read().unwrap(); + + // Now try to cause a blocking exception in another thread + let handle = thread::spawn(move || { + let write_result = rwlock_clone.try_write().unwrap_err(); + + assert!(matches!(write_result, TryLockError::WouldBlock)); + + // Should be able to get a read lock just fine. + let _ = rwlock_clone.read().unwrap(); + }); + + handle.join().unwrap(); + } + #[test] #[should_panic(expected = "Mutex order graph should not have cycles")] fn test_detect_cycle() {