commit a6923d1f370a1198a8bb7d849f309287fc1a02a4
parent 081ba3c1923b29b9af5ef5871cdd903c90b0eecb
Author: Drew DeVault <sir@cmpwn.com>
Date: Wed, 23 Jun 2021 11:37:01 -0400
net::dns: refactor decoding for improved rdata
This abandons the stack-based API, which was only going to get more
complex and brittle (and hard to use) with the introduction of more
rdata formats, and instead just focuses on heap-allocating the decoded
message data. This also makes the rdata field use a more Hairy type
(tagged union of known rdata formats) and fully implements MX decoding.
Signed-off-by: Drew DeVault <sir@cmpwn.com>
Diffstat:
4 files changed, 127 insertions(+), 154 deletions(-)
diff --git a/net/dns/decode.ha b/net/dns/decode.ha
@@ -1,5 +1,6 @@
use ascii;
use endian;
+use fmt;
use net::ip;
use strings;
@@ -21,16 +22,7 @@ export fn decode(buf: []u8) (*message | format) = {
decode_header(&dec, &msg.header)?;
for (let i = 0z; i < msg.header.qdcount; i += 1) {
- let q = question { ... };
- let names = decode_question(&dec, &q)?;
-
- for (let i = 0; len(names) != 0; i += 1) {
- let ns = decode_name(&dec, names)!;
- names = ns.0;
- append(q.qname, strings::dup(ns.1));
- };
-
- append(msg.questions, q);
+ append(msg.questions, decode_question(&dec)?);
};
decode_rrecords(&dec, msg.header.ancount, &msg.answers)?;
@@ -46,26 +38,12 @@ fn decode_rrecords(
) (void | format) = {
for (let i = 0z; i < count; i += 1) {
let r = rrecord { ... };
- let names = decode_rrecord(dec, &r)?;
- let rdata = r.rdata;
- r.rdata = [];
- append(r.rdata, rdata...);
-
- for (let i = 0; len(names) != 0; i += 1) {
- let ns = decode_name(dec, names)!;
- names = ns.0;
- append(r.name, strings::dup(ns.1));
- };
-
+ decode_rrecord(dec, &r)?;
append(*out, r);
};
};
-// Initializes a DNS message decoder. All storaged used by the decoder is either
-// stack-allocated, provided by the caller, or borrowed from the input buffer.
-//
-// Call [[decode_header]] next.
-export fn decoder_init(buf: []u8) decoder = decoder {
+fn decoder_init(buf: []u8) decoder = decoder {
buf = buf,
cur = buf,
...
@@ -89,12 +67,7 @@ fn decode_u32(dec: *decoder) (u32 | format) = {
return val;
};
-// Decodes a DNS message's header and advances the decoder to the
-// variable-length section of the message. Following this call, the user should
-// call [[decode_question]] for each question given by the header's qdcount,
-// then [[decode_rrecord]] for each resource record given by the ancount,
-// nscount, and arcount fields, respectively.
-export fn decode_header(dec: *decoder, head: *header) (void | format) = {
+fn decode_header(dec: *decoder, head: *header) (void | format) = {
head.id = decode_u16(dec)?;
let rawop = decode_u16(dec)?;
op_decode(rawop, &head.op);
@@ -118,113 +91,94 @@ fn op_decode(in: u16, out: *op) void = {
out.rcode = (in & 0b1111): rcode;
};
-// Partially decodes a [[question]] and advances the decoder. Returns a slice
-// representing the name field, which can be passed to [[decode_name]] to
-// interpret.
-export fn decode_question(dec: *decoder, q: *question) ([]u8 | format) = {
- let name = extract_name(dec)?;
- q.qtype = decode_u16(dec)?: qtype;
- q.qclass = decode_u16(dec)?: qclass;
- return name;
-};
-
-// Partially decodes a [[rrecord]] and advances the decoder. Returns a slice
-// representing the name field, which can be passed to [[decode_name]] to
-// interpret.
-export fn decode_rrecord(dec: *decoder, r: *rrecord) ([]u8 | format) = {
- let name = extract_name(dec)!;
- r.rtype = decode_u16(dec)!: rtype;
- r.class = decode_u16(dec)!: class;
- r.ttl = decode_u32(dec)!;
- let rdz = decode_u16(dec)!;
- r.rdata = dec.cur[..rdz];
- dec.cur = dec.cur[rdz..];
- return name;
-};
-
-fn extract_name(dec: *decoder) ([]u8 | format) = {
- if (dec.cur[0] & 0b11000000 == 0b11000000) {
- const name = dec.cur[..2];
- dec.cur = dec.cur[2..];
- return name;
- };
- for (let i = 0z; i < len(dec.cur); i += 1) {
- let z = dec.cur[i];
+fn decode_name(dec: *decoder) ([]str | format) = {
+ let names: []str = [];
+ for (true) {
+ let z = dec.cur[0];
+ if (z & 0b11000000 == 0b11000000) {
+ let offs = decode_u16(dec)? & ~0b1100000000000000u16;
+ let sub = decoder {
+ buf = dec.buf,
+ cur = dec.buf[offs..],
+ ...
+ };
+ append(names, decode_name(&sub)?...);
+ break;
+ };
+ dec.cur = dec.cur[1..];
if (z == 0) {
- const name = dec.cur[..i + 1];
- dec.cur = dec.cur[i + 1..];
- return name;
+ break;
+ };
+
+ let name = dec.cur[..z];
+ dec.cur = dec.cur[z..];
+ for (let i = 0z; i < len(name); i += 1) {
+ if (!ascii::isascii(name[i]: u32: rune)) {
+ fmt::errorfln("not ascii at index {}", i)!;
+ return format;
+ };
};
- i += z;
+
+ append(names, strings::dup(strings::fromutf8(name)));
};
- return format;
+ return names;
};
-// Decodes a name from a question or resource record, returning the decoded name
-// and the remainder of the buffer. The caller should pass the returned buffer
-// into decode_name again to retrieve the next name. When the return value is an
-// empty string, all of the names have been decoded. It is a programming error
-// to call decode_name again after this, and the program will abort.
-export fn decode_name(dec: *decoder, buf: []u8) (([]u8, str) | format) = {
- let z = buf[0];
- if (z == 0) {
- return ([]: []u8, "");
- };
- if (z & 0b11000000 == 0b11000000) {
- let offs = endian::begetu16(buf) & ~0b1100000000000000u16;
- return decode_name(dec, dec.buf[offs..]);
+fn decode_question(dec: *decoder) (question | format) = {
+ return question {
+ qname = decode_name(dec)?,
+ qtype = decode_u16(dec)?: qtype,
+ qclass = decode_u16(dec)?: qclass,
};
- let name = buf[1..z + 1];
- buf = buf[z + 1..];
- for (let i = 0z; i < len(name); i += 1) {
- if (!ascii::isascii(name[i]: u32: rune)) {
- return format;
- };
- };
- return (buf, strings::fromutf8(name));
};
-// Decodes the rdata field of a [[rrecord]]. The return value is borrowed from
-// the rdata buffer.
-export fn decode_rdata(rr: *rrecord) (ip::addr | format) = {
- return switch (rr.rtype) {
- rtype::A => decode_a(rr.rdata)?: ip::addr,
- rtype::AAAA => decode_aaaa(rr.rdata)?: ip::addr,
- * => format,
+fn decode_rrecord(dec: *decoder, r: *rrecord) (void | format) = {
+ r.name = decode_name(dec)?;
+ r.rtype = decode_u16(dec)?: rtype;
+ r.class = decode_u16(dec)?: class;
+ r.ttl = decode_u32(dec)?;
+ let rdz = decode_u16(dec)?;
+ r.rdata = decode_rdata(dec, r.rtype, rdz)?;
+};
+
+fn decode_rdata(dec: *decoder, rtype: rtype, rlen: size) (rdata | format) = {
+ return switch (rtype) {
+ rtype::A => decode_a(dec),
+ rtype::AAAA => decode_aaaa(dec),
+ rtype::MX => decode_mx(dec),
+ * => {
+ let buf = dec.cur[..rlen];
+ dec.cur = dec.cur[rlen..];
+ return buf: unknown_rdata;
+ },
};
};
-// Decodes the rdata field of an A (address) record. The return value is
-// borrowed from the rdata buffer.
-export fn decode_a(rdata: []u8) (ip::addr4 | format) = {
- if (len(rdata) != 4) {
+fn decode_a(dec: *decoder) (rdata | format) = {
+ if (len(dec.cur) != 4) {
return format;
};
let ip: ip::addr4 = [0...];
- ip[..] = rdata[..];
- return ip;
+ ip[..] = dec.cur[..4];
+ dec.cur = dec.cur[4..];
+ return ip: a;
};
-// Decodes the rdata field of an AAAA (address) record. The return value is
-// borrowed from the rdata buffer.
-export fn decode_aaaa(rdata: []u8) (ip::addr6 | format) = {
- if (len(rdata) != 8) {
+fn decode_aaaa(dec: *decoder) (rdata | format) = {
+ if (len(dec.cur) != 8) {
return format;
};
let ip: ip::addr6 = [0...];
- ip[..] = rdata[..];
- return ip;
+ ip[..] = dec.cur[..8];
+ dec.cur = dec.cur[8..];
+ return ip: aaaa;
};
-// Decodes the rdata field of an MX (mail exchange) record, returning the
-// priority and the name. See [[decode_name]] to decode the name. The return
-// value is borrowed from the rdata buffer.
-export fn decode_mx(rdata: []u8) ((u16, []u8) | format) = {
- if (len(rdata) < 2) {
- return format;
+fn decode_mx(dec: *decoder) (rdata | format) = {
+ return mx {
+ priority = decode_u16(dec)?,
+ name = decode_name(dec)?,
};
- let prio = endian::begetu16(rdata);
- return (prio, rdata[2..]);
};
// TODO: Expand breadth of supported rdata decoders
diff --git a/net/dns/encode.ha b/net/dns/encode.ha
@@ -85,15 +85,7 @@ fn rrecord_encode(enc: *encoder, r: *rrecord) (void | error) = {
encode_u16(enc, r.class)?;
encode_u32(enc, r.ttl)?;
- assert(len(r.rdata) <= 0xFFFF);
- encode_u16(enc, len(r.rdata): u16)?;
-
- if (len(enc.buf) <= enc.offs + len(r.rdata)) {
- return errors::overflow;
- };
-
- enc.buf[enc.offs..len(r.rdata)] = r.rdata[..];
- enc.offs += len(r.rdata);
+ abort(); // TODO
};
fn op_encode(op: *op) u16 =
diff --git a/net/dns/error.ha b/net/dns/error.ha
@@ -42,12 +42,14 @@ export fn strerror(err: error) const str = {
};
};
-fn check_rcode(rcode: rcode) (void | error) = switch (rcode) {
- rcode::NO_ERROR => void,
- rcode::FMT_ERROR => format,
- rcode::SERVER_FAILURE => server_failure,
- rcode::NAME_ERROR => name_error,
- rcode::NOT_IMPLEMENTED => not_implemented,
- rcode::REFUSED => refused,
- * => rcode: unknown_error,
+fn check_rcode(rcode: rcode) (void | error) = {
+ return switch (rcode) {
+ rcode::NO_ERROR => void,
+ rcode::FMT_ERROR => format,
+ rcode::SERVER_FAILURE => server_failure,
+ rcode::NAME_ERROR => name_error,
+ rcode::NOT_IMPLEMENTED => not_implemented,
+ rcode::REFUSED => refused,
+ * => rcode: unknown_error,
+ };
};
diff --git a/net/dns/types.ha b/net/dns/types.ha
@@ -1,4 +1,6 @@
-// Record type
+use net::ip;
+
+// Record type.
export type rtype = enum u16 {
A = 1,
NS = 2,
@@ -12,7 +14,7 @@ export type rtype = enum u16 {
DNSKEY = 48,
};
-// Question type
+// Question type (superset of [[rtype]]).
export type qtype = enum u16 {
A = 1,
NS = 2,
@@ -30,7 +32,7 @@ export type qtype = enum u16 {
ALL = 255,
};
-// Class type
+// Class type (e.g. Internet).
export type class = enum u16 {
IN = 1,
CS = 2,
@@ -38,7 +40,7 @@ export type class = enum u16 {
HS = 4,
};
-// Query class
+// Query class (superset of [[class]]).
export type qclass = enum u16 {
IN = 1,
CS = 2,
@@ -68,14 +70,14 @@ export type qr = enum u8 {
RESPONSE = 1,
};
-// Operation requested from resolver
+// Operation requested from resolver.
export type opcode = enum u8 {
QUERY = 0,
IQUERY = 1,
STATUS = 2,
};
-// Response code from resolver
+// Response code from resolver.
export type rcode = enum u8 {
NO_ERROR = 0,
FMT_ERROR = 1,
@@ -116,9 +118,27 @@ export type rrecord = struct {
rtype: rtype,
class: class,
ttl: u32,
- rdata: []u8,
+ rdata: rdata,
};
+// An A record.
+export type a = ip::addr4;
+
+// An AAAA record.
+export type aaaa = ip::addr6;
+
+// An MX record.
+export type mx = struct {
+ priority: u16,
+ name: []str,
+};
+
+// The raw rdata field for an [[rrecord]] with an unknown [[rtype]].
+export type unknown_rdata = []u8;
+
+// Tagged union of supported rdata types.
+export type rdata = (a | aaaa | mx | unknown_rdata);
+
// A DNS message, Hare representation. See [[encode]] and [[decode]] for the DNS
// representation.
export type message = struct {
@@ -129,13 +149,6 @@ export type message = struct {
additional: []rrecord,
};
-fn strings_free(in: []str) void = {
- for (let i = 0z; i < len(in); i += 1) {
- free(in[i]);
- };
- free(in);
-};
-
// Frees a [[message]] and the resources associated with it.
export fn message_free(msg: *message) void = {
for (let i = 0z; i < len(msg.questions); i += 1) {
@@ -144,22 +157,34 @@ export fn message_free(msg: *message) void = {
free(msg.questions);
for (let i = 0z; i < len(msg.answers); i += 1) {
- strings_free(msg.answers[i].name);
- free(msg.answers[i].rdata);
+ rrecord_finish(&msg.answers[i]);
};
free(msg.answers);
for (let i = 0z; i < len(msg.authority); i += 1) {
- strings_free(msg.authority[i].name);
- free(msg.authority[i].rdata);
+ rrecord_finish(&msg.authority[i]);
};
free(msg.authority);
for (let i = 0z; i < len(msg.additional); i += 1) {
- strings_free(msg.additional[i].name);
- free(msg.additional[i].rdata);
+ rrecord_finish(&msg.additional[i]);
};
free(msg.additional);
free(msg);
};
+
+fn strings_free(in: []str) void = {
+ for (let i = 0z; i < len(in); i += 1) {
+ free(in[i]);
+ };
+ free(in);
+};
+
+fn rrecord_finish(rr: *rrecord) void = {
+ strings_free(rr.name);
+ match (rr.rdata) {
+ mx: mx => strings_free(mx.name),
+ * => void,
+ };
+};