const std = @import("../std.zig");
const builtin = @import("builtin");
const Mutex = @This();
const os = std.os;
const assert = std.debug.assert;
const testing = std.testing;
const Atomic = std.atomic.Atomic;
const Thread = std.Thread;
const Futex = Thread.Futex;
impl: Impl = .{},
pub fn tryLock(self: *Mutex) bool {
return self.impl.tryLock();
}
pub fn lock(self: *Mutex) void {
self.impl.lock();
}
pub fn unlock(self: *Mutex) void {
self.impl.unlock();
}
const Impl = if (builtin.mode == .Debug and !builtin.single_threaded)
DebugImpl
else
ReleaseImpl;
const ReleaseImpl = if (builtin.single_threaded)
SingleThreadedImpl
else if (builtin.os.tag == .windows)
WindowsImpl
else if (builtin.os.tag.isDarwin())
DarwinImpl
else
FutexImpl;
const DebugImpl = struct {
locking_thread: Atomic(Thread.Id) = Atomic(Thread.Id).init(0),
impl: ReleaseImpl = .{},
inline fn tryLock(self: *@This()) bool {
const locking = self.impl.tryLock();
if (locking) {
self.locking_thread.store(Thread.getCurrentId(), .Unordered);
}
return locking;
}
inline fn lock(self: *@This()) void {
const current_id = Thread.getCurrentId();
if (self.locking_thread.load(.Unordered) == current_id and current_id != 0) {
@panic("Deadlock detected");
}
self.impl.lock();
self.locking_thread.store(current_id, .Unordered);
}
inline fn unlock(self: *@This()) void {
assert(self.locking_thread.load(.Unordered) == Thread.getCurrentId());
self.locking_thread.store(0, .Unordered);
self.impl.unlock();
}
};
const SingleThreadedImpl = struct {
is_locked: bool = false,
fn tryLock(self: *@This()) bool {
if (self.is_locked) return false;
self.is_locked = true;
return true;
}
fn lock(self: *@This()) void {
if (!self.tryLock()) {
unreachable;
}
}
fn unlock(self: *@This()) void {
assert(self.is_locked);
self.is_locked = false;
}
};
const WindowsImpl = struct {
srwlock: os.windows.SRWLOCK = .{},
fn tryLock(self: *@This()) bool {
return os.windows.kernel32.TryAcquireSRWLockExclusive(&self.srwlock) != os.windows.FALSE;
}
fn lock(self: *@This()) void {
os.windows.kernel32.AcquireSRWLockExclusive(&self.srwlock);
}
fn unlock(self: *@This()) void {
os.windows.kernel32.ReleaseSRWLockExclusive(&self.srwlock);
}
};
const DarwinImpl = struct {
oul: os.darwin.os_unfair_lock = .{},
fn tryLock(self: *@This()) bool {
return os.darwin.os_unfair_lock_trylock(&self.oul);
}
fn lock(self: *@This()) void {
os.darwin.os_unfair_lock_lock(&self.oul);
}
fn unlock(self: *@This()) void {
os.darwin.os_unfair_lock_unlock(&self.oul);
}
};
const FutexImpl = struct {
state: Atomic(u32) = Atomic(u32).init(unlocked),
const unlocked = 0b00;
const locked = 0b01;
const contended = 0b11;
fn tryLock(self: *@This()) bool {
return self.lockFast("compareAndSwap");
}
fn lock(self: *@This()) void {
if (!self.lockFast("tryCompareAndSwap")) {
self.lockSlow();
}
}
inline fn lockFast(self: *@This(), comptime casFn: []const u8) bool {
if (comptime builtin.target.cpu.arch.isX86()) {
const locked_bit = @ctz(@as(u32, locked));
return self.state.bitSet(locked_bit, .Acquire) == 0;
}
return @field(self.state, casFn)(unlocked, locked, .Acquire, .Monotonic) == null;
}
fn lockSlow(self: *@This()) void {
@setCold(true);
if (self.state.load(.Monotonic) == contended) {
Futex.wait(&self.state, contended);
}
while (self.state.swap(contended, .Acquire) != unlocked) {
Futex.wait(&self.state, contended);
}
}
fn unlock(self: *@This()) void {
const state = self.state.swap(unlocked, .Release);
assert(state != unlocked);
if (state == contended) {
Futex.wake(&self.state, 1);
}
}
};
test "Mutex - smoke test" {
var mutex = Mutex{};
try testing.expect(mutex.tryLock());
try testing.expect(!mutex.tryLock());
mutex.unlock();
mutex.lock();
try testing.expect(!mutex.tryLock());
mutex.unlock();
}
const NonAtomicCounter = struct {
value: [2]u64 = [_]u64{ 0, 0 },
fn get(self: NonAtomicCounter) u128 {
return @bitCast(u128, self.value);
}
fn inc(self: *NonAtomicCounter) void {
for (@bitCast([2]u64, self.get() + 1), 0..) |v, i| {
@ptrCast(*volatile u64, &self.value[i]).* = v;
}
}
};
test "Mutex - many uncontended" {
if (builtin.single_threaded) {
return error.SkipZigTest;
}
const num_threads = 4;
const num_increments = 1000;
const Runner = struct {
mutex: Mutex = .{},
thread: Thread = undefined,
counter: NonAtomicCounter = .{},
fn run(self: *@This()) void {
var i: usize = num_increments;
while (i > 0) : (i -= 1) {
self.mutex.lock();
defer self.mutex.unlock();
self.counter.inc();
}
}
};
var runners = [_]Runner{.{}} ** num_threads;
for (&runners) |*r| r.thread = try Thread.spawn(.{}, Runner.run, .{r});
for (runners) |r| r.thread.join();
for (runners) |r| try testing.expectEqual(r.counter.get(), num_increments);
}
test "Mutex - many contended" {
if (builtin.single_threaded) {
return error.SkipZigTest;
}
const num_threads = 4;
const num_increments = 1000;
const Runner = struct {
mutex: Mutex = .{},
counter: NonAtomicCounter = .{},
fn run(self: *@This()) void {
var i: usize = num_increments;
while (i > 0) : (i -= 1) {
defer if (i % 100 == 0) Thread.yield() catch {};
self.mutex.lock();
defer self.mutex.unlock();
self.counter.inc();
}
}
};
var runner = Runner{};
var threads: [num_threads]Thread = undefined;
for (&threads) |*t| t.* = try Thread.spawn(.{}, Runner.run, .{&runner});
for (threads) |t| t.join();
try testing.expectEqual(runner.counter.get(), num_increments * num_threads);
}