const std = @import("../std.zig");
const testing = std.testing;
const http = std.http;
const mem = std.mem;
const net = std.net;
const Uri = std.Uri;
const Allocator = mem.Allocator;
const assert = std.debug.assert;

const Server = @This();
const proto = @import("protocol.zig");

allocator: Allocator,

socket: net.StreamServer,

/// An interface to either a plain or TLS connection.
pub const Connection = struct {
    stream: net.Stream,
    protocol: Protocol,

    closing: bool = true,

    pub const Protocol = enum { plain };

    pub fn read(conn: *Connection, buffer: []u8) !usize {
        switch (conn.protocol) {
            .plain => return conn.stream.read(buffer),
            // .tls => return conn.tls_client.read(conn.stream, buffer),

        }
    }

    pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
        switch (conn.protocol) {
            .plain => return conn.stream.readAtLeast(buffer, len),
            // .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),

        }
    }

    pub const ReadError = net.Stream.ReadError;

    pub const Reader = std.io.Reader(*Connection, ReadError, read);

    pub fn reader(conn: *Connection) Reader {
        return Reader{ .context = conn };
    }

    pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
        switch (conn.protocol) {
            .plain => return conn.stream.writeAll(buffer),
            // .tls => return conn.tls_client.writeAll(conn.stream, buffer),

        }
    }

    pub fn write(conn: *Connection, buffer: []const u8) !usize {
        switch (conn.protocol) {
            .plain => return conn.stream.write(buffer),
            // .tls => return conn.tls_client.write(conn.stream, buffer),

        }
    }

    pub const WriteError = net.Stream.WriteError || error{};
    pub const Writer = std.io.Writer(*Connection, WriteError, write);

    pub fn writer(conn: *Connection) Writer {
        return Writer{ .context = conn };
    }

    pub fn close(conn: *Connection) void {
        conn.stream.close();
    }
};

/// A buffered (and peekable) Connection.
pub const BufferedConnection = struct {
    pub const buffer_size = 0x2000;

    conn: Connection,
    buf: [buffer_size]u8 = undefined,
    start: u16 = 0,
    end: u16 = 0,

    pub fn fill(bconn: *BufferedConnection) ReadError!void {
        if (bconn.end != bconn.start) return;

        const nread = try bconn.conn.read(bconn.buf[0..]);
        if (nread == 0) return error.EndOfStream;
        bconn.start = 0;
        bconn.end = @truncate(u16, nread);
    }

    pub fn peek(bconn: *BufferedConnection) []const u8 {
        return bconn.buf[bconn.start..bconn.end];
    }

    pub fn clear(bconn: *BufferedConnection, num: u16) void {
        bconn.start += num;
    }

    pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
        var out_index: u16 = 0;
        while (out_index < len) {
            const available = bconn.end - bconn.start;
            const left = buffer.len - out_index;

            if (available > 0) {
                const can_read = @truncate(u16, @min(available, left));

                std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]);
                out_index += can_read;
                bconn.start += can_read;

                continue;
            }

            if (left > bconn.buf.len) {
                // skip the buffer if the output is large enough

                return bconn.conn.read(buffer[out_index..]);
            }

            try bconn.fill();
        }

        return out_index;
    }

    pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
        return bconn.readAtLeast(buffer, 1);
    }

    pub const ReadError = Connection.ReadError || error{EndOfStream};
    pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);

    pub fn reader(bconn: *BufferedConnection) Reader {
        return Reader{ .context = bconn };
    }

    pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
        return bconn.conn.writeAll(buffer);
    }

    pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
        return bconn.conn.write(buffer);
    }

    pub const WriteError = Connection.WriteError;
    pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);

    pub fn writer(bconn: *BufferedConnection) Writer {
        return Writer{ .context = bconn };
    }

    pub fn close(bconn: *BufferedConnection) void {
        bconn.conn.close();
    }
};

