diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d2b662..5e1fd6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ### Added - Build [docs.rs] documentation with all features enabled for completeness. +- Add support for `std::sync::Condvar` ### Fixed diff --git a/src/stdsync.rs b/src/stdsync.rs index c6050fc..3cdb4d3 100644 --- a/src/stdsync.rs +++ b/src/stdsync.rs @@ -16,6 +16,7 @@ use std::fmt; use std::ops::Deref; use std::ops::DerefMut; +use std::sync::Condvar; use std::sync::LockResult; use std::sync::Mutex; use std::sync::MutexGuard; @@ -27,6 +28,8 @@ use std::sync::RwLockReadGuard; use std::sync::RwLockWriteGuard; use std::sync::TryLockError; use std::sync::TryLockResult; +use std::sync::WaitTimeoutResult; +use std::time::Duration; use crate::BorrowedMutex; use crate::LazyMutexId; @@ -48,6 +51,14 @@ pub type DebugMutexGuard<'a, T> = TracingMutexGuard<'a, T>; #[cfg(not(debug_assertions))] pub type DebugMutexGuard<'a, T> = MutexGuard<'a, T>; +/// Debug-only `Condvar` +/// +/// Type alias that accepts the mutex guard emitted from [`DebugMutex`]. +#[cfg(debug_assertions)] +pub type DebugCondvar = TracingCondvar; +#[cfg(not(debug_assertions))] +pub type DebugCondvar = Condvar; + /// Debug-only tracing `RwLock`. /// /// Type alias that resolves to [`TracingRwLock`] when debug assertions are enabled and to @@ -214,6 +225,123 @@ impl<'a, T: fmt::Display> fmt::Display for TracingMutexGuard<'a, T> { } } +/// Wrapper around [`std::sync::Condvar`]. +/// +/// Allows `TracingMutexGuard` to be used with a `Condvar`. Unlike other structs in this module, +/// this wrapper does not add any additional dependency tracking or other overhead on top of the +/// primitive it wraps. All dependency tracking happens through the mutexes itself. +/// +/// # Panics +/// +/// This struct does not add any panics over the base implementation of `Condvar`, but panics due to +/// dependency tracking may poison associated mutexes. +/// +/// # Examples +/// +/// ``` +/// use std::sync::Arc; +/// use std::thread; +/// +/// use tracing_mutex::stdsync::{TracingCondvar, TracingMutex}; +/// +/// let pair = Arc::new((TracingMutex::new(false), TracingCondvar::new())); +/// let pair2 = Arc::clone(&pair); +/// +/// // Spawn a thread that will unlock the condvar +/// thread::spawn(move || { +/// let (lock, condvar) = &*pair2; +/// *lock.lock().unwrap() = true; +/// condvar.notify_one(); +/// }); +/// +/// // Wait until the thread unlocks the condvar +/// let (lock, condvar) = &*pair; +/// let guard = lock.lock().unwrap(); +/// let guard = condvar.wait_while(guard, |started| !*started).unwrap(); +/// +/// // Guard should read true now +/// assert!(*guard); +/// ``` +#[derive(Debug, Default)] +pub struct TracingCondvar(Condvar); + +impl TracingCondvar { + /// Creates a new condition variable which is ready to be waited on and notified. + pub fn new() -> Self { + Default::default() + } + + /// Wrapper for [`std::sync::Condvar::wait`]. + pub fn wait<'a, T>( + &self, + guard: TracingMutexGuard<'a, T>, + ) -> LockResult> { + let TracingMutexGuard { _mutex, inner } = guard; + + map_lockresult(self.0.wait(inner), |inner| TracingMutexGuard { + _mutex, + inner, + }) + } + + /// Wrapper for [`std::sync::Condvar::wait_while`]. + pub fn wait_while<'a, T, F>( + &self, + guard: TracingMutexGuard<'a, T>, + condition: F, + ) -> LockResult> + where + F: FnMut(&mut T) -> bool, + { + let TracingMutexGuard { _mutex, inner } = guard; + + map_lockresult(self.0.wait_while(inner, condition), |inner| { + TracingMutexGuard { _mutex, inner } + }) + } + + /// Wrapper for [`std::sync::Condvar::wait_timeout`]. + pub fn wait_timeout<'a, T>( + &self, + guard: TracingMutexGuard<'a, T>, + dur: Duration, + ) -> LockResult<(TracingMutexGuard<'a, T>, WaitTimeoutResult)> { + let TracingMutexGuard { _mutex, inner } = guard; + + map_lockresult(self.0.wait_timeout(inner, dur), |(inner, result)| { + (TracingMutexGuard { _mutex, inner }, result) + }) + } + + /// Wrapper for [`std::sync::Condvar::wait_timeout_while`]. + pub fn wait_timeout_while<'a, T, F>( + &self, + guard: TracingMutexGuard<'a, T>, + dur: Duration, + condition: F, + ) -> LockResult<(TracingMutexGuard<'a, T>, WaitTimeoutResult)> + where + F: FnMut(&mut T) -> bool, + { + let TracingMutexGuard { _mutex, inner } = guard; + + map_lockresult( + self.0.wait_timeout_while(inner, dur, condition), + |(inner, result)| (TracingMutexGuard { _mutex, inner }, result), + ) + } + + /// Wrapper for [`std::sync::Condvar::notify_one`]. + pub fn notify_one(&self) { + self.0.notify_one(); + } + + /// Wrapper for [`std::sync::Condvar::notify_all`]. + pub fn notify_all(&self) { + self.0.notify_all(); + } +} + /// Wrapper for [`std::sync::RwLock`]. #[derive(Debug, Default)] pub struct TracingRwLock {