const std = @import("../std.zig");
const assert = std.debug.assert;
const mem = std.mem;
const Allocator = std.mem.Allocator;
pub const ArenaAllocator = struct {
child_allocator: Allocator,
state: State,
pub const State = struct {
buffer_list: std.SinglyLinkedList(usize) = .{},
end_index: usize = 0,
pub fn promote(self: State, child_allocator: Allocator) ArenaAllocator {
return .{
.child_allocator = child_allocator,
.state = self,
};
}
};
pub fn allocator(self: *ArenaAllocator) Allocator {
return .{
.ptr = self,
.vtable = &.{
.alloc = alloc,
.resize = resize,
.free = free,
},
};
}
const BufNode = std.SinglyLinkedList(usize).Node;
pub fn init(child_allocator: Allocator) ArenaAllocator {
return (State{}).promote(child_allocator);
}
pub fn deinit(self: ArenaAllocator) void {
var it = self.state.buffer_list.first;
while (it) |node| {
const next_it = node.next;
const align_bits = std.math.log2_int(usize, @alignOf(BufNode));
const alloc_buf = @ptrCast([*]u8, node)[0..node.data];
self.child_allocator.rawFree(alloc_buf, align_bits, @returnAddress());
it = next_it;
}
}
pub const ResetMode = union(enum) {
free_all,
retain_capacity,
retain_with_limit: usize,
};
pub fn queryCapacity(self: ArenaAllocator) usize {
var size: usize = 0;
var it = self.state.buffer_list.first;
while (it) |node| : (it = node.next) {
size += node.data - @sizeOf(BufNode);
}
return size;
}
pub fn reset(self: *ArenaAllocator, mode: ResetMode) bool {
const current_capacity = if (mode != .free_all)
@sizeOf(BufNode) + self.queryCapacity()
else
0;
if (mode == .free_all or current_capacity == 0) {
self.deinit();
self.state = State{};
return true;
}
const total_size = switch (mode) {
.retain_capacity => current_capacity,
.retain_with_limit => |limit| std.math.min(limit, current_capacity),
.free_all => unreachable,
};
const align_bits = std.math.log2_int(usize, @alignOf(BufNode));
var it = self.state.buffer_list.first;
const maybe_first_node = while (it) |node| {
const next_it = node.next;
if (next_it == null)
break node;
const alloc_buf = @ptrCast([*]u8, node)[0..node.data];
self.child_allocator.rawFree(alloc_buf, align_bits, @returnAddress());
it = next_it;
} else null;
std.debug.assert(maybe_first_node == null or maybe_first_node.?.next == null);
self.state.end_index = 0;
if (maybe_first_node) |first_node| {
if (first_node.data == total_size)
return true;
const first_alloc_buf = @ptrCast([*]u8, first_node)[0..first_node.data];
if (self.child_allocator.rawResize(first_alloc_buf, align_bits, total_size, @returnAddress())) {
first_node.data = total_size;
} else {
const new_ptr = self.child_allocator.rawAlloc(total_size, align_bits, @returnAddress()) orelse {
return false;
};
self.child_allocator.rawFree(first_alloc_buf, align_bits, @returnAddress());
const node = @ptrCast(*BufNode, @alignCast(@alignOf(BufNode), new_ptr));
node.* = .{ .data = total_size };
self.state.buffer_list.first = node;
}
}
return true;
}
fn createNode(self: *ArenaAllocator, prev_len: usize, minimum_size: usize) ?*BufNode {
const actual_min_size = minimum_size + (@sizeOf(BufNode) + 16);
const big_enough_len = prev_len + actual_min_size;
const len = big_enough_len + big_enough_len / 2;
const log2_align = comptime std.math.log2_int(usize, @alignOf(BufNode));
const ptr = self.child_allocator.rawAlloc(len, log2_align, @returnAddress()) orelse
return null;
const buf_node = @ptrCast(*BufNode, @alignCast(@alignOf(BufNode), ptr));
buf_node.* = .{ .data = len };
self.state.buffer_list.prepend(buf_node);
self.state.end_index = 0;
return buf_node;
}
fn alloc(ctx: *anyopaque, n: usize, log2_ptr_align: u8, ra: usize) ?[*]u8 {
const self = @ptrCast(*ArenaAllocator, @alignCast(@alignOf(ArenaAllocator), ctx));
_ = ra;
const ptr_align = @as(usize, 1) << @intCast(Allocator.Log2Align, log2_ptr_align);
var cur_node = if (self.state.buffer_list.first) |first_node|
first_node
else
(self.createNode(0, n + ptr_align) orelse return null);
while (true) {
const cur_alloc_buf = @ptrCast([*]u8, cur_node)[0..cur_node.data];
const cur_buf = cur_alloc_buf[@sizeOf(BufNode)..];
const addr = @ptrToInt(cur_buf.ptr) + self.state.end_index;
const adjusted_addr = mem.alignForward(addr, ptr_align);
const adjusted_index = self.state.end_index + (adjusted_addr - addr);
const new_end_index = adjusted_index + n;
if (new_end_index <= cur_buf.len) {
const result = cur_buf[adjusted_index..new_end_index];
self.state.end_index = new_end_index;
return result.ptr;
}
const bigger_buf_size = @sizeOf(BufNode) + new_end_index;
const log2_align = comptime std.math.log2_int(usize, @alignOf(BufNode));
if (self.child_allocator.rawResize(cur_alloc_buf, log2_align, bigger_buf_size, @returnAddress())) {
cur_node.data = bigger_buf_size;
} else {
cur_node = self.createNode(cur_buf.len, n + ptr_align) orelse return null;
}
}
}
fn resize(ctx: *anyopaque, buf: []u8, log2_buf_align: u8, new_len: usize, ret_addr: usize) bool {
const self = @ptrCast(*ArenaAllocator, @alignCast(@alignOf(ArenaAllocator), ctx));
_ = log2_buf_align;
_ = ret_addr;
const cur_node = self.state.buffer_list.first orelse return false;
const cur_buf = @ptrCast([*]u8, cur_node)[@sizeOf(BufNode)..cur_node.data];
if (@ptrToInt(cur_buf.ptr) + self.state.end_index != @ptrToInt(buf.ptr) + buf.len) {
return new_len <= buf.len;
}
if (buf.len >= new_len) {
self.state.end_index -= buf.len - new_len;
return true;
} else if (cur_buf.len - self.state.end_index >= new_len - buf.len) {
self.state.end_index += new_len - buf.len;
return true;
} else {
return false;
}
}
fn free(ctx: *anyopaque, buf: []u8, log2_buf_align: u8, ret_addr: usize) void {
_ = log2_buf_align;
_ = ret_addr;
const self = @ptrCast(*ArenaAllocator, @alignCast(@alignOf(ArenaAllocator), ctx));
const cur_node = self.state.buffer_list.first orelse return;
const cur_buf = @ptrCast([*]u8, cur_node)[@sizeOf(BufNode)..cur_node.data];
if (@ptrToInt(cur_buf.ptr) + self.state.end_index == @ptrToInt(buf.ptr) + buf.len) {
self.state.end_index -= buf.len;
}
}
};
test "ArenaAllocator (reset with preheating)" {
var arena_allocator = ArenaAllocator.init(std.testing.allocator);
defer arena_allocator.deinit();
var rng_src = std.rand.DefaultPrng.init(19930913);
const random = rng_src.random();
var rounds: usize = 25;
while (rounds > 0) {
rounds -= 1;
_ = arena_allocator.reset(.retain_capacity);
var alloced_bytes: usize = 0;
var total_size: usize = random.intRangeAtMost(usize, 256, 16384);
while (alloced_bytes < total_size) {
const size = random.intRangeAtMost(usize, 16, 256);
const alignment = 32;
const slice = try arena_allocator.allocator().alignedAlloc(u8, alignment, size);
try std.testing.expect(std.mem.isAligned(@ptrToInt(slice.ptr), alignment));
try std.testing.expectEqual(size, slice.len);
alloced_bytes += slice.len;
}
}
}