/// A HTTP request originating from a client.
pub const Request = struct {
    pub const Headers = struct {
        method: http.Method,
        target: []const u8,
        version: http.Version,
        content_length: ?u64 = null,
        transfer_encoding: ?http.TransferEncoding = null,
        transfer_compression: ?http.ContentEncoding = null,
        connection: http.Connection = .close,
        host: ?[]const u8 = null,

        pub const ParseError = error{
            ShortHttpStatusLine,
            BadHttpVersion,
            UnknownHttpMethod,
            HttpHeadersInvalid,
            HttpHeaderContinuationsUnsupported,
            HttpTransferEncodingUnsupported,
            HttpConnectionHeaderUnsupported,
            InvalidCharacter,
        };

        pub fn parse(bytes: []const u8) !Headers {
            var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");

            const first_line = it.next() orelse return error.HttpHeadersInvalid;
            if (first_line.len < 10)
                return error.ShortHttpStatusLine;

            const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid;
            const method_str = first_line[0..method_end];
            const method = std.meta.stringToEnum(http.Method, method_str) orelse return error.UnknownHttpMethod;

            const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid;
            if (version_start == method_end) return error.HttpHeadersInvalid;

            const version_str = first_line[version_start + 1 ..];
            if (version_str.len != 8) return error.HttpHeadersInvalid;
            const version: http.Version = switch (int64(version_str[0..8])) {
                int64("HTTP/1.0") => .@"HTTP/1.0",
                int64("HTTP/1.1") => .@"HTTP/1.1",
                else => return error.BadHttpVersion,
            };

            const target = first_line[method_end + 1 .. version_start];

            var headers: Headers = .{
                .method = method,
                .target = target,
                .version = version,
            };

            while (it.next()) |line| {
                if (line.len == 0) return error.HttpHeadersInvalid;
                switch (line[0]) {
                    ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
                    else => {},
                }

                var line_it = mem.tokenize(u8, line, ": ");
                const header_name = line_it.next() orelse return error.HttpHeadersInvalid;
                const header_value = line_it.rest();
                if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
                    if (headers.content_length != null) return error.HttpHeadersInvalid;
                    headers.content_length = try std.fmt.parseInt(u64, header_value, 10);
                } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
                    // Transfer-Encoding: second, first

                    // Transfer-Encoding: deflate, chunked

                    var iter = mem.splitBackwards(u8, header_value, ",");

                    if (iter.next()) |first| {
                        const trimmed = mem.trim(u8, first, " ");

                        if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| {
                            if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
                            headers.transfer_encoding = te;
                        } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
                            if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
                            headers.transfer_compression = ce;
                        } else {
                            return error.HttpTransferEncodingUnsupported;
                        }
                    }

                    if (iter.next()) |second| {
                        if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;

                        const trimmed = mem.trim(u8, second, " ");

                        if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
                            headers.transfer_compression = ce;
                        } else {
                            return error.HttpTransferEncodingUnsupported;
                        }
                    }

                    if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
                } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
                    if (headers.transfer_compression != null) return error.HttpHeadersInvalid;

                    const trimmed = mem.trim(u8, header_value, " ");

                    if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
                        headers.transfer_compression = ce;
                    } else {
                        return error.HttpTransferEncodingUnsupported;
                    }
                } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
                    if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
                        headers.connection = .keep_alive;
                    } else if (std.ascii.eqlIgnoreCase(header_value, "close")) {
                        headers.connection = .close;
                    } else {
                        return error.HttpConnectionHeaderUnsupported;
                    }
                } else if (std.ascii.eqlIgnoreCase(header_name, "host")) {
                    headers.host = header_value;
                }
            }

            return headers;
        }

        inline fn int64(array: *const [8]u8) u64 {
            return @bitCast(u64, array.*);
        }
    };

    headers: Headers = undefined,
    parser: proto.HeadersParser,
    compression: Compression = .none,
};

