const std = @import("../std.zig");
const builtin = @import("builtin");
const assert = std.debug.assert;
const testing = std.testing;
const mem = std.mem;
const Loop = std.event.Loop;
const Allocator = std.mem.Allocator;
pub const RwLock = struct {
shared_state: State,
writer_queue: Queue,
reader_queue: Queue,
writer_queue_empty: bool,
reader_queue_empty: bool,
reader_lock_count: usize,
const State = enum(u8) {
Unlocked,
WriteLock,
ReadLock,
};
const Queue = std.atomic.Queue(anyframe);
const global_event_loop = Loop.instance orelse
@compileError("std.event.RwLock currently only works with event-based I/O");
pub const HeldRead = struct {
lock: *RwLock,
pub fn release(self: HeldRead) void {
if (@atomicRmw(usize, &self.lock.reader_lock_count, .Sub, 1, .SeqCst) != 1) {
return;
}
@atomicStore(bool, &self.lock.reader_queue_empty, true, .SeqCst);
if (@cmpxchgStrong(State, &self.lock.shared_state, .ReadLock, .Unlocked, .SeqCst, .SeqCst) != null) {
return;
}
self.lock.commonPostUnlock();
}
};
pub const HeldWrite = struct {
lock: *RwLock,
pub fn release(self: HeldWrite) void {
if (self.lock.writer_queue.get()) |node| {
global_event_loop.onNextTick(node);
return;
}
if (!@atomicLoad(bool, &self.lock.reader_queue_empty, .SeqCst)) {
@atomicStore(State, &self.lock.shared_state, .ReadLock, .SeqCst);
while (self.lock.reader_queue.get()) |node| {
global_event_loop.onNextTick(node);
}
return;
}
@atomicStore(bool, &self.lock.writer_queue_empty, true, .SeqCst);
@atomicStore(State, &self.lock.shared_state, .Unlocked, .SeqCst);
self.lock.commonPostUnlock();
}
};
pub fn init() RwLock {
return .{
.shared_state = .Unlocked,
.writer_queue = Queue.init(),
.writer_queue_empty = true,
.reader_queue = Queue.init(),
.reader_queue_empty = true,
.reader_lock_count = 0,
};
}
pub fn deinit(self: *RwLock) void {
assert(self.shared_state == .Unlocked);
while (self.writer_queue.get()) |node| resume node.data;
while (self.reader_queue.get()) |node| resume node.data;
}
pub fn acquireRead(self: *RwLock) callconv(.Async) HeldRead {
_ = @atomicRmw(usize, &self.reader_lock_count, .Add, 1, .SeqCst);
suspend {
var my_tick_node = Loop.NextTickNode{
.data = @frame(),
.prev = undefined,
.next = undefined,
};
self.reader_queue.put(&my_tick_node);
@atomicStore(bool, &self.reader_queue_empty, false, .SeqCst);
const have_read_lock = if (@cmpxchgStrong(State, &self.shared_state, .Unlocked, .ReadLock, .SeqCst, .SeqCst)) |old_state| old_state == .ReadLock else true;
if (have_read_lock) {
if (self.reader_queue.get()) |first_node| {
while (self.reader_queue.get()) |node| {
global_event_loop.onNextTick(node);
}
resume first_node.data;
}
}
}
return HeldRead{ .lock = self };
}
pub fn acquireWrite(self: *RwLock) callconv(.Async) HeldWrite {
suspend {
var my_tick_node = Loop.NextTickNode{
.data = @frame(),
.prev = undefined,
.next = undefined,
};
self.writer_queue.put(&my_tick_node);
@atomicStore(bool, &self.writer_queue_empty, false, .SeqCst);
if (@cmpxchgStrong(State, &self.shared_state, .Unlocked, .WriteLock, .SeqCst, .SeqCst) == null) {
if (self.writer_queue.get()) |node| {
resume node.data;
}
}
}
return HeldWrite{ .lock = self };
}
fn commonPostUnlock(self: *RwLock) void {
while (true) {
if (!@atomicLoad(bool, &self.writer_queue_empty, .SeqCst)) {
if (@cmpxchgStrong(State, &self.shared_state, .Unlocked, .WriteLock, .SeqCst, .SeqCst) != null) {
return;
}
if (self.writer_queue.get()) |node| {
global_event_loop.onNextTick(node);
return;
}
@atomicStore(bool, &self.writer_queue_empty, true, .SeqCst);
@atomicStore(State, &self.shared_state, .Unlocked, .SeqCst);
continue;
}
if (!@atomicLoad(bool, &self.reader_queue_empty, .SeqCst)) {
if (@cmpxchgStrong(State, &self.shared_state, .Unlocked, .ReadLock, .SeqCst, .SeqCst) != null) {
return;
}
if (self.reader_queue.get()) |first_node| {
global_event_loop.onNextTick(first_node);
while (self.reader_queue.get()) |node| {
global_event_loop.onNextTick(node);
}
return;
}
@atomicStore(bool, &self.reader_queue_empty, true, .SeqCst);
if (@cmpxchgStrong(State, &self.shared_state, .ReadLock, .Unlocked, .SeqCst, .SeqCst) != null) {
return;
}
continue;
}
return;
}
}
};
test "std.event.RwLock" {
if (true) return error.SkipZigTest;
if (builtin.single_threaded) return error.SkipZigTest;
if (!std.io.is_async) return error.SkipZigTest;
var lock = RwLock.init();
defer lock.deinit();
_ = testLock(std.heap.page_allocator, &lock);
const expected_result = [1]i32{shared_it_count * @intCast(i32, shared_test_data.len)} ** shared_test_data.len;
try testing.expectEqualSlices(i32, expected_result, shared_test_data);
}
fn testLock(allocator: Allocator, lock: *RwLock) callconv(.Async) void {
var read_nodes: [100]Loop.NextTickNode = undefined;
for (read_nodes) |*read_node| {
const frame = allocator.create(@Frame(readRunner)) catch @panic("memory");
read_node.data = frame;
frame.* = async readRunner(lock);
Loop.instance.?.onNextTick(read_node);
}
var write_nodes: [shared_it_count]Loop.NextTickNode = undefined;
for (write_nodes) |*write_node| {
const frame = allocator.create(@Frame(writeRunner)) catch @panic("memory");
write_node.data = frame;
frame.* = async writeRunner(lock);
Loop.instance.?.onNextTick(write_node);
}
for (write_nodes) |*write_node| {
const casted = @ptrCast(*const @Frame(writeRunner), write_node.data);
await casted;
allocator.destroy(casted);
}
for (read_nodes) |*read_node| {
const casted = @ptrCast(*const @Frame(readRunner), read_node.data);
await casted;
allocator.destroy(casted);
}
}
const shared_it_count = 10;
var shared_test_data = [1]i32{0} ** 10;
var shared_test_index: usize = 0;
var shared_count: usize = 0;
fn writeRunner(lock: *RwLock) callconv(.Async) void {
suspend {}
var i: usize = 0;
while (i < shared_test_data.len) : (i += 1) {
std.time.sleep(100 * std.time.microsecond);
const lock_promise = async lock.acquireWrite();
const handle = await lock_promise;
defer handle.release();
shared_count += 1;
while (shared_test_index < shared_test_data.len) : (shared_test_index += 1) {
shared_test_data[shared_test_index] = shared_test_data[shared_test_index] + 1;
}
shared_test_index = 0;
}
}
fn readRunner(lock: *RwLock) callconv(.Async) void {
suspend {}
std.time.sleep(1);
var i: usize = 0;
while (i < shared_test_data.len) : (i += 1) {
const lock_promise = async lock.acquireRead();
const handle = await lock_promise;
defer handle.release();
try testing.expect(shared_test_index == 0);
try testing.expect(shared_test_data[i] == @intCast(i32, shared_count));
}
}