mirror of
https://github.com/bertptrs/tracing-mutex.git
synced 2025-12-25 20:50:32 +01:00
Implement Mutex wrappers
This commit is contained in:
116
src/graph.rs
Normal file
116
src/graph.rs
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
#[derive(Clone, Default, Debug)]
|
||||||
|
pub struct DiGraph {
|
||||||
|
in_edges: HashMap<usize, Vec<usize>>,
|
||||||
|
out_edges: HashMap<usize, Vec<usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DiGraph {
|
||||||
|
fn add_node(&mut self, node: usize) -> (&mut Vec<usize>, &mut Vec<usize>) {
|
||||||
|
let in_edges = self.in_edges.entry(node).or_default();
|
||||||
|
let out_edges = self.out_edges.entry(node).or_default();
|
||||||
|
|
||||||
|
(in_edges, out_edges)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove_node(&mut self, node: usize) -> bool {
|
||||||
|
match self.out_edges.remove(&node) {
|
||||||
|
None => false,
|
||||||
|
Some(out_edges) => {
|
||||||
|
for other in out_edges {
|
||||||
|
self.in_edges
|
||||||
|
.get_mut(&other)
|
||||||
|
.unwrap()
|
||||||
|
.retain(|c| c != &node);
|
||||||
|
}
|
||||||
|
|
||||||
|
for other in self.in_edges.remove(&node).unwrap() {
|
||||||
|
self.out_edges
|
||||||
|
.get_mut(&other)
|
||||||
|
.unwrap()
|
||||||
|
.retain(|c| c != &node);
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_edge(&mut self, from: usize, to: usize) -> bool {
|
||||||
|
if from == to {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (_, out_edges) = self.add_node(from);
|
||||||
|
|
||||||
|
out_edges.push(to);
|
||||||
|
|
||||||
|
let (in_edges, _) = self.add_node(to);
|
||||||
|
|
||||||
|
// No need for existence check assuming the datastructure is consistent
|
||||||
|
in_edges.push(from);
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_cycles(&self) -> bool {
|
||||||
|
let mut marks = HashSet::new();
|
||||||
|
let mut temp = HashSet::new();
|
||||||
|
|
||||||
|
self.out_edges
|
||||||
|
.keys()
|
||||||
|
.copied()
|
||||||
|
.any(|node| !self.visit(node, &mut marks, &mut temp))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit(&self, node: usize, marks: &mut HashSet<usize>, temp: &mut HashSet<usize>) -> bool {
|
||||||
|
if marks.contains(&node) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !temp.insert(node) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.out_edges[&node]
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.any(|node| !self.visit(node, marks, temp))
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
temp.remove(&node);
|
||||||
|
|
||||||
|
marks.insert(node);
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::DiGraph;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_digraph() {
|
||||||
|
let mut graph = DiGraph::default();
|
||||||
|
|
||||||
|
graph.add_edge(1, 2);
|
||||||
|
graph.add_edge(2, 3);
|
||||||
|
graph.add_edge(3, 4);
|
||||||
|
graph.add_edge(5, 2);
|
||||||
|
|
||||||
|
assert!(!graph.has_cycles());
|
||||||
|
|
||||||
|
graph.add_edge(4, 2);
|
||||||
|
|
||||||
|
assert!(graph.has_cycles());
|
||||||
|
|
||||||
|
assert!(graph.remove_node(4));
|
||||||
|
|
||||||
|
assert!(!graph.has_cycles())
|
||||||
|
}
|
||||||
|
}
|
||||||
250
src/lib.rs
250
src/lib.rs
@@ -1,116 +1,184 @@
|
|||||||
use std::collections::HashMap;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashSet;
|
use std::fmt;
|
||||||
|
use std::fmt::Display;
|
||||||
|
use std::ops::Deref;
|
||||||
|
use std::ops::DerefMut;
|
||||||
|
use std::sync::atomic::AtomicUsize;
|
||||||
|
use std::sync::atomic::Ordering;
|
||||||
|
use std::sync::LockResult;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use std::sync::MutexGuard;
|
||||||
|
use std::sync::PoisonError;
|
||||||
|
use std::sync::TryLockError;
|
||||||
|
use std::sync::TryLockResult;
|
||||||
|
|
||||||
#[derive(Clone, Default, Debug)]
|
mod graph;
|
||||||
struct DiGraph {
|
|
||||||
in_edges: HashMap<usize, Vec<usize>>,
|
/// Counter for Mutex IDs. Atomic avoids the need for locking.
|
||||||
out_edges: HashMap<usize, Vec<usize>>,
|
static ID_SEQUENCE: AtomicUsize = AtomicUsize::new(0);
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
/// Stack to track which locks are held
|
||||||
|
///
|
||||||
|
/// Assuming that locks are roughly released in the reverse order in which they were acquired,
|
||||||
|
/// a stack should be more efficient to keep track of the current state than a set would be.
|
||||||
|
static HELD_LOCKS: RefCell<Vec<usize>> = RefCell::new(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DiGraph {
|
/// Wrapper for std::sync::Mutex
|
||||||
fn add_node(&mut self, node: usize) -> (&mut Vec<usize>, &mut Vec<usize>) {
|
#[derive(Debug)]
|
||||||
let in_edges = self.in_edges.entry(node).or_default();
|
pub struct TracingMutex<T> {
|
||||||
let out_edges = self.out_edges.entry(node).or_default();
|
inner: Mutex<T>,
|
||||||
|
id: usize,
|
||||||
|
}
|
||||||
|
|
||||||
(in_edges, out_edges)
|
/// Wrapper for std::sync::MutexGuard
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TracingMutexGuard<'a, T> {
|
||||||
|
inner: MutexGuard<'a, T>,
|
||||||
|
mutex: &'a TracingMutex<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn next_mutex_id() -> usize {
|
||||||
|
ID_SEQUENCE
|
||||||
|
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |id| id.checked_add(1))
|
||||||
|
.expect("Mutex ID wraparound happened, results unreliable")
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> TracingMutex<T> {
|
||||||
|
pub fn new(t: T) -> Self {
|
||||||
|
Self {
|
||||||
|
inner: Mutex::new(t),
|
||||||
|
id: next_mutex_id(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn remove_node(&mut self, node: usize) -> bool {
|
pub fn lock(&self) -> LockResult<TracingMutexGuard<T>> {
|
||||||
match self.out_edges.remove(&node) {
|
let result = self.inner.lock();
|
||||||
None => false,
|
self.register_lock();
|
||||||
Some(out_edges) => {
|
|
||||||
for other in out_edges {
|
|
||||||
self.in_edges
|
|
||||||
.get_mut(&other)
|
|
||||||
.unwrap()
|
|
||||||
.retain(|c| c != &node);
|
|
||||||
}
|
|
||||||
|
|
||||||
for other in self.in_edges.remove(&node).unwrap() {
|
let mapper = |guard| TracingMutexGuard {
|
||||||
self.out_edges
|
mutex: self,
|
||||||
.get_mut(&other)
|
inner: guard,
|
||||||
.unwrap()
|
};
|
||||||
.retain(|c| c != &node);
|
|
||||||
}
|
|
||||||
|
|
||||||
true
|
result
|
||||||
|
.map(mapper)
|
||||||
|
.map_err(|poison| PoisonError::new(mapper(poison.into_inner())))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn try_lock(&self) -> TryLockResult<TracingMutexGuard<T>> {
|
||||||
|
let result = self.inner.try_lock();
|
||||||
|
self.register_lock();
|
||||||
|
|
||||||
|
let mapper = |guard| TracingMutexGuard {
|
||||||
|
mutex: self,
|
||||||
|
inner: guard,
|
||||||
|
};
|
||||||
|
|
||||||
|
result.map(mapper).map_err(|error| match error {
|
||||||
|
TryLockError::Poisoned(poison) => PoisonError::new(mapper(poison.into_inner())).into(),
|
||||||
|
TryLockError::WouldBlock => TryLockError::WouldBlock,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_id(&self) -> usize {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_poisoned(&self) -> bool {
|
||||||
|
self.inner.is_poisoned()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_mut(&mut self) -> LockResult<&mut T> {
|
||||||
|
self.inner.get_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_lock(&self) {
|
||||||
|
HELD_LOCKS.with(|locks| locks.borrow_mut().push(self.id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: ?Sized + Default> Default for TracingMutex<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new(T::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<T> for TracingMutex<T> {
|
||||||
|
fn from(t: T) -> Self {
|
||||||
|
Self::new(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Deref for TracingMutexGuard<'a, T> {
|
||||||
|
type Target = MutexGuard<'a, T>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> DerefMut for TracingMutexGuard<'a, T> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
&mut self.inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Display> fmt::Display for TracingMutexGuard<'a, T> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
self.inner.fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Drop for TracingMutexGuard<'a, T> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
HELD_LOCKS.with(|locks| {
|
||||||
|
let id = self.mutex.id;
|
||||||
|
let mut locks = locks.borrow_mut();
|
||||||
|
|
||||||
|
for (i, &lock) in locks.iter().enumerate().rev() {
|
||||||
|
if lock == id {
|
||||||
|
locks.remove(i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_edge(&mut self, from: usize, to: usize) -> bool {
|
panic!("Tried to drop lock for mutex {} but it wasn't held", id)
|
||||||
if from == to {
|
});
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (_, out_edges) = self.add_node(from);
|
|
||||||
|
|
||||||
out_edges.push(to);
|
|
||||||
|
|
||||||
let (in_edges, _) = self.add_node(to);
|
|
||||||
|
|
||||||
// No need for existence check assuming the datastructure is consistent
|
|
||||||
in_edges.push(from);
|
|
||||||
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn has_cycles(&self) -> bool {
|
|
||||||
let mut marks = HashSet::new();
|
|
||||||
let mut temp = HashSet::new();
|
|
||||||
|
|
||||||
self.out_edges
|
|
||||||
.keys()
|
|
||||||
.copied()
|
|
||||||
.any(|node| !self.visit(node, &mut marks, &mut temp))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit(&self, node: usize, marks: &mut HashSet<usize>, temp: &mut HashSet<usize>) -> bool {
|
|
||||||
if marks.contains(&node) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if !temp.insert(node) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.out_edges[&node]
|
|
||||||
.iter()
|
|
||||||
.copied()
|
|
||||||
.any(|node| !self.visit(node, marks, temp))
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
temp.remove(&node);
|
|
||||||
|
|
||||||
marks.insert(node);
|
|
||||||
|
|
||||||
true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::DiGraph;
|
use std::sync::Arc;
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_digraph() {
|
fn test_next_mutex_id() {
|
||||||
let mut graph = DiGraph::default();
|
let initial = next_mutex_id();
|
||||||
|
let next = next_mutex_id();
|
||||||
|
|
||||||
graph.add_edge(1, 2);
|
// Can't assert N + 1 because multiple threads running tests
|
||||||
graph.add_edge(2, 3);
|
assert!(initial < next);
|
||||||
graph.add_edge(3, 4);
|
}
|
||||||
graph.add_edge(5, 2);
|
|
||||||
|
|
||||||
assert!(!graph.has_cycles());
|
#[test]
|
||||||
|
fn test_mutex_usage() {
|
||||||
|
let mutex = Arc::new(TracingMutex::new(()));
|
||||||
|
let mutex_clone = mutex.clone();
|
||||||
|
|
||||||
graph.add_edge(4, 2);
|
let _guard = mutex.lock().unwrap();
|
||||||
|
|
||||||
assert!(graph.has_cycles());
|
// Now try to cause a blocking exception in another thread
|
||||||
|
let handle = thread::spawn(move || {
|
||||||
|
let result = mutex_clone.try_lock().unwrap_err();
|
||||||
|
|
||||||
assert!(graph.remove_node(4));
|
assert!(matches!(result, TryLockError::WouldBlock));
|
||||||
|
});
|
||||||
|
|
||||||
assert!(!graph.has_cycles())
|
handle.join().unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user