hare

The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

commit a6923d1f370a1198a8bb7d849f309287fc1a02a4
parent 081ba3c1923b29b9af5ef5871cdd903c90b0eecb
Author: Drew DeVault <sir@cmpwn.com>
Date:   Wed, 23 Jun 2021 11:37:01 -0400

net::dns: refactor decoding for improved rdata

This abandons the stack-based API, which was only going to get more
complex and brittle (and hard to use) with the introduction of more
rdata formats, and instead just focuses on heap-allocating the decoded
message data. This also makes the rdata field use a more Hairy type
(tagged union of known rdata formats) and fully implements MX decoding.

Signed-off-by: Drew DeVault <sir@cmpwn.com>

Diffstat:
Mnet/dns/decode.ha | 188++++++++++++++++++++++++++++++-------------------------------------------------
Mnet/dns/encode.ha | 10+---------
Mnet/dns/error.ha | 18++++++++++--------
Mnet/dns/types.ha | 65+++++++++++++++++++++++++++++++++++++++++++++--------------------
4 files changed, 127 insertions(+), 154 deletions(-)

diff --git a/net/dns/decode.ha b/net/dns/decode.ha @@ -1,5 +1,6 @@ use ascii; use endian; +use fmt; use net::ip; use strings; @@ -21,16 +22,7 @@ export fn decode(buf: []u8) (*message | format) = { decode_header(&dec, &msg.header)?; for (let i = 0z; i < msg.header.qdcount; i += 1) { - let q = question { ... }; - let names = decode_question(&dec, &q)?; - - for (let i = 0; len(names) != 0; i += 1) { - let ns = decode_name(&dec, names)!; - names = ns.0; - append(q.qname, strings::dup(ns.1)); - }; - - append(msg.questions, q); + append(msg.questions, decode_question(&dec)?); }; decode_rrecords(&dec, msg.header.ancount, &msg.answers)?; @@ -46,26 +38,12 @@ fn decode_rrecords( ) (void | format) = { for (let i = 0z; i < count; i += 1) { let r = rrecord { ... }; - let names = decode_rrecord(dec, &r)?; - let rdata = r.rdata; - r.rdata = []; - append(r.rdata, rdata...); - - for (let i = 0; len(names) != 0; i += 1) { - let ns = decode_name(dec, names)!; - names = ns.0; - append(r.name, strings::dup(ns.1)); - }; - + decode_rrecord(dec, &r)?; append(*out, r); }; }; -// Initializes a DNS message decoder. All storaged used by the decoder is either -// stack-allocated, provided by the caller, or borrowed from the input buffer. -// -// Call [[decode_header]] next. -export fn decoder_init(buf: []u8) decoder = decoder { +fn decoder_init(buf: []u8) decoder = decoder { buf = buf, cur = buf, ... @@ -89,12 +67,7 @@ fn decode_u32(dec: *decoder) (u32 | format) = { return val; }; -// Decodes a DNS message's header and advances the decoder to the -// variable-length section of the message. Following this call, the user should -// call [[decode_question]] for each question given by the header's qdcount, -// then [[decode_rrecord]] for each resource record given by the ancount, -// nscount, and arcount fields, respectively. -export fn decode_header(dec: *decoder, head: *header) (void | format) = { +fn decode_header(dec: *decoder, head: *header) (void | format) = { head.id = decode_u16(dec)?; let rawop = decode_u16(dec)?; op_decode(rawop, &head.op); @@ -118,113 +91,94 @@ fn op_decode(in: u16, out: *op) void = { out.rcode = (in & 0b1111): rcode; }; -// Partially decodes a [[question]] and advances the decoder. Returns a slice -// representing the name field, which can be passed to [[decode_name]] to -// interpret. -export fn decode_question(dec: *decoder, q: *question) ([]u8 | format) = { - let name = extract_name(dec)?; - q.qtype = decode_u16(dec)?: qtype; - q.qclass = decode_u16(dec)?: qclass; - return name; -}; - -// Partially decodes a [[rrecord]] and advances the decoder. Returns a slice -// representing the name field, which can be passed to [[decode_name]] to -// interpret. -export fn decode_rrecord(dec: *decoder, r: *rrecord) ([]u8 | format) = { - let name = extract_name(dec)!; - r.rtype = decode_u16(dec)!: rtype; - r.class = decode_u16(dec)!: class; - r.ttl = decode_u32(dec)!; - let rdz = decode_u16(dec)!; - r.rdata = dec.cur[..rdz]; - dec.cur = dec.cur[rdz..]; - return name; -}; - -fn extract_name(dec: *decoder) ([]u8 | format) = { - if (dec.cur[0] & 0b11000000 == 0b11000000) { - const name = dec.cur[..2]; - dec.cur = dec.cur[2..]; - return name; - }; - for (let i = 0z; i < len(dec.cur); i += 1) { - let z = dec.cur[i]; +fn decode_name(dec: *decoder) ([]str | format) = { + let names: []str = []; + for (true) { + let z = dec.cur[0]; + if (z & 0b11000000 == 0b11000000) { + let offs = decode_u16(dec)? & ~0b1100000000000000u16; + let sub = decoder { + buf = dec.buf, + cur = dec.buf[offs..], + ... + }; + append(names, decode_name(&sub)?...); + break; + }; + dec.cur = dec.cur[1..]; if (z == 0) { - const name = dec.cur[..i + 1]; - dec.cur = dec.cur[i + 1..]; - return name; + break; + }; + + let name = dec.cur[..z]; + dec.cur = dec.cur[z..]; + for (let i = 0z; i < len(name); i += 1) { + if (!ascii::isascii(name[i]: u32: rune)) { + fmt::errorfln("not ascii at index {}", i)!; + return format; + }; }; - i += z; + + append(names, strings::dup(strings::fromutf8(name))); }; - return format; + return names; }; -// Decodes a name from a question or resource record, returning the decoded name -// and the remainder of the buffer. The caller should pass the returned buffer -// into decode_name again to retrieve the next name. When the return value is an -// empty string, all of the names have been decoded. It is a programming error -// to call decode_name again after this, and the program will abort. -export fn decode_name(dec: *decoder, buf: []u8) (([]u8, str) | format) = { - let z = buf[0]; - if (z == 0) { - return ([]: []u8, ""); - }; - if (z & 0b11000000 == 0b11000000) { - let offs = endian::begetu16(buf) & ~0b1100000000000000u16; - return decode_name(dec, dec.buf[offs..]); +fn decode_question(dec: *decoder) (question | format) = { + return question { + qname = decode_name(dec)?, + qtype = decode_u16(dec)?: qtype, + qclass = decode_u16(dec)?: qclass, }; - let name = buf[1..z + 1]; - buf = buf[z + 1..]; - for (let i = 0z; i < len(name); i += 1) { - if (!ascii::isascii(name[i]: u32: rune)) { - return format; - }; - }; - return (buf, strings::fromutf8(name)); }; -// Decodes the rdata field of a [[rrecord]]. The return value is borrowed from -// the rdata buffer. -export fn decode_rdata(rr: *rrecord) (ip::addr | format) = { - return switch (rr.rtype) { - rtype::A => decode_a(rr.rdata)?: ip::addr, - rtype::AAAA => decode_aaaa(rr.rdata)?: ip::addr, - * => format, +fn decode_rrecord(dec: *decoder, r: *rrecord) (void | format) = { + r.name = decode_name(dec)?; + r.rtype = decode_u16(dec)?: rtype; + r.class = decode_u16(dec)?: class; + r.ttl = decode_u32(dec)?; + let rdz = decode_u16(dec)?; + r.rdata = decode_rdata(dec, r.rtype, rdz)?; +}; + +fn decode_rdata(dec: *decoder, rtype: rtype, rlen: size) (rdata | format) = { + return switch (rtype) { + rtype::A => decode_a(dec), + rtype::AAAA => decode_aaaa(dec), + rtype::MX => decode_mx(dec), + * => { + let buf = dec.cur[..rlen]; + dec.cur = dec.cur[rlen..]; + return buf: unknown_rdata; + }, }; }; -// Decodes the rdata field of an A (address) record. The return value is -// borrowed from the rdata buffer. -export fn decode_a(rdata: []u8) (ip::addr4 | format) = { - if (len(rdata) != 4) { +fn decode_a(dec: *decoder) (rdata | format) = { + if (len(dec.cur) != 4) { return format; }; let ip: ip::addr4 = [0...]; - ip[..] = rdata[..]; - return ip; + ip[..] = dec.cur[..4]; + dec.cur = dec.cur[4..]; + return ip: a; }; -// Decodes the rdata field of an AAAA (address) record. The return value is -// borrowed from the rdata buffer. -export fn decode_aaaa(rdata: []u8) (ip::addr6 | format) = { - if (len(rdata) != 8) { +fn decode_aaaa(dec: *decoder) (rdata | format) = { + if (len(dec.cur) != 8) { return format; }; let ip: ip::addr6 = [0...]; - ip[..] = rdata[..]; - return ip; + ip[..] = dec.cur[..8]; + dec.cur = dec.cur[8..]; + return ip: aaaa; }; -// Decodes the rdata field of an MX (mail exchange) record, returning the -// priority and the name. See [[decode_name]] to decode the name. The return -// value is borrowed from the rdata buffer. -export fn decode_mx(rdata: []u8) ((u16, []u8) | format) = { - if (len(rdata) < 2) { - return format; +fn decode_mx(dec: *decoder) (rdata | format) = { + return mx { + priority = decode_u16(dec)?, + name = decode_name(dec)?, }; - let prio = endian::begetu16(rdata); - return (prio, rdata[2..]); }; // TODO: Expand breadth of supported rdata decoders diff --git a/net/dns/encode.ha b/net/dns/encode.ha @@ -85,15 +85,7 @@ fn rrecord_encode(enc: *encoder, r: *rrecord) (void | error) = { encode_u16(enc, r.class)?; encode_u32(enc, r.ttl)?; - assert(len(r.rdata) <= 0xFFFF); - encode_u16(enc, len(r.rdata): u16)?; - - if (len(enc.buf) <= enc.offs + len(r.rdata)) { - return errors::overflow; - }; - - enc.buf[enc.offs..len(r.rdata)] = r.rdata[..]; - enc.offs += len(r.rdata); + abort(); // TODO }; fn op_encode(op: *op) u16 = diff --git a/net/dns/error.ha b/net/dns/error.ha @@ -42,12 +42,14 @@ export fn strerror(err: error) const str = { }; }; -fn check_rcode(rcode: rcode) (void | error) = switch (rcode) { - rcode::NO_ERROR => void, - rcode::FMT_ERROR => format, - rcode::SERVER_FAILURE => server_failure, - rcode::NAME_ERROR => name_error, - rcode::NOT_IMPLEMENTED => not_implemented, - rcode::REFUSED => refused, - * => rcode: unknown_error, +fn check_rcode(rcode: rcode) (void | error) = { + return switch (rcode) { + rcode::NO_ERROR => void, + rcode::FMT_ERROR => format, + rcode::SERVER_FAILURE => server_failure, + rcode::NAME_ERROR => name_error, + rcode::NOT_IMPLEMENTED => not_implemented, + rcode::REFUSED => refused, + * => rcode: unknown_error, + }; }; diff --git a/net/dns/types.ha b/net/dns/types.ha @@ -1,4 +1,6 @@ -// Record type +use net::ip; + +// Record type. export type rtype = enum u16 { A = 1, NS = 2, @@ -12,7 +14,7 @@ export type rtype = enum u16 { DNSKEY = 48, }; -// Question type +// Question type (superset of [[rtype]]). export type qtype = enum u16 { A = 1, NS = 2, @@ -30,7 +32,7 @@ export type qtype = enum u16 { ALL = 255, }; -// Class type +// Class type (e.g. Internet). export type class = enum u16 { IN = 1, CS = 2, @@ -38,7 +40,7 @@ export type class = enum u16 { HS = 4, }; -// Query class +// Query class (superset of [[class]]). export type qclass = enum u16 { IN = 1, CS = 2, @@ -68,14 +70,14 @@ export type qr = enum u8 { RESPONSE = 1, }; -// Operation requested from resolver +// Operation requested from resolver. export type opcode = enum u8 { QUERY = 0, IQUERY = 1, STATUS = 2, }; -// Response code from resolver +// Response code from resolver. export type rcode = enum u8 { NO_ERROR = 0, FMT_ERROR = 1, @@ -116,9 +118,27 @@ export type rrecord = struct { rtype: rtype, class: class, ttl: u32, - rdata: []u8, + rdata: rdata, }; +// An A record. +export type a = ip::addr4; + +// An AAAA record. +export type aaaa = ip::addr6; + +// An MX record. +export type mx = struct { + priority: u16, + name: []str, +}; + +// The raw rdata field for an [[rrecord]] with an unknown [[rtype]]. +export type unknown_rdata = []u8; + +// Tagged union of supported rdata types. +export type rdata = (a | aaaa | mx | unknown_rdata); + // A DNS message, Hare representation. See [[encode]] and [[decode]] for the DNS // representation. export type message = struct { @@ -129,13 +149,6 @@ export type message = struct { additional: []rrecord, }; -fn strings_free(in: []str) void = { - for (let i = 0z; i < len(in); i += 1) { - free(in[i]); - }; - free(in); -}; - // Frees a [[message]] and the resources associated with it. export fn message_free(msg: *message) void = { for (let i = 0z; i < len(msg.questions); i += 1) { @@ -144,22 +157,34 @@ export fn message_free(msg: *message) void = { free(msg.questions); for (let i = 0z; i < len(msg.answers); i += 1) { - strings_free(msg.answers[i].name); - free(msg.answers[i].rdata); + rrecord_finish(&msg.answers[i]); }; free(msg.answers); for (let i = 0z; i < len(msg.authority); i += 1) { - strings_free(msg.authority[i].name); - free(msg.authority[i].rdata); + rrecord_finish(&msg.authority[i]); }; free(msg.authority); for (let i = 0z; i < len(msg.additional); i += 1) { - strings_free(msg.additional[i].name); - free(msg.additional[i].rdata); + rrecord_finish(&msg.additional[i]); }; free(msg.additional); free(msg); }; + +fn strings_free(in: []str) void = { + for (let i = 0z; i < len(in); i += 1) { + free(in[i]); + }; + free(in); +}; + +fn rrecord_finish(rr: *rrecord) void = { + strings_free(rr.name); + match (rr.rdata) { + mx: mx => strings_free(mx.name), + * => void, + }; +};