From 5638ebffd850239fe38888273b74b2c9ca6b45c0 Mon Sep 17 00:00:00 2001 From: Bert Peters Date: Sat, 20 Mar 2021 20:40:23 +0100 Subject: [PATCH] Simplify lock guard tracking Instead of implementing the tracking everywhere, create a RAII-guard that will track the state as it is held and dropped. --- src/lib.rs | 94 +++++++++++++++++++++++++++----------------------- src/stdsync.rs | 76 +++++++++++----------------------------- 2 files changed, 71 insertions(+), 99 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9bf3fa0..065ba58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,19 +56,38 @@ impl MutexID { .map(Self) .expect("Mutex ID wraparound happened, results unreliable") } + + /// Get a borrowed guard for this lock. + /// + /// This method adds checks adds this Mutex ID to the dependency graph as needed, and adds the + /// lock to the list of + /// + /// # Panics + /// + /// This method panics if the new dependency would introduce a cycle. + pub fn get_borrowed(self) -> BorrowedMutex { + let creates_cycle = HELD_LOCKS.with(|locks| { + if let Some(&previous) = locks.borrow().last() { + let mut graph = get_depedency_graph(); + + graph.add_edge(previous, self) && graph.has_cycles() + } else { + false + } + }); + + if creates_cycle { + // Panic without holding the lock to avoid needlessly poisoning it + panic!("Mutex order graph should not have cycles"); + } + + HELD_LOCKS.with(|locks| locks.borrow_mut().push(self)); + BorrowedMutex(self) + } } -/// Get a reference to the current dependency graph -fn get_depedency_graph() -> impl DerefMut { - DEPENDENCY_GRAPH - .lock() - .unwrap_or_else(PoisonError::into_inner) -} - -/// Register that a lock is currently held -fn register_lock(lock: MutexID) { - HELD_LOCKS.with(|locks| locks.borrow_mut().push(lock)) -} +#[derive(Debug)] +struct BorrowedMutex(MutexID); /// Drop a lock held by the current thread. /// @@ -76,44 +95,31 @@ fn register_lock(lock: MutexID) { /// /// This function panics if the lock did not appear to be handled by this thread. If that happens, /// that is an indication of a serious design flaw in this library. -fn drop_lock(id: MutexID) { - HELD_LOCKS.with(|locks| { - let mut locks = locks.borrow_mut(); +impl Drop for BorrowedMutex { + fn drop(&mut self) { + let id = self.0; - for (i, &lock) in locks.iter().enumerate().rev() { - if lock == id { - locks.remove(i); - return; + HELD_LOCKS.with(|locks| { + let mut locks = locks.borrow_mut(); + + for (i, &lock) in locks.iter().enumerate().rev() { + if lock == id { + locks.remove(i); + return; + } } - } - panic!("Tried to drop lock for mutex {:?} but it wasn't held", id) - }); + // Drop impls shouldn't panic but if this happens something is seriously broken. + unreachable!("Tried to drop lock for mutex {:?} but it wasn't held", id) + }); + } } -/// Register a dependency in the dependency graph -/// -/// If the dependency is new, check for cycles in the dependency graph. If not, there shouldn't be -/// any cycles so we don't need to check. -/// -/// # Panics -/// -/// This function panics if the new dependency would introduce a cycle. -fn register_dependency(lock: MutexID) { - let creates_cycle = HELD_LOCKS.with(|locks| { - if let Some(&previous) = locks.borrow().last() { - let mut graph = get_depedency_graph(); - - graph.add_edge(previous, lock) && graph.has_cycles() - } else { - false - } - }); - - if creates_cycle { - // Panic without holding the lock to avoid needlessly poisoning it - panic!("Mutex order graph should not have cycles"); - } +/// Get a reference to the current dependency graph +fn get_depedency_graph() -> impl DerefMut { + DEPENDENCY_GRAPH + .lock() + .unwrap_or_else(PoisonError::into_inner) } #[cfg(test)] diff --git a/src/stdsync.rs b/src/stdsync.rs index 1afed3f..d373f51 100644 --- a/src/stdsync.rs +++ b/src/stdsync.rs @@ -28,10 +28,8 @@ use std::sync::RwLockWriteGuard; use std::sync::TryLockError; use std::sync::TryLockResult; -use crate::drop_lock; use crate::get_depedency_graph; -use crate::register_dependency; -use crate::register_lock; +use crate::BorrowedMutex; use crate::MutexID; /// Wrapper for `std::sync::Mutex` @@ -45,7 +43,7 @@ pub struct TracingMutex { #[derive(Debug)] pub struct TracingMutexGuard<'a, T> { inner: MutexGuard<'a, T>, - mutex: &'a TracingMutex, + mutex: BorrowedMutex, } fn map_lockresult(result: LockResult, mapper: F) -> LockResult @@ -81,12 +79,11 @@ impl TracingMutex { #[track_caller] pub fn lock(&self) -> LockResult> { - register_dependency(self.id); + let mutex = self.id.get_borrowed(); let result = self.inner.lock(); - register_lock(self.id); let mapper = |guard| TracingMutexGuard { - mutex: self, + mutex, inner: guard, }; @@ -95,12 +92,11 @@ impl TracingMutex { #[track_caller] pub fn try_lock(&self) -> TryLockResult> { - register_dependency(self.id); + let mutex = self.id.get_borrowed(); let result = self.inner.try_lock(); - register_lock(self.id); let mapper = |guard| TracingMutexGuard { - mutex: self, + mutex, inner: guard, }; @@ -168,12 +164,6 @@ impl<'a, T: fmt::Display> fmt::Display for TracingMutexGuard<'a, T> { } } -impl<'a, T> Drop for TracingMutexGuard<'a, T> { - fn drop(&mut self) { - drop_lock(self.mutex.id); - } -} - /// Wrapper for `std::sync::RwLock` #[derive(Debug)] pub struct TracingRwLock { @@ -185,15 +175,15 @@ pub struct TracingRwLock { /// /// Please refer to [`TracingReadGuard`] and [`TracingWriteGuard`] for usable types. #[derive(Debug)] -pub struct TracingRwLockGuard<'a, T, L> { +pub struct TracingRwLockGuard { inner: L, - mutex: &'a TracingRwLock, + mutex: BorrowedMutex, } /// Wrapper around `std::sync::RwLockReadGuard`. -pub type TracingReadGuard<'a, T> = TracingRwLockGuard<'a, T, RwLockReadGuard<'a, T>>; +pub type TracingReadGuard<'a, T> = TracingRwLockGuard>; /// Wrapper around `std::sync::RwLockWriteGuard`. -pub type TracingWriteGuard<'a, T> = TracingRwLockGuard<'a, T, RwLockWriteGuard<'a, T>>; +pub type TracingWriteGuard<'a, T> = TracingRwLockGuard>; impl TracingRwLock { pub fn new(t: T) -> Self { @@ -205,52 +195,34 @@ impl TracingRwLock { #[track_caller] pub fn read(&self) -> LockResult> { - register_dependency(self.id); + let mutex = self.id.get_borrowed(); let result = self.inner.read(); - register_lock(self.id); - map_lockresult(result, |lock| TracingRwLockGuard { - inner: lock, - mutex: self, - }) + map_lockresult(result, |inner| TracingRwLockGuard { inner, mutex }) } #[track_caller] pub fn write(&self) -> LockResult> { - register_dependency(self.id); + let mutex = self.id.get_borrowed(); let result = self.inner.write(); - register_lock(self.id); - let id = self.id; - get_depedency_graph().remove_node(id); - map_lockresult(result, |lock| TracingRwLockGuard { - inner: lock, - mutex: self, - }) + map_lockresult(result, |inner| TracingRwLockGuard { inner, mutex }) } #[track_caller] pub fn try_read(&self) -> TryLockResult> { - register_dependency(self.id); + let mutex = self.id.get_borrowed(); let result = self.inner.try_read(); - register_lock(self.id); - map_trylockresult(result, |lock| TracingRwLockGuard { - inner: lock, - mutex: self, - }) + map_trylockresult(result, |inner| TracingRwLockGuard { inner, mutex }) } #[track_caller] pub fn try_write(&self) -> TryLockResult> { - register_dependency(self.id); + let mutex = self.id.get_borrowed(); let result = self.inner.try_write(); - register_lock(self.id); - map_trylockresult(result, |lock| TracingRwLockGuard { - inner: lock, - mutex: self, - }) + map_trylockresult(result, |inner| TracingRwLockGuard { inner, mutex }) } pub fn get_mut(&mut self) -> LockResult<&mut T> { @@ -288,13 +260,7 @@ where } } -impl<'a, T, L> Drop for TracingRwLockGuard<'a, T, L> { - fn drop(&mut self) { - drop_lock(self.mutex.id) - } -} - -impl<'a, T, L> Deref for TracingRwLockGuard<'a, T, L> +impl Deref for TracingRwLockGuard where L: Deref, { @@ -305,7 +271,7 @@ where } } -impl<'a, T, L> DerefMut for TracingRwLockGuard<'a, T, L> +impl DerefMut for TracingRwLockGuard where L: Deref + DerefMut, { @@ -357,7 +323,7 @@ mod tests { assert!(matches!(write_result, TryLockError::WouldBlock)); // Should be able to get a read lock just fine. - let _ = rwlock_clone.read().unwrap(); + let _read_lock = rwlock_clone.read().unwrap(); }); handle.join().unwrap();