const std = @import("../std.zig");
const builtin = @import("builtin");
const assert = std.debug.assert;
const testing = std.testing;
const Loop = std.event.Loop;
pub fn Channel(comptime T: type) type {
return struct {
getters: std.atomic.Queue(GetNode),
or_null_queue: std.atomic.Queue(*std.atomic.Queue(GetNode).Node),
putters: std.atomic.Queue(PutNode),
get_count: usize,
put_count: usize,
dispatch_lock: bool,
need_dispatch: bool,
buffer_nodes: []T,
buffer_index: usize,
buffer_len: usize,
const SelfChannel = @This();
const GetNode = struct {
tick_node: *Loop.NextTickNode,
data: Data,
const Data = union(enum) {
Normal: Normal,
OrNull: OrNull,
};
const Normal = struct {
ptr: *T,
};
const OrNull = struct {
ptr: *?T,
or_null: *std.atomic.Queue(*std.atomic.Queue(GetNode).Node).Node,
};
};
const PutNode = struct {
data: T,
tick_node: *Loop.NextTickNode,
};
const global_event_loop = Loop.instance orelse
@compileError("std.event.Channel currently only works with event-based I/O");
pub fn init(self: *SelfChannel, buffer: []T) void {
assert(buffer.len == 0 or @popCount(buffer.len) == 1);
self.* = SelfChannel{
.buffer_len = 0,
.buffer_nodes = buffer,
.buffer_index = 0,
.dispatch_lock = false,
.need_dispatch = false,
.getters = std.atomic.Queue(GetNode).init(),
.putters = std.atomic.Queue(PutNode).init(),
.or_null_queue = std.atomic.Queue(*std.atomic.Queue(GetNode).Node).init(),
.get_count = 0,
.put_count = 0,
};
}
pub fn deinit(self: *SelfChannel) void {
while (self.getters.get()) |get_node| {
resume get_node.data.tick_node.data;
}
while (self.putters.get()) |put_node| {
resume put_node.data.tick_node.data;
}
self.* = undefined;
}
pub fn put(self: *SelfChannel, data: T) void {
var my_tick_node = Loop.NextTickNode{ .data = @frame() };
var queue_node = std.atomic.Queue(PutNode).Node{
.data = PutNode{
.tick_node = &my_tick_node,
.data = data,
},
};
suspend {
self.putters.put(&queue_node);
_ = @atomicRmw(usize, &self.put_count, .Add, 1, .SeqCst);
self.dispatch();
}
}
pub fn get(self: *SelfChannel) callconv(.Async) T {
var result: T = undefined;
var my_tick_node = Loop.NextTickNode{ .data = @frame() };
var queue_node = std.atomic.Queue(GetNode).Node{
.data = GetNode{
.tick_node = &my_tick_node,
.data = GetNode.Data{
.Normal = GetNode.Normal{ .ptr = &result },
},
},
};
suspend {
self.getters.put(&queue_node);
_ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);
self.dispatch();
}
return result;
}
pub fn getOrNull(self: *SelfChannel) ?T {
var result: ?T = null;
var my_tick_node = Loop.NextTickNode{ .data = @frame() };
var or_null_node = std.atomic.Queue(*std.atomic.Queue(GetNode).Node).Node{ .data = undefined };
var queue_node = std.atomic.Queue(GetNode).Node{
.data = GetNode{
.tick_node = &my_tick_node,
.data = GetNode.Data{
.OrNull = GetNode.OrNull{
.ptr = &result,
.or_null = &or_null_node,
},
},
},
};
or_null_node.data = &queue_node;
suspend {
self.getters.put(&queue_node);
_ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);
self.or_null_queue.put(&or_null_node);
self.dispatch();
}
return result;
}
fn dispatch(self: *SelfChannel) void {
@atomicStore(bool, &self.need_dispatch, true, .SeqCst);
lock: while (true) {
if (@atomicRmw(bool, &self.dispatch_lock, .Xchg, true, .SeqCst)) return;
@atomicStore(bool, &self.need_dispatch, false, .SeqCst);
while (true) {
one_dispatch: {
var get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
var put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);
while (self.buffer_len != 0) {
if (get_count == 0) break :one_dispatch;
const get_node = &self.getters.get().?.data;
switch (get_node.data) {
GetNode.Data.Normal => |info| {
info.ptr.* = self.buffer_nodes[(self.buffer_index -% self.buffer_len) % self.buffer_nodes.len];
},
GetNode.Data.OrNull => |info| {
_ = self.or_null_queue.remove(info.or_null);
info.ptr.* = self.buffer_nodes[(self.buffer_index -% self.buffer_len) % self.buffer_nodes.len];
},
}
global_event_loop.onNextTick(get_node.tick_node);
self.buffer_len -= 1;
get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
}
while (get_count != 0 and put_count != 0) {
const get_node = &self.getters.get().?.data;
const put_node = &self.putters.get().?.data;
switch (get_node.data) {
GetNode.Data.Normal => |info| {
info.ptr.* = put_node.data;
},
GetNode.Data.OrNull => |info| {
_ = self.or_null_queue.remove(info.or_null);
info.ptr.* = put_node.data;
},
}
global_event_loop.onNextTick(get_node.tick_node);
global_event_loop.onNextTick(put_node.tick_node);
get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);
}
while (self.buffer_len != self.buffer_nodes.len and put_count != 0) {
const put_node = &self.putters.get().?.data;
self.buffer_nodes[self.buffer_index % self.buffer_nodes.len] = put_node.data;
global_event_loop.onNextTick(put_node.tick_node);
self.buffer_index +%= 1;
self.buffer_len += 1;
put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);
}
}
_ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);
_ = @atomicRmw(usize, &self.put_count, .Add, 1, .SeqCst);
var remove_count: usize = 0;
while (self.or_null_queue.get()) |or_null_node| {
remove_count += @boolToInt(self.getters.remove(or_null_node.data));
global_event_loop.onNextTick(or_null_node.data.data.tick_node);
}
if (remove_count != 0) {
_ = @atomicRmw(usize, &self.get_count, .Sub, remove_count, .SeqCst);
}
if (@atomicRmw(bool, &self.need_dispatch, .Xchg, false, .SeqCst)) continue;
assert(@atomicRmw(bool, &self.dispatch_lock, .Xchg, false, .SeqCst));
if (@atomicLoad(bool, &self.need_dispatch, .SeqCst)) continue :lock;
return;
}
}
}
};
}
test "std.event.Channel" {
if (!std.io.is_async) return error.SkipZigTest;
if (builtin.single_threaded) return error.SkipZigTest;
if (builtin.os.tag == .freebsd) return error.SkipZigTest;
var channel: Channel(i32) = undefined;
channel.init(&[0]i32{});
defer channel.deinit();
var handle = async testChannelGetter(&channel);
var putter = async testChannelPutter(&channel);
await handle;
await putter;
}
test "std.event.Channel wraparound" {
if (!std.io.is_async) return error.SkipZigTest;
const channel_size = 2;
var buf: [channel_size]i32 = undefined;
var channel: Channel(i32) = undefined;
channel.init(&buf);
defer channel.deinit();
channel.put(5);
try testing.expectEqual(@as(i32, 5), channel.get());
channel.put(6);
try testing.expectEqual(@as(i32, 6), channel.get());
channel.put(7);
try testing.expectEqual(@as(i32, 7), channel.get());
}
fn testChannelGetter(channel: *Channel(i32)) callconv(.Async) void {
const value1 = channel.get();
try testing.expect(value1 == 1234);
const value2 = channel.get();
try testing.expect(value2 == 4567);
const value3 = channel.getOrNull();
try testing.expect(value3 == null);
var last_put = async testPut(channel, 4444);
const value4 = channel.getOrNull();
try testing.expect(value4.? == 4444);
await last_put;
}
fn testChannelPutter(channel: *Channel(i32)) callconv(.Async) void {
channel.put(1234);
channel.put(4567);
}
fn testPut(channel: *Channel(i32), value: i32) callconv(.Async) void {
channel.put(value);
}