const std = @import("../../std.zig");
const tls = std.crypto.tls;
const Client = @This();
const net = std.net;
const mem = std.mem;
const crypto = std.crypto;
const assert = std.debug.assert;
const Certificate = std.crypto.Certificate;
const max_ciphertext_len = tls.max_ciphertext_len;
const hkdfExpandLabel = tls.hkdfExpandLabel;
const int2 = tls.int2;
const int3 = tls.int3;
const array = tls.array;
const enum_array = tls.enum_array;
read_seq: u64,
write_seq: u64,
partial_cleartext_idx: u15,
partial_ciphertext_idx: u15,
partial_ciphertext_end: u15,
received_close_notify: bool,
allow_truncation_attacks: bool = false,
application_cipher: tls.ApplicationCipher,
partially_read_buffer: [tls.max_ciphertext_record_len]u8,
pub const StreamInterface = struct {
pub const ReadError = error{};
pub fn readv(this: @This(), iovecs: []std.os.iovec) ReadError!usize {
_ = .{ this, iovecs };
@panic("unimplemented");
}
pub const WriteError = error{};
pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize {
_ = .{ this, iovecs };
@panic("unimplemented");
}
pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!usize {
_ = .{ this, iovecs };
@panic("unimplemented");
}
};
pub fn InitError(comptime Stream: type) type {
return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error{
InsufficientEntropy,
DiskQuota,
LockViolation,
NotOpenForWriting,
TlsAlert,
TlsUnexpectedMessage,
TlsIllegalParameter,
TlsDecryptFailure,
TlsRecordOverflow,
TlsBadRecordMac,
CertificateFieldHasInvalidLength,
CertificateHostMismatch,
CertificatePublicKeyInvalid,
CertificateExpired,
CertificateFieldHasWrongDataType,
CertificateIssuerMismatch,
CertificateNotYetValid,
CertificateSignatureAlgorithmMismatch,
CertificateSignatureAlgorithmUnsupported,
CertificateSignatureInvalid,
CertificateSignatureInvalidLength,
CertificateSignatureNamedCurveUnsupported,
CertificateSignatureUnsupportedBitCount,
TlsCertificateNotVerified,
TlsBadSignatureScheme,
TlsBadRsaSignatureBitCount,
InvalidEncoding,
IdentityElement,
SignatureVerificationFailed,
TlsDecryptError,
TlsConnectionTruncated,
TlsDecodeError,
UnsupportedCertificateVersion,
CertificateTimeInvalid,
CertificateHasUnrecognizedObjectId,
CertificateHasInvalidBitString,
MessageTooLong,
NegativeIntoUnsigned,
TargetTooSmall,
BufferTooSmall,
InvalidSignature,
NotSquare,
NonCanonical,
};
}
pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) InitError(@TypeOf(stream))!Client {
const host_len = @intCast(u16, host.len);
var random_buffer: [128]u8 = undefined;
crypto.random.bytes(&random_buffer);
const hello_rand = random_buffer[0..32].*;
const legacy_session_id = random_buffer[32..64].*;
const x25519_kp_seed = random_buffer[64..96].*;
const secp256r1_kp_seed = random_buffer[96..128].*;
const x25519_kp = crypto.dh.X25519.KeyPair.create(x25519_kp_seed) catch |err| switch (err) {
error.IdentityElement => return error.InsufficientEntropy,
};
const secp256r1_kp = crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair.create(secp256r1_kp_seed) catch |err| switch (err) {
error.IdentityElement => return error.InsufficientEntropy,
};
const kyber768_kp = crypto.kem.kyber_d00.Kyber768.KeyPair.create(null) catch {};
const extensions_payload =
tls.extension(.supported_versions, [_]u8{
0x02,
0x03, 0x04,
}) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{
.ecdsa_secp256r1_sha256,
.ecdsa_secp384r1_sha384,
.ecdsa_secp521r1_sha512,
.rsa_pss_rsae_sha256,
.rsa_pss_rsae_sha384,
.rsa_pss_rsae_sha512,
.rsa_pkcs1_sha256,
.rsa_pkcs1_sha384,
.rsa_pkcs1_sha512,
.ed25519,
})) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{
.x25519_kyber768d00,
.secp256r1,
.x25519,
})) ++ tls.extension(
.key_share,
array(1, int2(@enumToInt(tls.NamedGroup.x25519)) ++
array(1, x25519_kp.public_key) ++
int2(@enumToInt(tls.NamedGroup.secp256r1)) ++
array(1, secp256r1_kp.public_key.toUncompressedSec1()) ++
int2(@enumToInt(tls.NamedGroup.x25519_kyber768d00)) ++
array(1, x25519_kp.public_key ++ kyber768_kp.public_key.toBytes())),
) ++
int2(@enumToInt(tls.ExtensionType.server_name)) ++
int2(host_len + 5) ++
int2(host_len + 3) ++
[1]u8{0x00} ++
int2(host_len);
const extensions_header =
int2(@intCast(u16, extensions_payload.len + host_len)) ++
extensions_payload;
const legacy_compression_methods = 0x0100;
const client_hello =
int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++
hello_rand ++
[1]u8{32} ++ legacy_session_id ++
cipher_suites ++
int2(legacy_compression_methods) ++
extensions_header;
const out_handshake =
[_]u8{@enumToInt(tls.HandshakeType.client_hello)} ++
int3(@intCast(u24, client_hello.len + host_len)) ++
client_hello;
const plaintext_header = [_]u8{
@enumToInt(tls.ContentType.handshake),
0x03, 0x01,
} ++ int2(@intCast(u16, out_handshake.len + host_len)) ++ out_handshake;
{
var iovecs = [_]std.os.iovec_const{
.{
.iov_base = &plaintext_header,
.iov_len = plaintext_header.len,
},
.{
.iov_base = host.ptr,
.iov_len = host.len,
},
};
try stream.writevAll(&iovecs);
}
const client_hello_bytes1 = plaintext_header[5..];
var handshake_cipher: tls.HandshakeCipher = undefined;
var handshake_buffer: [8000]u8 = undefined;
var d: tls.Decoder = .{ .buf = &handshake_buffer };
{
try d.readAtLeastOurAmt(stream, tls.record_header_len);
const ct = d.decode(tls.ContentType);
d.skip(2);
const record_len = d.decode(u16);
try d.readAtLeast(stream, record_len);
const server_hello_fragment = d.buf[d.idx..][0..record_len];
var ptd = try d.sub(record_len);
switch (ct) {
.alert => {
try ptd.ensure(2);
const level = ptd.decode(tls.AlertLevel);
const desc = ptd.decode(tls.AlertDescription);
_ = level;
_ = desc;
return error.TlsAlert;
},
.handshake => {
try ptd.ensure(4);
const handshake_type = ptd.decode(tls.HandshakeType);
if (handshake_type != .server_hello) return error.TlsUnexpectedMessage;
const length = ptd.decode(u24);
var hsd = try ptd.sub(length);
try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2);
const legacy_version = hsd.decode(u16);
const random = hsd.array(32);
if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) {
return error.TlsUnexpectedMessage;
}
const legacy_session_id_echo_len = hsd.decode(u8);
if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter;
const legacy_session_id_echo = hsd.array(32);
if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id))
return error.TlsIllegalParameter;
const cipher_suite_tag = hsd.decode(tls.CipherSuite);
hsd.skip(1);
const extensions_size = hsd.decode(u16);
var all_extd = try hsd.sub(extensions_size);
var supported_version: u16 = 0;
var shared_key: []const u8 = undefined;
var have_shared_key = false;
while (!all_extd.eof()) {
try all_extd.ensure(2 + 2);
const et = all_extd.decode(tls.ExtensionType);
const ext_size = all_extd.decode(u16);
var extd = try all_extd.sub(ext_size);
switch (et) {
.supported_versions => {
if (supported_version != 0) return error.TlsIllegalParameter;
try extd.ensure(2);
supported_version = extd.decode(u16);
},
.key_share => {
if (have_shared_key) return error.TlsIllegalParameter;
have_shared_key = true;
try extd.ensure(4);
const named_group = extd.decode(tls.NamedGroup);
const key_size = extd.decode(u16);
try extd.ensure(key_size);
switch (named_group) {
.x25519_kyber768d00 => {
const xksl = crypto.dh.X25519.public_length;
const hksl = xksl + crypto.kem.kyber_d00.Kyber768.ciphertext_length;
if (key_size != hksl)
return error.TlsIllegalParameter;
const server_ks = extd.array(hksl);
shared_key = &((crypto.dh.X25519.scalarmult(
x25519_kp.secret_key,
server_ks[0..xksl].*,
) catch return error.TlsDecryptFailure) ++ (kyber768_kp.secret_key.decaps(
server_ks[xksl..hksl],
) catch return error.TlsDecryptFailure));
},
.x25519 => {
const ksl = crypto.dh.X25519.public_length;
if (key_size != ksl) return error.TlsIllegalParameter;
const server_pub_key = extd.array(ksl);
shared_key = &(crypto.dh.X25519.scalarmult(
x25519_kp.secret_key,
server_pub_key.*,
) catch return error.TlsDecryptFailure);
},
.secp256r1 => {
const server_pub_key = extd.slice(key_size);
const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey;
const pk = PublicKey.fromSec1(server_pub_key) catch {
return error.TlsDecryptFailure;
};
const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .Big) catch {
return error.TlsDecryptFailure;
};
shared_key = &mul.affineCoordinates().x.toBytes(.Big);
},
else => {
return error.TlsIllegalParameter;
},
}
},
else => {},
}
}
if (!have_shared_key) return error.TlsIllegalParameter;
const tls_version = if (supported_version == 0) legacy_version else supported_version;
if (tls_version != @enumToInt(tls.ProtocolVersion.tls_1_3))
return error.TlsIllegalParameter;
switch (cipher_suite_tag) {
inline .AES_128_GCM_SHA256,
.AES_256_GCM_SHA384,
.CHACHA20_POLY1305_SHA256,
.AEGIS_256_SHA384,
.AEGIS_128L_SHA256,
=> |tag| {
const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag));
handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{
.handshake_secret = undefined,
.master_secret = undefined,
.client_handshake_key = undefined,
.server_handshake_key = undefined,
.client_finished_key = undefined,
.server_finished_key = undefined,
.client_handshake_iv = undefined,
.server_handshake_iv = undefined,
.transcript_hash = P.Hash.init(.{}),
});
const p = &@field(handshake_cipher, @tagName(tag));
p.transcript_hash.update(client_hello_bytes1);
p.transcript_hash.update(host);
p.transcript_hash.update(server_hello_fragment);
const hello_hash = p.transcript_hash.peek();
const zeroes = [1]u8{0} ** P.Hash.digest_length;
const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes);
const empty_hash = tls.emptyHash(P.Hash);
const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length);
p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, shared_key);
const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length);
p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes);
const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length);
p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length);
p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length);
p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
},
else => {
return error.TlsIllegalParameter;
},
}
},
else => return error.TlsUnexpectedMessage,
}
}
var cert_index: usize = 0;
var read_seq: u64 = 0;
var prev_cert: Certificate.Parsed = undefined;
const HandshakeState = enum {
encrypted_extensions,
certificate,
trust_chain_established,
finished,
};
var handshake_state: HandshakeState = .encrypted_extensions;
var cleartext_bufs: [2][8000]u8 = undefined;
var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined;
var main_cert_pub_key_buf: [300]u8 = undefined;
var main_cert_pub_key_len: u16 = undefined;
const now_sec = std.time.timestamp();
while (true) {
try d.readAtLeastOurAmt(stream, tls.record_header_len);
const record_header = d.buf[d.idx..][0..5];
const ct = d.decode(tls.ContentType);
d.skip(2);
const record_len = d.decode(u16);
try d.readAtLeast(stream, record_len);
var record_decoder = try d.sub(record_len);
switch (ct) {
.change_cipher_spec => {
try record_decoder.ensure(1);
if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter;
},
.application_data => {
const cleartext_buf = &cleartext_bufs[cert_index % 2];
const cleartext = switch (handshake_cipher) {
inline else => |*p| c: {
const P = @TypeOf(p.*);
const ciphertext_len = record_len - P.AEAD.tag_length;
try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length);
const ciphertext = record_decoder.slice(ciphertext_len);
if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_buf[0..ciphertext.len];
const auth_tag = record_decoder.array(P.AEAD.tag_length).*;
const V = @Vector(P.AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
read_seq += 1;
const nonce = @as(V, p.server_handshake_iv) ^ operand;
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch
return error.TlsBadRecordMac;
break :c cleartext;
},
};
const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
if (inner_ct != .handshake) return error.TlsUnexpectedMessage;
var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]);
while (true) {
try ctd.ensure(4);
const handshake_type = ctd.decode(tls.HandshakeType);
const handshake_len = ctd.decode(u24);
var hsd = try ctd.sub(handshake_len);
const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx];
const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx];
switch (handshake_type) {
.encrypted_extensions => {
if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage;
handshake_state = .certificate;
switch (handshake_cipher) {
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
}
try hsd.ensure(2);
const total_ext_size = hsd.decode(u16);
var all_extd = try hsd.sub(total_ext_size);
while (!all_extd.eof()) {
try all_extd.ensure(4);
const et = all_extd.decode(tls.ExtensionType);
const ext_size = all_extd.decode(u16);
var extd = try all_extd.sub(ext_size);
_ = extd;
switch (et) {
.server_name => {},
else => {},
}
}
},
.certificate => cert: {
switch (handshake_cipher) {
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
}
switch (handshake_state) {
.certificate => {},
.trust_chain_established => break :cert,
else => return error.TlsUnexpectedMessage,
}
try hsd.ensure(1 + 4);
const cert_req_ctx_len = hsd.decode(u8);
if (cert_req_ctx_len != 0) return error.TlsIllegalParameter;
const certs_size = hsd.decode(u24);
var certs_decoder = try hsd.sub(certs_size);
while (!certs_decoder.eof()) {
try certs_decoder.ensure(3);
const cert_size = certs_decoder.decode(u24);
var certd = try certs_decoder.sub(cert_size);
const subject_cert: Certificate = .{
.buffer = certd.buf,
.index = @intCast(u32, certd.idx),
};
const subject = try subject_cert.parse();
if (cert_index == 0) {
try subject.verifyHostName(host);
main_cert_pub_key_algo = subject.pub_key_algo;
const pub_key = subject.pubKey();
if (pub_key.len > main_cert_pub_key_buf.len)
return error.CertificatePublicKeyInvalid;
@memcpy(main_cert_pub_key_buf[0..pub_key.len], pub_key);
main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len);
} else {
try prev_cert.verify(subject, now_sec);
}
if (ca_bundle.verify(subject, now_sec)) |_| {
handshake_state = .trust_chain_established;
break :cert;
} else |err| switch (err) {
error.CertificateIssuerNotFound => {},
else => |e| return e,
}
prev_cert = subject;
cert_index += 1;
try certs_decoder.ensure(2);
const total_ext_size = certs_decoder.decode(u16);
var all_extd = try certs_decoder.sub(total_ext_size);
_ = all_extd;
}
},
.certificate_verify => {
switch (handshake_state) {
.trust_chain_established => handshake_state = .finished,
.certificate => return error.TlsCertificateNotVerified,
else => return error.TlsUnexpectedMessage,
}
try hsd.ensure(4);
const scheme = hsd.decode(tls.SignatureScheme);
const sig_len = hsd.decode(u16);
try hsd.ensure(sig_len);
const encoded_sig = hsd.slice(sig_len);
const max_digest_len = 64;
var verify_buffer =
([1]u8{0x20} ** 64) ++
"TLS 1.3, server CertificateVerify\x00".* ++
@as([max_digest_len]u8, undefined);
const verify_bytes = switch (handshake_cipher) {
inline else => |*p| v: {
const transcript_digest = p.transcript_hash.peek();
verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest;
p.transcript_hash.update(wrapped_handshake);
break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len];
},
};
const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len];
switch (scheme) {
inline .ecdsa_secp256r1_sha256,
.ecdsa_secp384r1_sha384,
=> |comptime_scheme| {
if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey)
return error.TlsBadSignatureScheme;
const Ecdsa = SchemeEcdsa(comptime_scheme);
const sig = try Ecdsa.Signature.fromDer(encoded_sig);
const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key);
try sig.verify(verify_bytes, key);
},
.rsa_pss_rsae_sha256 => {
if (main_cert_pub_key_algo != .rsaEncryption)
return error.TlsBadSignatureScheme;
const Hash = crypto.hash.sha2.Sha256;
const rsa = Certificate.rsa;
const components = try rsa.PublicKey.parseDer(main_cert_pub_key);
const exponent = components.exponent;
const modulus = components.modulus;
var rsa_mem_buf: [512 * 32]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf);
const ally = fba.allocator();
switch (modulus.len) {
inline 128, 256, 512 => |modulus_len| {
const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally);
const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig);
try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally);
},
else => {
return error.TlsBadRsaSignatureBitCount;
},
}
},
else => {
return error.TlsBadSignatureScheme;
},
}
},
.finished => {
if (handshake_state != .finished) return error.TlsUnexpectedMessage;
const client_change_cipher_spec_msg = [_]u8{
@enumToInt(tls.ContentType.change_cipher_spec),
0x03, 0x03,
0x00, 0x01,
0x01,
};
const app_cipher = switch (handshake_cipher) {
inline else => |*p, tag| c: {
const P = @TypeOf(p.*);
const finished_digest = p.transcript_hash.peek();
p.transcript_hash.update(wrapped_handshake);
const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key);
if (!mem.eql(u8, &expected_server_verify_data, handshake))
return error.TlsDecryptError;
const handshake_hash = p.transcript_hash.finalResult();
const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);
const out_cleartext = [_]u8{
@enumToInt(tls.HandshakeType.finished),
0, 0, verify_data.len,
} ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)};
const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
var finished_msg = [_]u8{
@enumToInt(tls.ContentType.application_data),
0x03, 0x03,
0, wrapped_len,
} ++ @as([wrapped_len]u8, undefined);
const ad = finished_msg[0..5];
const ciphertext = finished_msg[5..][0..out_cleartext.len];
const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..];
const nonce = p.client_handshake_iv;
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
try stream.writeAll(&both_msgs);
const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{
.client_secret = client_secret,
.server_secret = server_secret,
.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length),
.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length),
});
},
};
const leftover = d.rest();
var client: Client = .{
.read_seq = 0,
.write_seq = 0,
.partial_cleartext_idx = 0,
.partial_ciphertext_idx = 0,
.partial_ciphertext_end = @intCast(u15, leftover.len),
.received_close_notify = false,
.application_cipher = app_cipher,
.partially_read_buffer = undefined,
};
mem.copy(u8, &client.partially_read_buffer, leftover);
return client;
},
else => {
return error.TlsUnexpectedMessage;
},
}
if (ctd.eof()) break;
}
},
else => {
return error.TlsUnexpectedMessage;
},
}
}
}
pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize {
return writeEnd(c, stream, bytes, false);
}
pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void {
var index: usize = 0;
while (index < bytes.len) {
index += try c.write(stream, bytes[index..]);
}
}
pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void {
var index: usize = 0;
while (index < bytes.len) {
index += try c.writeEnd(stream, bytes[index..], end);
}
}
pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize {
var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
var iovecs_buf: [6]std.os.iovec_const = undefined;
var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
if (end) {
prepared.iovec_end += prepareCiphertextRecord(
c,
iovecs_buf[prepared.iovec_end..],
ciphertext_buf[prepared.ciphertext_end..],
&tls.close_notify_alert,
.alert,
).iovec_end;
}
const iovec_end = prepared.iovec_end;
const overhead_len = prepared.overhead_len;
var i: usize = 0;
var total_amt: usize = 0;
while (true) {
var amt = try stream.writev(iovecs_buf[i..iovec_end]);
while (amt >= iovecs_buf[i].iov_len) {
const encrypted_amt = iovecs_buf[i].iov_len;
total_amt += encrypted_amt - overhead_len;
amt -= encrypted_amt;
i += 1;
if (i >= iovec_end) return total_amt;
if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
}
iovecs_buf[i].iov_base += amt;
iovecs_buf[i].iov_len -= amt;
}
}
fn prepareCiphertextRecord(
c: *Client,
iovecs: []std.os.iovec_const,
ciphertext_buf: []u8,
bytes: []const u8,
inner_content_type: tls.ContentType,
) struct {
iovec_end: usize,
ciphertext_end: usize,
overhead_len: usize,
} {
var cleartext_buf: [max_ciphertext_len]u8 = undefined;
var ciphertext_end: usize = 0;
var iovec_end: usize = 0;
var bytes_i: usize = 0;
switch (c.application_cipher) {
inline else => |*p| {
const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8);
const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
while (true) {
const encrypted_content_len = @intCast(u16, @min(
@min(bytes.len - bytes_i, max_ciphertext_len - 1),
ciphertext_buf.len - close_notify_alert_reserved -
overhead_len - ciphertext_end,
));
if (encrypted_content_len == 0) return .{
.iovec_end = iovec_end,
.ciphertext_end = ciphertext_end,
.overhead_len = overhead_len,
};
mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
cleartext_buf[encrypted_content_len] = @enumToInt(inner_content_type);
bytes_i += encrypted_content_len;
const ciphertext_len = encrypted_content_len + 1;
const cleartext = cleartext_buf[0..ciphertext_len];
const record_start = ciphertext_end;
const ad = ciphertext_buf[ciphertext_end..][0..5];
ad.* =
[_]u8{@enumToInt(tls.ContentType.application_data)} ++
int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++
int2(ciphertext_len + P.AEAD.tag_length);
ciphertext_end += ad.len;
const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len];
ciphertext_end += ciphertext_len;
const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length];
ciphertext_end += auth_tag.len;
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ @bitCast([8]u8, big(c.write_seq));
c.write_seq += 1;
const nonce = @as(V, p.client_iv) ^ operand;
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
const record = ciphertext_buf[record_start..ciphertext_end];
iovecs[iovec_end] = .{
.iov_base = record.ptr,
.iov_len = record.len,
};
iovec_end += 1;
}
},
}
}
pub fn eof(c: Client) bool {
return c.received_close_notify and
c.partial_cleartext_idx >= c.partial_ciphertext_idx and
c.partial_ciphertext_idx >= c.partial_ciphertext_end;
}
pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize {
var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }};
return readvAtLeast(c, stream, &iovecs, len);
}
pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize {
return readAtLeast(c, stream, buffer, 1);
}
pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
return readAtLeast(c, stream, buffer, buffer.len);
}
pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize {
return readvAtLeast(c, stream, iovecs);
}
pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: usize) !usize {
if (c.eof()) return 0;
var off_i: usize = 0;
var vec_i: usize = 0;
while (true) {
var amt = try c.readvAdvanced(stream, iovecs[vec_i..]);
off_i += amt;
if (c.eof() or off_i >= len) return off_i;
while (amt >= iovecs[vec_i].iov_len) {
amt -= iovecs[vec_i].iov_len;
vec_i += 1;
}
iovecs[vec_i].iov_base += amt;
iovecs[vec_i].iov_len -= amt;
}
}
pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) !usize {
var vp: VecPut = .{ .iovecs = iovecs };
const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx];
if (partial_cleartext.len > 0) {
const amt = @intCast(u15, vp.put(partial_cleartext));
c.partial_cleartext_idx += amt;
if (c.partial_ciphertext_end == c.partial_ciphertext_idx) {
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = 0;
c.partial_ciphertext_end = 0;
}
if (c.received_close_notify) {
c.partial_ciphertext_end = 0;
assert(vp.total == amt);
return amt;
} else if (amt <= partial_cleartext.len) {
assert(vp.total == amt);
return amt;
}
}
assert(!c.received_close_notify);
var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined;
var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined;
const free_size = vp.freeSize();
const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len;
const cleartext_buf_len = free_size - ciphertext_buf_len;
const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..];
var ask_iovecs_buf: [2]std.os.iovec = .{
.{
.iov_base = first_iov.ptr,
.iov_len = first_iov.len,
},
.{
.iov_base = &in_stack_buffer,
.iov_len = in_stack_buffer.len,
},
};
const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len);
const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
const actual_read_len = try stream.readv(ask_iovecs);
if (actual_read_len == 0) {
if (c.allow_truncation_attacks) {
c.received_close_notify = true;
} else {
return error.TlsConnectionTruncated;
}
}
const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len);
const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end];
var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len];
var frag = frag0;
var in: usize = 0;
while (true) {
if (in == frag.len) {
if (frag.ptr == frag1.ptr) {
c.partial_ciphertext_end = c.partial_ciphertext_idx;
return vp.total;
}
frag = frag1;
in = 0;
continue;
}
if (in + tls.record_header_len > frag.len) {
if (frag.ptr == frag1.ptr)
return finishRead(c, frag, in, vp.total);
const first = frag[in..];
if (frag1.len < tls.record_header_len)
return finishRead2(c, first, frag1, vp.total);
const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3);
const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4);
const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
const full_record_len = record_len + tls.record_header_len;
const second_len = full_record_len - first.len;
if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total);
mem.copy(u8, frag[0..in], first);
mem.copy(u8, frag[first.len..], frag1[0..second_len]);
frag = frag[0..full_record_len];
frag1 = frag1[second_len..];
in = 0;
continue;
}
const ct = @intToEnum(tls.ContentType, frag[in]);
in += 1;
const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
in += 2;
_ = legacy_version;
const record_len = mem.readIntBig(u16, frag[in..][0..2]);
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
in += 2;
const end = in + record_len;
if (end > frag.len) {
in -= tls.record_header_len;
if (frag.ptr == frag1.ptr)
return finishRead(c, frag, in, vp.total);
const first = frag[in..];
const full_record_len = record_len + tls.record_header_len;
const second_len = full_record_len - first.len;
if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total);
mem.copy(u8, frag[0..in], first);
mem.copy(u8, frag[first.len..], frag1[0..second_len]);
frag = frag[0..full_record_len];
frag1 = frag1[second_len..];
in = 0;
continue;
}
switch (ct) {
.alert => {
if (in + 2 > frag.len) return error.TlsDecodeError;
const level = @intToEnum(tls.AlertLevel, frag[in]);
const desc = @intToEnum(tls.AlertDescription, frag[in + 1]);
_ = level;
_ = desc;
return error.TlsAlert;
},
.application_data => {
const cleartext = switch (c.application_cipher) {
inline else => |*p| c: {
const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8);
const ad = frag[in - 5 ..][0..5];
const ciphertext_len = record_len - P.AEAD.tag_length;
const ciphertext = frag[in..][0..ciphertext_len];
in += ciphertext_len;
const auth_tag = frag[in..][0..P.AEAD.tag_length].*;
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq));
const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand;
const out_buf = vp.peek();
const cleartext_buf = if (ciphertext.len <= out_buf.len)
out_buf
else
&cleartext_stack_buffer;
const cleartext = cleartext_buf[0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch
return error.TlsBadRecordMac;
break :c cleartext;
},
};
c.read_seq = try std.math.add(u64, c.read_seq, 1);
const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
switch (inner_ct) {
.alert => {
const level = @intToEnum(tls.AlertLevel, cleartext[0]);
const desc = @intToEnum(tls.AlertDescription, cleartext[1]);
if (desc == .close_notify) {
c.received_close_notify = true;
c.partial_ciphertext_end = c.partial_ciphertext_idx;
return vp.total;
}
_ = level;
return error.TlsAlert;
},
.handshake => {
var ct_i: usize = 0;
while (true) {
const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]);
ct_i += 1;
const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
ct_i += 3;
const next_handshake_i = ct_i + handshake_len;
if (next_handshake_i > cleartext.len - 1)
return error.TlsBadLength;
const handshake = cleartext[ct_i..next_handshake_i];
switch (handshake_type) {
.new_session_ticket => {
},
.key_update => {
switch (c.application_cipher) {
inline else => |*p| {
const P = @TypeOf(p.*);
const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length);
p.server_secret = server_secret;
p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
},
}
c.read_seq = 0;
switch (@intToEnum(tls.KeyUpdateRequest, handshake[0])) {
.update_requested => {
switch (c.application_cipher) {
inline else => |*p| {
const P = @TypeOf(p.*);
const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length);
p.client_secret = client_secret;
p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
},
}
c.write_seq = 0;
},
.update_not_requested => {},
_ => return error.TlsIllegalParameter,
}
},
else => {
return error.TlsUnexpectedMessage;
},
}
ct_i = next_handshake_i;
if (ct_i >= cleartext.len - 1) break;
}
},
.application_data => {
if (cleartext.ptr == &cleartext_stack_buffer) {
const msg = cleartext[0 .. cleartext.len - 1];
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
const dest = c.partially_read_buffer[c.partial_ciphertext_idx..];
mem.copy(u8, dest, msg);
c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len);
} else {
const amt = vp.put(msg);
if (amt < msg.len) {
const rest = msg[amt..];
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len);
mem.copy(u8, &c.partially_read_buffer, rest);
}
}
} else {
vp.next(cleartext.len - 1);
}
},
else => {
return error.TlsUnexpectedMessage;
},
}
},
else => {
return error.TlsUnexpectedMessage;
},
}
in = end;
}
}
fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
const saved_buf = frag[in..];
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + saved_buf.len);
mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx..], saved_buf);
} else {
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = 0;
c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), saved_buf.len);
mem.copy(u8, &c.partially_read_buffer, saved_buf);
}
return out;
}
fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize {
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + first.len + frag1.len);
mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx..], first);
mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..], frag1);
} else {
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = 0;
c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), first.len + frag1.len);
mem.copy(u8, &c.partially_read_buffer, first);
mem.copy(u8, c.partially_read_buffer[first.len..], frag1);
}
return out;
}
fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 {
if (index < s1.len) {
return s1[index];
} else {
return s2[index - s1.len];
}
}
const builtin = @import("builtin");
const native_endian = builtin.cpu.arch.endian();
inline fn big(x: anytype) @TypeOf(x) {
return switch (native_endian) {
.Big => x,
.Little => @byteSwap(x),
};
}
fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type {
return switch (scheme) {
.ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256,
.ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384,
.ecdsa_secp521r1_sha512 => crypto.sign.ecdsa.EcdsaP512Sha512,
else => @compileError("bad scheme"),
};
}
const VecPut = struct {
iovecs: []const std.os.iovec,
idx: usize = 0,
off: usize = 0,
total: usize = 0,
fn put(vp: *VecPut, bytes: []const u8) usize {
if (vp.idx >= vp.iovecs.len) return 0;
var bytes_i: usize = 0;
while (true) {
const v = vp.iovecs[vp.idx];
const dest = v.iov_base[vp.off..v.iov_len];
const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)];
mem.copy(u8, dest, src);
bytes_i += src.len;
vp.off += src.len;
if (vp.off >= v.iov_len) {
vp.off = 0;
vp.idx += 1;
if (vp.idx >= vp.iovecs.len) {
vp.total += bytes_i;
return bytes_i;
}
}
if (bytes_i >= bytes.len) {
vp.total += bytes_i;
return bytes_i;
}
}
}
fn peek(vp: VecPut) []u8 {
if (vp.idx >= vp.iovecs.len) return &.{};
const v = vp.iovecs[vp.idx];
return v.iov_base[vp.off..v.iov_len];
}
fn next(vp: *VecPut, len: usize) void {
vp.total += len;
vp.off += len;
if (vp.off >= vp.iovecs[vp.idx].iov_len) {
vp.off = 0;
vp.idx += 1;
}
}
fn freeSize(vp: VecPut) usize {
if (vp.idx >= vp.iovecs.len) return 0;
var total: usize = 0;
total += vp.iovecs[vp.idx].iov_len - vp.off;
if (vp.idx + 1 >= vp.iovecs.len) return total;
for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len;
return total;
}
};
fn limitVecs(iovecs: []std.os.iovec, len: usize) []std.os.iovec {
var vec_i: usize = 0;
var bytes_left: usize = len;
while (true) {
if (bytes_left >= iovecs[vec_i].iov_len) {
bytes_left -= iovecs[vec_i].iov_len;
vec_i += 1;
if (vec_i == iovecs.len or bytes_left == 0) return iovecs[0..vec_i];
continue;
}
iovecs[vec_i].iov_len = bytes_left;
return iovecs[0..vec_i];
}
}
const cipher_suites = if (crypto.core.aes.has_hardware_support)
enum_array(tls.CipherSuite, &.{
.AEGIS_128L_SHA256,
.AEGIS_256_SHA384,
.AES_128_GCM_SHA256,
.AES_256_GCM_SHA384,
.CHACHA20_POLY1305_SHA256,
})
else
enum_array(tls.CipherSuite, &.{
.CHACHA20_POLY1305_SHA256,
.AEGIS_128L_SHA256,
.AEGIS_256_SHA384,
.AES_128_GCM_SHA256,
.AES_256_GCM_SHA384,
});
test {
_ = StreamInterface;
}