Refactor MutexID to be self tracking

This avoids the need to implement Drop on every wrapped mutex, and
removes the need for unsafe code in this crate.
This commit is contained in:
2021-05-02 11:51:18 +02:00
parent 24c8453496
commit 050ee27af6
3 changed files with 57 additions and 93 deletions

View File

@@ -1,7 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::collections::HashSet; use std::collections::HashSet;
use std::hash::Hash;
use crate::MutexId;
type Order = usize; type Order = usize;
@@ -20,19 +19,25 @@ type Order = usize;
/// ///
/// [paper]: https://whileydave.com/publications/pk07_jea/ /// [paper]: https://whileydave.com/publications/pk07_jea/
#[derive(Clone, Default, Debug)] #[derive(Clone, Default, Debug)]
pub struct DiGraph { pub struct DiGraph<V>
in_edges: HashMap<MutexId, HashSet<MutexId>>, where
out_edges: HashMap<MutexId, HashSet<MutexId>>, V: Eq + Hash + Copy,
{
in_edges: HashMap<V, HashSet<V>>,
out_edges: HashMap<V, HashSet<V>>,
/// Next topological sort order /// Next topological sort order
next_ord: Order, next_ord: Order,
/// Poison flag, set if a cycle is detected when adding a new edge and /// Poison flag, set if a cycle is detected when adding a new edge and
/// unset when removing a node successfully removed the cycle. /// unset when removing a node successfully removed the cycle.
contains_cycle: bool, contains_cycle: bool,
/// Topological sort order. Order is not guaranteed to be contiguous /// Topological sort order. Order is not guaranteed to be contiguous
ord: HashMap<MutexId, Order>, ord: HashMap<V, Order>,
} }
impl DiGraph { impl<V> DiGraph<V>
where
V: Eq + Hash + Copy,
{
/// Add a new node to the graph. /// Add a new node to the graph.
/// ///
/// If the node already existed, this function does not add it and uses the existing node data. /// If the node already existed, this function does not add it and uses the existing node data.
@@ -40,7 +45,7 @@ impl DiGraph {
/// the node in the topological order. /// the node in the topological order.
/// ///
/// New nodes are appended to the end of the topological order when added. /// New nodes are appended to the end of the topological order when added.
fn add_node(&mut self, n: MutexId) -> (&mut HashSet<MutexId>, &mut HashSet<MutexId>, Order) { fn add_node(&mut self, n: V) -> (&mut HashSet<V>, &mut HashSet<V>, Order) {
let next_ord = &mut self.next_ord; let next_ord = &mut self.next_ord;
let in_edges = self.in_edges.entry(n).or_default(); let in_edges = self.in_edges.entry(n).or_default();
let out_edges = self.out_edges.entry(n).or_default(); let out_edges = self.out_edges.entry(n).or_default();
@@ -54,7 +59,7 @@ impl DiGraph {
(in_edges, out_edges, order) (in_edges, out_edges, order)
} }
pub(crate) fn remove_node(&mut self, n: MutexId) -> bool { pub(crate) fn remove_node(&mut self, n: V) -> bool {
match self.out_edges.remove(&n) { match self.out_edges.remove(&n) {
None => false, None => false,
Some(out_edges) => { Some(out_edges) => {
@@ -79,7 +84,7 @@ impl DiGraph {
/// Add an edge to the graph /// Add an edge to the graph
/// ///
/// Nodes, both from and to, are created as needed when creating new edges. /// Nodes, both from and to, are created as needed when creating new edges.
pub(crate) fn add_edge(&mut self, x: MutexId, y: MutexId) -> bool { pub(crate) fn add_edge(&mut self, x: V, y: V) -> bool {
if x == y { if x == y {
// self-edges are not considered cycles // self-edges are not considered cycles
return false; return false;
@@ -120,13 +125,7 @@ impl DiGraph {
} }
/// Forwards depth-first-search /// Forwards depth-first-search
fn dfs_f( fn dfs_f(&self, n: V, ub: Order, visited: &mut HashSet<V>, delta_f: &mut Vec<V>) -> bool {
&self,
n: MutexId,
ub: Order,
visited: &mut HashSet<MutexId>,
delta_f: &mut Vec<MutexId>,
) -> bool {
visited.insert(n); visited.insert(n);
delta_f.push(n); delta_f.push(n);
@@ -147,13 +146,7 @@ impl DiGraph {
} }
/// Backwards depth-first-search /// Backwards depth-first-search
fn dfs_b( fn dfs_b(&self, n: V, lb: Order, visited: &mut HashSet<V>, delta_b: &mut Vec<V>) {
&self,
n: MutexId,
lb: Order,
visited: &mut HashSet<MutexId>,
delta_b: &mut Vec<MutexId>,
) {
visited.insert(n); visited.insert(n);
delta_b.push(n); delta_b.push(n);
@@ -164,7 +157,7 @@ impl DiGraph {
} }
} }
fn reorder(&mut self, mut delta_f: Vec<MutexId>, mut delta_b: Vec<MutexId>) { fn reorder(&mut self, mut delta_f: Vec<V>, mut delta_b: Vec<V>) {
self.sort(&mut delta_f); self.sort(&mut delta_f);
self.sort(&mut delta_b); self.sort(&mut delta_b);
@@ -190,7 +183,7 @@ impl DiGraph {
} }
} }
fn sort(&self, ids: &mut [MutexId]) { fn sort(&self, ids: &mut [V]) {
// Can use unstable sort because mutex ids should not be equal // Can use unstable sort because mutex ids should not be equal
ids.sort_unstable_by_key(|v| self.ord[v]); ids.sort_unstable_by_key(|v| self.ord[v]);
} }
@@ -239,10 +232,10 @@ impl DiGraph {
/// Helper function for `Self::recompute_topological_order`. /// Helper function for `Self::recompute_topological_order`.
fn visit( fn visit(
&self, &self,
v: MutexId, v: V,
permanent_marks: &mut HashSet<MutexId>, permanent_marks: &mut HashSet<V>,
temporary_marks: &mut HashSet<MutexId>, temporary_marks: &mut HashSet<V>,
rev_order: &mut Vec<MutexId>, rev_order: &mut Vec<V>,
) -> bool { ) -> bool {
if permanent_marks.contains(&v) { if permanent_marks.contains(&v) {
return true; return true;
@@ -271,28 +264,26 @@ impl DiGraph {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::MutexId;
#[test] #[test]
fn test_digraph() { fn test_digraph() {
let id: Vec<MutexId> = (0..5).map(|_| MutexId::new()).collect();
let mut graph = DiGraph::default(); let mut graph = DiGraph::default();
// Add some safe edges // Add some safe edges
graph.add_edge(id[0], id[1]); graph.add_edge(0, 1);
graph.add_edge(id[1], id[2]); graph.add_edge(1, 2);
graph.add_edge(id[2], id[3]); graph.add_edge(2, 3);
graph.add_edge(id[4], id[2]); graph.add_edge(4, 2);
// Should not have a cycle yet // Should not have a cycle yet
assert!(!graph.has_cycles()); assert!(!graph.has_cycles());
// Introduce cycle 3 → 1 → 2 → 3 // Introduce cycle 3 → 1 → 2 → 3
graph.add_edge(id[3], id[1]); graph.add_edge(3, 1);
assert!(graph.has_cycles()); assert!(graph.has_cycles());
// Removing 3 should remove that cycle // Removing 3 should remove that cycle
assert!(graph.remove_node(id[3])); assert!(graph.remove_node(3));
assert!(!graph.has_cycles()) assert!(!graph.has_cycles())
} }
} }

View File

@@ -70,11 +70,11 @@ thread_local! {
/// ///
/// Assuming that locks are roughly released in the reverse order in which they were acquired, /// Assuming that locks are roughly released in the reverse order in which they were acquired,
/// a stack should be more efficient to keep track of the current state than a set would be. /// a stack should be more efficient to keep track of the current state than a set would be.
static HELD_LOCKS: RefCell<Vec<MutexId>> = RefCell::new(Vec::new()); static HELD_LOCKS: RefCell<Vec<usize>> = RefCell::new(Vec::new());
} }
lazy_static! { lazy_static! {
static ref DEPENDENCY_GRAPH: Mutex<DiGraph> = Default::default(); static ref DEPENDENCY_GRAPH: Mutex<DiGraph<usize>> = Default::default();
} }
/// Dedicated ID type for Mutexes /// Dedicated ID type for Mutexes
@@ -86,7 +86,6 @@ lazy_static! {
/// ///
/// One possible alteration is to make this type not `Copy` but `Drop`, and handle deregistering /// One possible alteration is to make this type not `Copy` but `Drop`, and handle deregistering
/// the lock from there. /// the lock from there.
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
struct MutexId(usize); struct MutexId(usize);
impl MutexId { impl MutexId {
@@ -105,6 +104,10 @@ impl MutexId {
.expect("Mutex ID wraparound happened, results unreliable") .expect("Mutex ID wraparound happened, results unreliable")
} }
pub fn value(&self) -> usize {
self.0
}
/// Get a borrowed guard for this lock. /// Get a borrowed guard for this lock.
/// ///
/// This method adds checks adds this Mutex ID to the dependency graph as needed, and adds the /// This method adds checks adds this Mutex ID to the dependency graph as needed, and adds the
@@ -113,12 +116,12 @@ impl MutexId {
/// # Panics /// # Panics
/// ///
/// This method panics if the new dependency would introduce a cycle. /// This method panics if the new dependency would introduce a cycle.
pub fn get_borrowed(self) -> BorrowedMutex { pub fn get_borrowed(&self) -> BorrowedMutex {
let creates_cycle = HELD_LOCKS.with(|locks| { let creates_cycle = HELD_LOCKS.with(|locks| {
if let Some(&previous) = locks.borrow().last() { if let Some(&previous) = locks.borrow().last() {
let mut graph = get_depedency_graph(); let mut graph = get_depedency_graph();
graph.add_edge(previous, self) && graph.has_cycles() graph.add_edge(previous, self.value()) && graph.has_cycles()
} else { } else {
false false
} }
@@ -129,7 +132,7 @@ impl MutexId {
panic!("Mutex order graph should not have cycles"); panic!("Mutex order graph should not have cycles");
} }
HELD_LOCKS.with(|locks| locks.borrow_mut().push(self)); HELD_LOCKS.with(|locks| locks.borrow_mut().push(self.value()));
BorrowedMutex(self) BorrowedMutex(self)
} }
} }
@@ -140,8 +143,14 @@ impl fmt::Debug for MutexId {
} }
} }
impl Drop for MutexId {
fn drop(&mut self) {
get_depedency_graph().remove_node(self.value());
}
}
#[derive(Debug)] #[derive(Debug)]
struct BorrowedMutex(MutexId); struct BorrowedMutex<'a>(&'a MutexId);
/// Drop a lock held by the current thread. /// Drop a lock held by the current thread.
/// ///
@@ -149,7 +158,7 @@ struct BorrowedMutex(MutexId);
/// ///
/// This function panics if the lock did not appear to be handled by this thread. If that happens, /// This function panics if the lock did not appear to be handled by this thread. If that happens,
/// that is an indication of a serious design flaw in this library. /// that is an indication of a serious design flaw in this library.
impl Drop for BorrowedMutex { impl<'a> Drop for BorrowedMutex<'a> {
fn drop(&mut self) { fn drop(&mut self) {
let id = self.0; let id = self.0;
@@ -157,7 +166,7 @@ impl Drop for BorrowedMutex {
let mut locks = locks.borrow_mut(); let mut locks = locks.borrow_mut();
for (i, &lock) in locks.iter().enumerate().rev() { for (i, &lock) in locks.iter().enumerate().rev() {
if lock == id { if lock == id.value() {
locks.remove(i); locks.remove(i);
return; return;
} }
@@ -170,7 +179,7 @@ impl Drop for BorrowedMutex {
} }
/// Get a reference to the current dependency graph /// Get a reference to the current dependency graph
fn get_depedency_graph() -> impl DerefMut<Target = DiGraph> { fn get_depedency_graph() -> impl DerefMut<Target = DiGraph<usize>> {
DEPENDENCY_GRAPH DEPENDENCY_GRAPH
.lock() .lock()
.unwrap_or_else(PoisonError::into_inner) .unwrap_or_else(PoisonError::into_inner)

View File

@@ -14,10 +14,8 @@
//! rwlock.read().unwrap(); //! rwlock.read().unwrap();
//! ``` //! ```
use std::fmt; use std::fmt;
use std::mem;
use std::ops::Deref; use std::ops::Deref;
use std::ops::DerefMut; use std::ops::DerefMut;
use std::ptr;
use std::sync::LockResult; use std::sync::LockResult;
use std::sync::Mutex; use std::sync::Mutex;
use std::sync::MutexGuard; use std::sync::MutexGuard;
@@ -28,7 +26,6 @@ use std::sync::RwLockWriteGuard;
use std::sync::TryLockError; use std::sync::TryLockError;
use std::sync::TryLockResult; use std::sync::TryLockResult;
use crate::get_depedency_graph;
use crate::BorrowedMutex; use crate::BorrowedMutex;
use crate::MutexId; use crate::MutexId;
@@ -69,7 +66,7 @@ pub struct TracingMutex<T> {
#[derive(Debug)] #[derive(Debug)]
pub struct TracingMutexGuard<'a, T> { pub struct TracingMutexGuard<'a, T> {
inner: MutexGuard<'a, T>, inner: MutexGuard<'a, T>,
mutex: BorrowedMutex, mutex: BorrowedMutex<'a>,
} }
fn map_lockresult<T, I, F>(result: LockResult<I>, mapper: F) -> LockResult<T> fn map_lockresult<T, I, F>(result: LockResult<I>, mapper: F) -> LockResult<T>
@@ -138,17 +135,7 @@ impl<T> TracingMutex<T> {
} }
pub fn into_inner(self) -> LockResult<T> { pub fn into_inner(self) -> LockResult<T> {
self.deregister(); self.inner.into_inner()
// Safety: we forget the original immediately after
let inner = unsafe { ptr::read(&self.inner) };
mem::forget(self);
inner.into_inner()
}
fn deregister(&self) {
get_depedency_graph().remove_node(self.id);
} }
} }
@@ -164,12 +151,6 @@ impl<T> From<T> for TracingMutex<T> {
} }
} }
impl<T> Drop for TracingMutex<T> {
fn drop(&mut self) {
self.deregister();
}
}
impl<'a, T> Deref for TracingMutexGuard<'a, T> { impl<'a, T> Deref for TracingMutexGuard<'a, T> {
type Target = MutexGuard<'a, T>; type Target = MutexGuard<'a, T>;
@@ -201,15 +182,15 @@ pub struct TracingRwLock<T> {
/// ///
/// Please refer to [`TracingReadGuard`] and [`TracingWriteGuard`] for usable types. /// Please refer to [`TracingReadGuard`] and [`TracingWriteGuard`] for usable types.
#[derive(Debug)] #[derive(Debug)]
pub struct TracingRwLockGuard<L> { pub struct TracingRwLockGuard<'a, L> {
inner: L, inner: L,
mutex: BorrowedMutex, mutex: BorrowedMutex<'a>,
} }
/// Wrapper around [`std::sync::RwLockReadGuard`]. /// Wrapper around [`std::sync::RwLockReadGuard`].
pub type TracingReadGuard<'a, T> = TracingRwLockGuard<RwLockReadGuard<'a, T>>; pub type TracingReadGuard<'a, T> = TracingRwLockGuard<'a, RwLockReadGuard<'a, T>>;
/// Wrapper around [`std::sync::RwLockWriteGuard`]. /// Wrapper around [`std::sync::RwLockWriteGuard`].
pub type TracingWriteGuard<'a, T> = TracingRwLockGuard<RwLockWriteGuard<'a, T>>; pub type TracingWriteGuard<'a, T> = TracingRwLockGuard<'a, RwLockWriteGuard<'a, T>>;
impl<T> TracingRwLock<T> { impl<T> TracingRwLock<T> {
pub fn new(t: T) -> Self { pub fn new(t: T) -> Self {
@@ -256,24 +237,7 @@ impl<T> TracingRwLock<T> {
} }
pub fn into_inner(self) -> LockResult<T> { pub fn into_inner(self) -> LockResult<T> {
self.deregister(); self.inner.into_inner()
// Grab our contents and then forget ourselves
// Safety: we immediately forget the mutex after copying
let inner = unsafe { ptr::read(&self.inner) };
mem::forget(self);
inner.into_inner()
}
fn deregister(&self) {
get_depedency_graph().remove_node(self.id);
}
}
impl<T> Drop for TracingRwLock<T> {
fn drop(&mut self) {
self.deregister();
} }
} }
@@ -286,7 +250,7 @@ where
} }
} }
impl<L, T> Deref for TracingRwLockGuard<L> impl<'a, L, T> Deref for TracingRwLockGuard<'a, L>
where where
L: Deref<Target = T>, L: Deref<Target = T>,
{ {
@@ -297,7 +261,7 @@ where
} }
} }
impl<T, L> DerefMut for TracingRwLockGuard<L> impl<'a, T, L> DerefMut for TracingRwLockGuard<'a, L>
where where
L: Deref<Target = T> + DerefMut, L: Deref<Target = T> + DerefMut,
{ {