Simplify lock guard tracking

Instead of implementing the tracking everywhere, create a RAII-guard
that will track the state as it is held and dropped.
This commit is contained in:
2021-03-20 20:40:23 +01:00
parent 9b56deac26
commit 5638ebffd8
2 changed files with 71 additions and 99 deletions

View File

@@ -56,19 +56,38 @@ impl MutexID {
.map(Self)
.expect("Mutex ID wraparound happened, results unreliable")
}
/// Get a borrowed guard for this lock.
///
/// This method adds checks adds this Mutex ID to the dependency graph as needed, and adds the
/// lock to the list of
///
/// # Panics
///
/// This method panics if the new dependency would introduce a cycle.
pub fn get_borrowed(self) -> BorrowedMutex {
let creates_cycle = HELD_LOCKS.with(|locks| {
if let Some(&previous) = locks.borrow().last() {
let mut graph = get_depedency_graph();
graph.add_edge(previous, self) && graph.has_cycles()
} else {
false
}
});
if creates_cycle {
// Panic without holding the lock to avoid needlessly poisoning it
panic!("Mutex order graph should not have cycles");
}
HELD_LOCKS.with(|locks| locks.borrow_mut().push(self));
BorrowedMutex(self)
}
}
/// Get a reference to the current dependency graph
fn get_depedency_graph() -> impl DerefMut<Target = DiGraph> {
DEPENDENCY_GRAPH
.lock()
.unwrap_or_else(PoisonError::into_inner)
}
/// Register that a lock is currently held
fn register_lock(lock: MutexID) {
HELD_LOCKS.with(|locks| locks.borrow_mut().push(lock))
}
#[derive(Debug)]
struct BorrowedMutex(MutexID);
/// Drop a lock held by the current thread.
///
@@ -76,7 +95,10 @@ fn register_lock(lock: MutexID) {
///
/// This function panics if the lock did not appear to be handled by this thread. If that happens,
/// that is an indication of a serious design flaw in this library.
fn drop_lock(id: MutexID) {
impl Drop for BorrowedMutex {
fn drop(&mut self) {
let id = self.0;
HELD_LOCKS.with(|locks| {
let mut locks = locks.borrow_mut();
@@ -87,33 +109,17 @@ fn drop_lock(id: MutexID) {
}
}
panic!("Tried to drop lock for mutex {:?} but it wasn't held", id)
// Drop impls shouldn't panic but if this happens something is seriously broken.
unreachable!("Tried to drop lock for mutex {:?} but it wasn't held", id)
});
}
}
/// Register a dependency in the dependency graph
///
/// If the dependency is new, check for cycles in the dependency graph. If not, there shouldn't be
/// any cycles so we don't need to check.
///
/// # Panics
///
/// This function panics if the new dependency would introduce a cycle.
fn register_dependency(lock: MutexID) {
let creates_cycle = HELD_LOCKS.with(|locks| {
if let Some(&previous) = locks.borrow().last() {
let mut graph = get_depedency_graph();
graph.add_edge(previous, lock) && graph.has_cycles()
} else {
false
}
});
if creates_cycle {
// Panic without holding the lock to avoid needlessly poisoning it
panic!("Mutex order graph should not have cycles");
}
/// Get a reference to the current dependency graph
fn get_depedency_graph() -> impl DerefMut<Target = DiGraph> {
DEPENDENCY_GRAPH
.lock()
.unwrap_or_else(PoisonError::into_inner)
}
#[cfg(test)]

View File

@@ -28,10 +28,8 @@ use std::sync::RwLockWriteGuard;
use std::sync::TryLockError;
use std::sync::TryLockResult;
use crate::drop_lock;
use crate::get_depedency_graph;
use crate::register_dependency;
use crate::register_lock;
use crate::BorrowedMutex;
use crate::MutexID;
/// Wrapper for `std::sync::Mutex`
@@ -45,7 +43,7 @@ pub struct TracingMutex<T> {
#[derive(Debug)]
pub struct TracingMutexGuard<'a, T> {
inner: MutexGuard<'a, T>,
mutex: &'a TracingMutex<T>,
mutex: BorrowedMutex,
}
fn map_lockresult<T, I, F>(result: LockResult<I>, mapper: F) -> LockResult<T>
@@ -81,12 +79,11 @@ impl<T> TracingMutex<T> {
#[track_caller]
pub fn lock(&self) -> LockResult<TracingMutexGuard<T>> {
register_dependency(self.id);
let mutex = self.id.get_borrowed();
let result = self.inner.lock();
register_lock(self.id);
let mapper = |guard| TracingMutexGuard {
mutex: self,
mutex,
inner: guard,
};
@@ -95,12 +92,11 @@ impl<T> TracingMutex<T> {
#[track_caller]
pub fn try_lock(&self) -> TryLockResult<TracingMutexGuard<T>> {
register_dependency(self.id);
let mutex = self.id.get_borrowed();
let result = self.inner.try_lock();
register_lock(self.id);
let mapper = |guard| TracingMutexGuard {
mutex: self,
mutex,
inner: guard,
};
@@ -168,12 +164,6 @@ impl<'a, T: fmt::Display> fmt::Display for TracingMutexGuard<'a, T> {
}
}
impl<'a, T> Drop for TracingMutexGuard<'a, T> {
fn drop(&mut self) {
drop_lock(self.mutex.id);
}
}
/// Wrapper for `std::sync::RwLock`
#[derive(Debug)]
pub struct TracingRwLock<T> {
@@ -185,15 +175,15 @@ pub struct TracingRwLock<T> {
///
/// Please refer to [`TracingReadGuard`] and [`TracingWriteGuard`] for usable types.
#[derive(Debug)]
pub struct TracingRwLockGuard<'a, T, L> {
pub struct TracingRwLockGuard<L> {
inner: L,
mutex: &'a TracingRwLock<T>,
mutex: BorrowedMutex,
}
/// Wrapper around `std::sync::RwLockReadGuard`.
pub type TracingReadGuard<'a, T> = TracingRwLockGuard<'a, T, RwLockReadGuard<'a, T>>;
pub type TracingReadGuard<'a, T> = TracingRwLockGuard<RwLockReadGuard<'a, T>>;
/// Wrapper around `std::sync::RwLockWriteGuard`.
pub type TracingWriteGuard<'a, T> = TracingRwLockGuard<'a, T, RwLockWriteGuard<'a, T>>;
pub type TracingWriteGuard<'a, T> = TracingRwLockGuard<RwLockWriteGuard<'a, T>>;
impl<T> TracingRwLock<T> {
pub fn new(t: T) -> Self {
@@ -205,52 +195,34 @@ impl<T> TracingRwLock<T> {
#[track_caller]
pub fn read(&self) -> LockResult<TracingReadGuard<T>> {
register_dependency(self.id);
let mutex = self.id.get_borrowed();
let result = self.inner.read();
register_lock(self.id);
map_lockresult(result, |lock| TracingRwLockGuard {
inner: lock,
mutex: self,
})
map_lockresult(result, |inner| TracingRwLockGuard { inner, mutex })
}
#[track_caller]
pub fn write(&self) -> LockResult<TracingWriteGuard<T>> {
register_dependency(self.id);
let mutex = self.id.get_borrowed();
let result = self.inner.write();
register_lock(self.id);
let id = self.id;
get_depedency_graph().remove_node(id);
map_lockresult(result, |lock| TracingRwLockGuard {
inner: lock,
mutex: self,
})
map_lockresult(result, |inner| TracingRwLockGuard { inner, mutex })
}
#[track_caller]
pub fn try_read(&self) -> TryLockResult<TracingReadGuard<T>> {
register_dependency(self.id);
let mutex = self.id.get_borrowed();
let result = self.inner.try_read();
register_lock(self.id);
map_trylockresult(result, |lock| TracingRwLockGuard {
inner: lock,
mutex: self,
})
map_trylockresult(result, |inner| TracingRwLockGuard { inner, mutex })
}
#[track_caller]
pub fn try_write(&self) -> TryLockResult<TracingWriteGuard<T>> {
register_dependency(self.id);
let mutex = self.id.get_borrowed();
let result = self.inner.try_write();
register_lock(self.id);
map_trylockresult(result, |lock| TracingRwLockGuard {
inner: lock,
mutex: self,
})
map_trylockresult(result, |inner| TracingRwLockGuard { inner, mutex })
}
pub fn get_mut(&mut self) -> LockResult<&mut T> {
@@ -288,13 +260,7 @@ where
}
}
impl<'a, T, L> Drop for TracingRwLockGuard<'a, T, L> {
fn drop(&mut self) {
drop_lock(self.mutex.id)
}
}
impl<'a, T, L> Deref for TracingRwLockGuard<'a, T, L>
impl<L, T> Deref for TracingRwLockGuard<L>
where
L: Deref<Target = T>,
{
@@ -305,7 +271,7 @@ where
}
}
impl<'a, T, L> DerefMut for TracingRwLockGuard<'a, T, L>
impl<T, L> DerefMut for TracingRwLockGuard<L>
where
L: Deref<Target = T> + DerefMut,
{
@@ -357,7 +323,7 @@ mod tests {
assert!(matches!(write_result, TryLockError::WouldBlock));
// Should be able to get a read lock just fine.
let _ = rwlock_clone.read().unwrap();
let _read_lock = rwlock_clone.read().unwrap();
});
handle.join().unwrap();