/// A HTTP response waiting to be sent.
///
///                                  [/ <----------------------------------- \]
/// Order of operations: accept -> wait -> do  [ -> write -> finish][ -> reset /]
///                                   \ -> read /
pub const Response = struct {
    pub const Headers = struct {
        version: http.Version = .@"HTTP/1.1",
        status: http.Status = .ok,
        reason: ?[]const u8 = null,

        server: ?[]const u8 = "zig (std.http)",
        connection: http.Connection = .keep_alive,
        transfer_encoding: RequestTransfer = .none,

        custom: []const http.CustomHeader = &[_]http.CustomHeader{},
    };

    server: *Server,
    address: net.Address,
    connection: BufferedConnection,

    headers: Headers = .{},
    request: Request,

    /// Reset this response to its initial state. This must be called before handling a second request on the same connection.
    pub fn reset(res: *Response) void {
        switch (res.request.compression) {
            .none => {},
            .deflate => |*deflate| deflate.deinit(),
            .gzip => |*gzip| gzip.deinit(),
            .zstd => |*zstd| zstd.deinit(),
        }

        if (!res.request.parser.done) {
            // If the response wasn't fully read, then we need to close the connection.

            res.connection.conn.closing = true;
        }

        if (res.connection.conn.closing) {
            res.connection.close();

            if (res.request.parser.header_bytes_owned) {
                res.request.parser.header_bytes.deinit(res.server.allocator);
            }

            res.* = undefined;
        } else {
            res.request.parser.reset();
        }
    }

    /// Send the response headers.
    pub fn do(res: *Response) !void {
        var buffered = std.io.bufferedWriter(res.connection.writer());
        const w = buffered.writer();

        try w.writeAll(@tagName(res.headers.version));
        try w.writeByte(' ');
        try w.print("{d}", .{@enumToInt(res.headers.status)});
        try w.writeByte(' ');
        if (res.headers.reason) |reason| {
            try w.writeAll(reason);
        } else if (res.headers.status.phrase()) |phrase| {
            try w.writeAll(phrase);
        }

        if (res.headers.server) |server| {
            try w.writeAll("\r\nServer: ");
            try w.writeAll(server);
        }

        if (res.headers.connection == .close) {
            try w.writeAll("\r\nConnection: close");
        } else {
            try w.writeAll("\r\nConnection: keep-alive");
        }

        switch (res.headers.transfer_encoding) {
            .chunked => try w.writeAll("\r\nTransfer-Encoding: chunked"),
            .content_length => |content_length| try w.print("\r\nContent-Length: {d}", .{content_length}),
            .none => {},
        }

        for (res.headers.custom) |header| {
            try w.writeAll("\r\n");
            try w.writeAll(header.name);
            try w.writeAll(": ");
            try w.writeAll(header.value);
        }

        try w.writeAll("\r\n\r\n");

        try buffered.flush();
    }

    pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;

    pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead);

    pub fn transferReader(res: *Response) TransferReader {
        return .{ .context = res };
    }

    pub fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
        if (res.request.parser.isComplete()) return 0;

        var index: usize = 0;
        while (index == 0) {
            const amt = try res.request.parser.read(&res.connection, buf[index..], false);
            if (amt == 0 and res.request.parser.isComplete()) break;
            index += amt;
        }

        return index;
    }

    pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported};

    /// Wait for the client to send a complete request head.
    pub fn wait(res: *Response) !void {
        while (true) {
            try res.connection.fill();

            const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek());
            res.connection.clear(@intCast(u16, nchecked));

            if (res.request.parser.state.isContent()) break;
        }

        res.request.headers = try Request.Headers.parse(res.request.parser.header_bytes.items);

        if (res.headers.connection == .keep_alive and res.request.headers.connection == .keep_alive) {
            res.connection.conn.closing = false;
        } else {
            res.connection.conn.closing = true;
        }

        if (res.request.headers.transfer_encoding) |te| {
            switch (te) {
                .chunked => {
                    res.request.parser.next_chunk_length = 0;
                    res.request.parser.state = .chunk_head_size;
                },
            }
        } else if (res.request.headers.content_length) |cl| {
            res.request.parser.next_chunk_length = cl;

            if (cl == 0) res.request.parser.done = true;
        } else {
            res.request.parser.done = true;
        }

        if (!res.request.parser.done) {
            if (res.request.headers.transfer_compression) |tc| switch (tc) {
                .compress => return error.CompressionNotSupported,
                .deflate => res.request.compression = .{
                    .deflate = try std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()),
                },
                .gzip => res.request.compression = .{
                    .gzip = try std.compress.gzip.decompress(res.server.allocator, res.transferReader()),
                },
                .zstd => res.request.compression = .{
                    .zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()),
                },
            };
        }
    }

    pub const ReadError = Compression.DeflateDecompressor.Error || Compression.GzipDecompressor.Error || Compression.ZstdDecompressor.Error || WaitForCompleteHeadError;

    pub const Reader = std.io.Reader(*Response, ReadError, read);

    pub fn reader(res: *Response) Reader {
        return .{ .context = res };
    }

    pub fn read(res: *Response, buffer: []u8) ReadError!usize {
        return switch (res.request.compression) {
            .deflate => |*deflate| try deflate.read(buffer),
            .gzip => |*gzip| try gzip.read(buffer),
            .zstd => |*zstd| try zstd.read(buffer),
            else => try res.transferRead(buffer),
        };
    }

    pub fn readAll(res: *Response, buffer: []u8) !usize {
        var index: usize = 0;
        while (index < buffer.len) {
            const amt = try read(res, buffer[index..]);
            if (amt == 0) break;
            index += amt;
        }
        return index;
    }

    pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };

    pub const Writer = std.io.Writer(*Response, WriteError, write);

    pub fn writer(res: *Response) Writer {
        return .{ .context = res };
    }

    /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
    pub fn write(res: *Response, bytes: []const u8) WriteError!usize {
        switch (res.headers.transfer_encoding) {
            .chunked => {
                try res.connection.writer().print("{x}\r\n", .{bytes.len});
                try res.connection.writeAll(bytes);
                try res.connection.writeAll("\r\n");

                return bytes.len;
            },
            .content_length => |*len| {
                if (len.* < bytes.len) return error.MessageTooLong;

                const amt = try res.connection.write(bytes);
                len.* -= amt;
                return amt;
            },
            .none => return error.NotWriteable,
        }
    }

    /// Finish the body of a request. This notifies the server that you have no more data to send.
    pub fn finish(res: *Response) !void {
        switch (res.headers.transfer_encoding) {
            .chunked => try res.connection.writeAll("0\r\n"),
            .content_length => |len| if (len != 0) return error.MessageNotCompleted,
            .none => {},
        }
    }
};

