From b21a63e74b792d173b048e0da3a4ce239aabca85 Mon Sep 17 00:00:00 2001 From: Bert Peters Date: Thu, 27 May 2021 19:20:45 +0200 Subject: [PATCH] Implement RwLock-based traits for lockapi worker. --- src/lockapi.rs | 233 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 227 insertions(+), 6 deletions(-) diff --git a/src/lockapi.rs b/src/lockapi.rs index b957d6d..c9f82f4 100644 --- a/src/lockapi.rs +++ b/src/lockapi.rs @@ -3,6 +3,16 @@ use lock_api::GuardNoSend; use lock_api::RawMutex; use lock_api::RawMutexFair; use lock_api::RawMutexTimed; +use lock_api::RawRwLock; +use lock_api::RawRwLockDowngrade; +use lock_api::RawRwLockFair; +use lock_api::RawRwLockRecursive; +use lock_api::RawRwLockRecursiveTimed; +use lock_api::RawRwLockTimed; +use lock_api::RawRwLockUpgrade; +use lock_api::RawRwLockUpgradeDowngrade; +use lock_api::RawRwLockUpgradeFair; +use lock_api::RawRwLockUpgradeTimed; use crate::LazyMutexId; @@ -13,6 +23,7 @@ use crate::LazyMutexId; #[derive(Debug, Default)] pub struct TracingWrapper { inner: T, + // Need to use a lazy mutex ID to intialize statically. id: LazyMutexId, } @@ -32,10 +43,26 @@ impl TracingWrapper { self.id.mark_released(); } + /// First mark ourselves as held, then call the locking function. + fn lock(&self, f: impl FnOnce()) { + self.mark_held(); + f(); + } + + /// First call the unlocking function, then mark ourselves as realeased. + unsafe fn unlock(&self, f: impl FnOnce()) { + f(); + self.mark_released(); + } + /// Conditionally lock the mutex. /// /// First acquires the lock, then runs the provided function. If that function returns true, /// then the lock is kept, otherwise the mutex is immediately marked as relased. + /// + /// # Returns + /// + /// The value returned from the callback. fn conditionally_lock(&self, f: impl FnOnce() -> bool) -> bool { // Mark as locked while we try to do the thing self.mark_held(); @@ -65,8 +92,7 @@ where type GuardMarker = GuardNoSend; fn lock(&self) { - self.mark_held(); - self.inner.lock(); + self.lock(|| self.inner.lock()); } fn try_lock(&self) -> bool { @@ -74,8 +100,7 @@ where } unsafe fn unlock(&self) { - self.inner.unlock(); - self.mark_released(); + self.unlock(|| self.inner.unlock()); } fn is_locked(&self) -> bool { @@ -89,8 +114,7 @@ where T: RawMutexFair, { unsafe fn unlock_fair(&self) { - self.inner.unlock_fair(); - self.mark_released(); + self.unlock(|| self.inner.unlock_fair()) } unsafe fn bump(&self) { @@ -116,3 +140,200 @@ where self.conditionally_lock(|| self.inner.try_lock_until(timeout)) } } + +unsafe impl RawRwLock for TracingWrapper +where + T: RawRwLock, +{ + const INIT: Self = Self { + inner: T::INIT, + id: LazyMutexId::new(), + }; + + /// Always equal to [`GuardNoSend`], as an implementation detail in the tracking system requires + /// this behaviour. May change in the future to reflect the actual guard type from the wrapped + /// primitive. + type GuardMarker = GuardNoSend; + + fn lock_shared(&self) { + self.lock(|| self.inner.lock_shared()); + } + + fn try_lock_shared(&self) -> bool { + self.conditionally_lock(|| self.inner.try_lock_shared()) + } + + unsafe fn unlock_shared(&self) { + self.unlock(|| self.inner.unlock_shared()); + } + + fn lock_exclusive(&self) { + self.lock(|| self.inner.lock_exclusive()); + } + + fn try_lock_exclusive(&self) -> bool { + self.conditionally_lock(|| self.inner.try_lock_exclusive()) + } + + unsafe fn unlock_exclusive(&self) { + self.lock(|| self.inner.unlock_exclusive()); + } + + fn is_locked(&self) -> bool { + self.inner.is_locked() + } +} + +unsafe impl RawRwLockDowngrade for TracingWrapper +where + T: RawRwLockDowngrade, +{ + unsafe fn downgrade(&self) { + // Downgrading does not require tracking + self.inner.downgrade() + } +} + +unsafe impl RawRwLockUpgrade for TracingWrapper +where + T: RawRwLockUpgrade, +{ + fn lock_upgradable(&self) { + self.lock(|| self.inner.lock_upgradable()); + } + + fn try_lock_upgradable(&self) -> bool { + self.conditionally_lock(|| self.inner.try_lock_upgradable()) + } + + unsafe fn unlock_upgradable(&self) { + self.unlock(|| self.inner.unlock_upgradable()); + } + + unsafe fn upgrade(&self) { + self.inner.upgrade(); + } + + unsafe fn try_upgrade(&self) -> bool { + self.inner.try_upgrade() + } +} + +unsafe impl RawRwLockFair for TracingWrapper +where + T: RawRwLockFair, +{ + unsafe fn unlock_shared_fair(&self) { + self.unlock(|| self.inner.unlock_shared_fair()); + } + + unsafe fn unlock_exclusive_fair(&self) { + self.unlock(|| self.inner.unlock_exclusive_fair()); + } + + unsafe fn bump_shared(&self) { + self.inner.bump_shared(); + } + + unsafe fn bump_exclusive(&self) { + self.inner.bump_exclusive(); + } +} + +unsafe impl RawRwLockRecursive for TracingWrapper +where + T: RawRwLockRecursive, +{ + fn lock_shared_recursive(&self) { + self.lock(|| self.inner.lock_shared_recursive()); + } + + fn try_lock_shared_recursive(&self) -> bool { + self.conditionally_lock(|| self.inner.try_lock_shared_recursive()) + } +} + +unsafe impl RawRwLockRecursiveTimed for TracingWrapper +where + T: RawRwLockRecursiveTimed, +{ + fn try_lock_shared_recursive_for(&self, timeout: Self::Duration) -> bool { + self.conditionally_lock(|| self.inner.try_lock_shared_recursive_for(timeout)) + } + + fn try_lock_shared_recursive_until(&self, timeout: Self::Instant) -> bool { + self.conditionally_lock(|| self.inner.try_lock_shared_recursive_until(timeout)) + } +} + +unsafe impl RawRwLockTimed for TracingWrapper +where + T: RawRwLockTimed, +{ + type Duration = T::Duration; + + type Instant = T::Instant; + + fn try_lock_shared_for(&self, timeout: Self::Duration) -> bool { + self.conditionally_lock(|| self.inner.try_lock_shared_for(timeout)) + } + + fn try_lock_shared_until(&self, timeout: Self::Instant) -> bool { + self.conditionally_lock(|| self.inner.try_lock_shared_until(timeout)) + } + + fn try_lock_exclusive_for(&self, timeout: Self::Duration) -> bool { + self.conditionally_lock(|| self.inner.try_lock_exclusive_for(timeout)) + } + + fn try_lock_exclusive_until(&self, timeout: Self::Instant) -> bool { + self.conditionally_lock(|| self.inner.try_lock_exclusive_until(timeout)) + } +} + +unsafe impl RawRwLockUpgradeDowngrade for TracingWrapper +where + T: RawRwLockUpgradeDowngrade, +{ + unsafe fn downgrade_upgradable(&self) { + self.inner.downgrade_upgradable() + } + + unsafe fn downgrade_to_upgradable(&self) { + self.inner.downgrade_to_upgradable() + } +} + +unsafe impl RawRwLockUpgradeFair for TracingWrapper +where + T: RawRwLockUpgradeFair, +{ + unsafe fn unlock_upgradable_fair(&self) { + self.lock(|| self.inner.unlock_upgradable_fair()) + } + + unsafe fn bump_upgradable(&self) { + self.inner.bump_upgradable() + } +} + +unsafe impl RawRwLockUpgradeTimed for TracingWrapper +where + T: RawRwLockUpgradeTimed, +{ + fn try_lock_upgradable_for(&self, timeout: Self::Duration) -> bool { + self.conditionally_lock(|| self.inner.try_lock_upgradable_for(timeout)) + } + + fn try_lock_upgradable_until(&self, timeout: Self::Instant) -> bool { + self.conditionally_lock(|| self.inner.try_lock_upgradable_until(timeout)) + } + + unsafe fn try_upgrade_for(&self, timeout: Self::Duration) -> bool { + self.conditionally_lock(|| self.inner.try_upgrade_for(timeout)) + } + + unsafe fn try_upgrade_until(&self, timeout: Self::Instant) -> bool { + self.conditionally_lock(|| self.inner.try_upgrade_until(timeout)) + } +}