Simplify lock guard tracking

Instead of implementing the tracking everywhere, create a RAII-guard
that will track the state as it is held and dropped.
This commit is contained in:
2021-03-20 20:40:23 +01:00
parent 9b56deac26
commit 5638ebffd8
2 changed files with 71 additions and 99 deletions

View File

@@ -56,19 +56,38 @@ impl MutexID {
.map(Self) .map(Self)
.expect("Mutex ID wraparound happened, results unreliable") .expect("Mutex ID wraparound happened, results unreliable")
} }
/// Get a borrowed guard for this lock.
///
/// This method adds checks adds this Mutex ID to the dependency graph as needed, and adds the
/// lock to the list of
///
/// # Panics
///
/// This method panics if the new dependency would introduce a cycle.
pub fn get_borrowed(self) -> BorrowedMutex {
let creates_cycle = HELD_LOCKS.with(|locks| {
if let Some(&previous) = locks.borrow().last() {
let mut graph = get_depedency_graph();
graph.add_edge(previous, self) && graph.has_cycles()
} else {
false
}
});
if creates_cycle {
// Panic without holding the lock to avoid needlessly poisoning it
panic!("Mutex order graph should not have cycles");
}
HELD_LOCKS.with(|locks| locks.borrow_mut().push(self));
BorrowedMutex(self)
}
} }
/// Get a reference to the current dependency graph #[derive(Debug)]
fn get_depedency_graph() -> impl DerefMut<Target = DiGraph> { struct BorrowedMutex(MutexID);
DEPENDENCY_GRAPH
.lock()
.unwrap_or_else(PoisonError::into_inner)
}
/// Register that a lock is currently held
fn register_lock(lock: MutexID) {
HELD_LOCKS.with(|locks| locks.borrow_mut().push(lock))
}
/// Drop a lock held by the current thread. /// Drop a lock held by the current thread.
/// ///
@@ -76,7 +95,10 @@ fn register_lock(lock: MutexID) {
/// ///
/// This function panics if the lock did not appear to be handled by this thread. If that happens, /// This function panics if the lock did not appear to be handled by this thread. If that happens,
/// that is an indication of a serious design flaw in this library. /// that is an indication of a serious design flaw in this library.
fn drop_lock(id: MutexID) { impl Drop for BorrowedMutex {
fn drop(&mut self) {
let id = self.0;
HELD_LOCKS.with(|locks| { HELD_LOCKS.with(|locks| {
let mut locks = locks.borrow_mut(); let mut locks = locks.borrow_mut();
@@ -87,33 +109,17 @@ fn drop_lock(id: MutexID) {
} }
} }
panic!("Tried to drop lock for mutex {:?} but it wasn't held", id) // Drop impls shouldn't panic but if this happens something is seriously broken.
unreachable!("Tried to drop lock for mutex {:?} but it wasn't held", id)
}); });
}
} }
/// Register a dependency in the dependency graph /// Get a reference to the current dependency graph
/// fn get_depedency_graph() -> impl DerefMut<Target = DiGraph> {
/// If the dependency is new, check for cycles in the dependency graph. If not, there shouldn't be DEPENDENCY_GRAPH
/// any cycles so we don't need to check. .lock()
/// .unwrap_or_else(PoisonError::into_inner)
/// # Panics
///
/// This function panics if the new dependency would introduce a cycle.
fn register_dependency(lock: MutexID) {
let creates_cycle = HELD_LOCKS.with(|locks| {
if let Some(&previous) = locks.borrow().last() {
let mut graph = get_depedency_graph();
graph.add_edge(previous, lock) && graph.has_cycles()
} else {
false
}
});
if creates_cycle {
// Panic without holding the lock to avoid needlessly poisoning it
panic!("Mutex order graph should not have cycles");
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -28,10 +28,8 @@ use std::sync::RwLockWriteGuard;
use std::sync::TryLockError; use std::sync::TryLockError;
use std::sync::TryLockResult; use std::sync::TryLockResult;
use crate::drop_lock;
use crate::get_depedency_graph; use crate::get_depedency_graph;
use crate::register_dependency; use crate::BorrowedMutex;
use crate::register_lock;
use crate::MutexID; use crate::MutexID;
/// Wrapper for `std::sync::Mutex` /// Wrapper for `std::sync::Mutex`
@@ -45,7 +43,7 @@ pub struct TracingMutex<T> {
#[derive(Debug)] #[derive(Debug)]
pub struct TracingMutexGuard<'a, T> { pub struct TracingMutexGuard<'a, T> {
inner: MutexGuard<'a, T>, inner: MutexGuard<'a, T>,
mutex: &'a TracingMutex<T>, mutex: BorrowedMutex,
} }
fn map_lockresult<T, I, F>(result: LockResult<I>, mapper: F) -> LockResult<T> fn map_lockresult<T, I, F>(result: LockResult<I>, mapper: F) -> LockResult<T>
@@ -81,12 +79,11 @@ impl<T> TracingMutex<T> {
#[track_caller] #[track_caller]
pub fn lock(&self) -> LockResult<TracingMutexGuard<T>> { pub fn lock(&self) -> LockResult<TracingMutexGuard<T>> {
register_dependency(self.id); let mutex = self.id.get_borrowed();
let result = self.inner.lock(); let result = self.inner.lock();
register_lock(self.id);
let mapper = |guard| TracingMutexGuard { let mapper = |guard| TracingMutexGuard {
mutex: self, mutex,
inner: guard, inner: guard,
}; };
@@ -95,12 +92,11 @@ impl<T> TracingMutex<T> {
#[track_caller] #[track_caller]
pub fn try_lock(&self) -> TryLockResult<TracingMutexGuard<T>> { pub fn try_lock(&self) -> TryLockResult<TracingMutexGuard<T>> {
register_dependency(self.id); let mutex = self.id.get_borrowed();
let result = self.inner.try_lock(); let result = self.inner.try_lock();
register_lock(self.id);
let mapper = |guard| TracingMutexGuard { let mapper = |guard| TracingMutexGuard {
mutex: self, mutex,
inner: guard, inner: guard,
}; };
@@ -168,12 +164,6 @@ impl<'a, T: fmt::Display> fmt::Display for TracingMutexGuard<'a, T> {
} }
} }
impl<'a, T> Drop for TracingMutexGuard<'a, T> {
fn drop(&mut self) {
drop_lock(self.mutex.id);
}
}
/// Wrapper for `std::sync::RwLock` /// Wrapper for `std::sync::RwLock`
#[derive(Debug)] #[derive(Debug)]
pub struct TracingRwLock<T> { pub struct TracingRwLock<T> {
@@ -185,15 +175,15 @@ pub struct TracingRwLock<T> {
/// ///
/// Please refer to [`TracingReadGuard`] and [`TracingWriteGuard`] for usable types. /// Please refer to [`TracingReadGuard`] and [`TracingWriteGuard`] for usable types.
#[derive(Debug)] #[derive(Debug)]
pub struct TracingRwLockGuard<'a, T, L> { pub struct TracingRwLockGuard<L> {
inner: L, inner: L,
mutex: &'a TracingRwLock<T>, mutex: BorrowedMutex,
} }
/// Wrapper around `std::sync::RwLockReadGuard`. /// Wrapper around `std::sync::RwLockReadGuard`.
pub type TracingReadGuard<'a, T> = TracingRwLockGuard<'a, T, RwLockReadGuard<'a, T>>; pub type TracingReadGuard<'a, T> = TracingRwLockGuard<RwLockReadGuard<'a, T>>;
/// Wrapper around `std::sync::RwLockWriteGuard`. /// Wrapper around `std::sync::RwLockWriteGuard`.
pub type TracingWriteGuard<'a, T> = TracingRwLockGuard<'a, T, RwLockWriteGuard<'a, T>>; pub type TracingWriteGuard<'a, T> = TracingRwLockGuard<RwLockWriteGuard<'a, T>>;
impl<T> TracingRwLock<T> { impl<T> TracingRwLock<T> {
pub fn new(t: T) -> Self { pub fn new(t: T) -> Self {
@@ -205,52 +195,34 @@ impl<T> TracingRwLock<T> {
#[track_caller] #[track_caller]
pub fn read(&self) -> LockResult<TracingReadGuard<T>> { pub fn read(&self) -> LockResult<TracingReadGuard<T>> {
register_dependency(self.id); let mutex = self.id.get_borrowed();
let result = self.inner.read(); let result = self.inner.read();
register_lock(self.id);
map_lockresult(result, |lock| TracingRwLockGuard { map_lockresult(result, |inner| TracingRwLockGuard { inner, mutex })
inner: lock,
mutex: self,
})
} }
#[track_caller] #[track_caller]
pub fn write(&self) -> LockResult<TracingWriteGuard<T>> { pub fn write(&self) -> LockResult<TracingWriteGuard<T>> {
register_dependency(self.id); let mutex = self.id.get_borrowed();
let result = self.inner.write(); let result = self.inner.write();
register_lock(self.id);
let id = self.id;
get_depedency_graph().remove_node(id); map_lockresult(result, |inner| TracingRwLockGuard { inner, mutex })
map_lockresult(result, |lock| TracingRwLockGuard {
inner: lock,
mutex: self,
})
} }
#[track_caller] #[track_caller]
pub fn try_read(&self) -> TryLockResult<TracingReadGuard<T>> { pub fn try_read(&self) -> TryLockResult<TracingReadGuard<T>> {
register_dependency(self.id); let mutex = self.id.get_borrowed();
let result = self.inner.try_read(); let result = self.inner.try_read();
register_lock(self.id);
map_trylockresult(result, |lock| TracingRwLockGuard { map_trylockresult(result, |inner| TracingRwLockGuard { inner, mutex })
inner: lock,
mutex: self,
})
} }
#[track_caller] #[track_caller]
pub fn try_write(&self) -> TryLockResult<TracingWriteGuard<T>> { pub fn try_write(&self) -> TryLockResult<TracingWriteGuard<T>> {
register_dependency(self.id); let mutex = self.id.get_borrowed();
let result = self.inner.try_write(); let result = self.inner.try_write();
register_lock(self.id);
map_trylockresult(result, |lock| TracingRwLockGuard { map_trylockresult(result, |inner| TracingRwLockGuard { inner, mutex })
inner: lock,
mutex: self,
})
} }
pub fn get_mut(&mut self) -> LockResult<&mut T> { pub fn get_mut(&mut self) -> LockResult<&mut T> {
@@ -288,13 +260,7 @@ where
} }
} }
impl<'a, T, L> Drop for TracingRwLockGuard<'a, T, L> { impl<L, T> Deref for TracingRwLockGuard<L>
fn drop(&mut self) {
drop_lock(self.mutex.id)
}
}
impl<'a, T, L> Deref for TracingRwLockGuard<'a, T, L>
where where
L: Deref<Target = T>, L: Deref<Target = T>,
{ {
@@ -305,7 +271,7 @@ where
} }
} }
impl<'a, T, L> DerefMut for TracingRwLockGuard<'a, T, L> impl<T, L> DerefMut for TracingRwLockGuard<L>
where where
L: Deref<Target = T> + DerefMut, L: Deref<Target = T> + DerefMut,
{ {
@@ -357,7 +323,7 @@ mod tests {
assert!(matches!(write_result, TryLockError::WouldBlock)); assert!(matches!(write_result, TryLockError::WouldBlock));
// Should be able to get a read lock just fine. // Should be able to get a read lock just fine.
let _ = rwlock_clone.read().unwrap(); let _read_lock = rwlock_clone.read().unwrap();
}); });
handle.join().unwrap(); handle.join().unwrap();