/// The mode of transport for responses.
pub const RequestTransfer = union(enum) {
    content_length: u64,
    chunked: void,
    none: void,
};

/// The decompressor for request messages.
pub const Compression = union(enum) {
    pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader);
    pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader);
    pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});

    deflate: DeflateDecompressor,
    gzip: GzipDecompressor,
    zstd: ZstdDecompressor,
    none: void,
};

pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server {
    return .{
        .allocator = allocator,
        .socket = net.StreamServer.init(options),
    };
}

pub fn deinit(server: *Server) void {
    server.socket.deinit();
}

pub const ListenError = std.os.SocketError || std.os.BindError || std.os.ListenError || std.os.SetSockOptError || std.os.GetSockNameError;

/// Start the HTTP server listening on the given address.
pub fn listen(server: *Server, address: net.Address) !void {
    try server.socket.listen(address);
}

pub const AcceptError = net.StreamServer.AcceptError || Allocator.Error;

pub const HeaderStrategy = union(enum) {
    /// In this case, the client's Allocator will be used to store the
    /// entire HTTP header. This value is the maximum total size of
    /// HTTP headers allowed, otherwise
    /// error.HttpHeadersExceededSizeLimit is returned from read().
    dynamic: usize,
    /// This is used to store the entire HTTP header. If the HTTP
    /// header is too big to fit, `error.HttpHeadersExceededSizeLimit`
    /// is returned from read(). When this is used, `error.OutOfMemory`
    /// cannot be returned from `read()`.
    static: []u8,
};

/// Accept a new connection and allocate a Response for it.
pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!*Response {
    const in = try server.socket.accept();

    const res = try server.allocator.create(Response);
    res.* = .{
        .server = server,
        .address = in.address,
        .connection = .{ .conn = .{
            .stream = in.stream,
            .protocol = .plain,
        } },
        .request = .{
            .parser = switch (options) {
                .dynamic => |max| proto.HeadersParser.initDynamic(max),
                .static => |buf| proto.HeadersParser.initStatic(buf),
            },
        },
    };

    return res;
}