commit f199b4b03b394f26671447c9113eb8faf9303433
parent bee3c3a8b5847449ce6c4ac79e2f7354edc4f1db
Author: Drew DeVault <sir@cmpwn.com>
Date: Tue, 22 Jun 2021 13:52:00 -0400
net::dns: refactor encode
Adds error handling and streamlines the code a bit.
Signed-off-by: Drew DeVault <sir@cmpwn.com>
Diffstat:
3 files changed, 77 insertions(+), 59 deletions(-)
diff --git a/net/dns/encode.ha b/net/dns/encode.ha
@@ -1,88 +1,104 @@
// TODO: Refactor me
+use errors;
use endian;
use fmt;
-// Encodes a DNS message, returning its size.
-export fn encode(buf: []u8, msg: *message) size = {
- let z = 0z;
- endian::beputu16(buf[z..], msg.header.id);
- z += 2;
- endian::beputu16(buf[z..], op_encode(&msg.header.op));
- z += 2;
- endian::beputu16(buf[z..], msg.header.qdcount);
- z += 2;
- endian::beputu16(buf[z..], msg.header.ancount);
- z += 2;
- endian::beputu16(buf[z..], msg.header.nscount);
- z += 2;
- endian::beputu16(buf[z..], msg.header.arcount);
- z += 2;
+type encoder = struct {
+ buf: []u8,
+ offs: size,
+};
+
+// Encodes a DNS message, returning its size, or an error.
+export fn encode(buf: []u8, msg: *message) (size | error) = {
+ let enc = encoder { buf = buf, offs = 0z };
+ encode_u16(&enc, msg.header.id)?;
+ encode_u16(&enc, op_encode(&msg.header.op))?;
+ encode_u16(&enc, msg.header.qdcount)?;
+ encode_u16(&enc, msg.header.ancount)?;
+ encode_u16(&enc, msg.header.nscount)?;
+ encode_u16(&enc, msg.header.arcount)?;
for (let i = 0z; i < len(msg.questions); i += 1) {
- z += question_encode(buf[z..], &msg.questions[i]);
+ question_encode(&enc, &msg.questions[i])?;
};
for (let i = 0z; i < len(msg.answers); i += 1) {
- z += rrecord_encode(buf[z..], &msg.answers[i]);
+ rrecord_encode(&enc, &msg.answers[i])?;
};
for (let i = 0z; i < len(msg.authority); i += 1) {
- z += rrecord_encode(buf[z..], &msg.authority[i]);
+ rrecord_encode(&enc, &msg.authority[i])?;
};
for (let i = 0z; i < len(msg.additional); i += 1) {
- z += rrecord_encode(buf[z..], &msg.additional[i]);
+ rrecord_encode(&enc, &msg.additional[i])?;
+ };
+
+ return enc.offs;
+};
+
+fn encode_u8(enc: *encoder, val: u8) (void | error) = {
+ if (len(enc.buf) <= enc.offs + 1) {
+ return errors::overflow;
+ };
+ enc.buf[enc.offs] = val;
+ enc.offs += 1;
+};
+
+fn encode_u16(enc: *encoder, val: u16) (void | error) = {
+ if (len(enc.buf) <= enc.offs + 2) {
+ return errors::overflow;
};
+ endian::beputu16(enc.buf[enc.offs..], val);
+ enc.offs += 2;
+};
- return z;
+fn encode_u32(enc: *encoder, val: u32) (void | error) = {
+ if (len(enc.buf) <= enc.offs + 4) {
+ return errors::overflow;
+ };
+ endian::beputu32(enc.buf[enc.offs..], val);
+ enc.offs += 4;
};
-fn question_encode(buf: []u8, q: *question) size = {
+fn question_encode(enc: *encoder, q: *question) (void | error) = {
// TODO: Assert that the labels are all valid ASCII?
- let z = 0z;
for (let i = 0z; i < len(q.qname); i += 1) {
assert(len(q.qname[i]) < 256);
- buf[z] = len(q.qname[i]): u8;
- z += 1;
- let label = fmt::bsprintf(buf[z..], "{}", q.qname[i]);
- z += len(label);
+ if (len(enc.buf) <= enc.offs + 1 + len(q.qname[i])) {
+ return errors::overflow;
+ };
+ encode_u8(enc, len(q.qname[i]): u8)?;
+ let label = fmt::bsprintf(enc.buf[enc.offs..], "{}", q.qname[i]);
+ enc.offs += len(label);
};
- // Root
- buf[z] = 0;
- z += 1;
- // Trailers
- endian::beputu16(buf[z..], q.qtype);
- z += 2;
- endian::beputu16(buf[z..], q.qclass);
- z += 2;
- return z;
+ encode_u8(enc, 0)?;
+ encode_u16(enc, q.qtype)?;
+ encode_u16(enc, q.qclass)?;
};
-fn rrecord_encode(buf: []u8, r: *rrecord) size = {
+fn rrecord_encode(enc: *encoder, r: *rrecord) (void | error) = {
// TODO: Assert that the labels are all valid ASCII?
- let z = 0z;
for (let i = 0z; i < len(r.name); i += 1) {
assert(len(r.name[i]) < 256);
- buf[z] = len(r.name[i]): u8;
- z += 1;
- let label = fmt::bsprintf(buf[z..], "{}", r.name[i]);
- z += len(label);
+ if (len(enc.buf) <= enc.offs + 1 + len(r.name[i])) {
+ return errors::overflow;
+ };
+ encode_u8(enc, len(r.name[i]): u8)?;
+ let label = fmt::bsprintf(enc.buf[enc.offs..], "{}", r.name[i]);
+ enc.offs += len(label);
};
- // Root
- buf[z] = 0;
- z += 1;
-
- endian::beputu16(buf[z..], r.rtype);
- z += 2;
- endian::beputu16(buf[z..], r.class);
- z += 2;
- endian::beputu32(buf[z..], r.ttl);
- z += 4;
+ encode_u8(enc, 0)?;
+ encode_u16(enc, r.rtype)?;
+ encode_u16(enc, r.class)?;
+ encode_u32(enc, r.ttl)?;
assert(len(r.rdata) <= 0xFFFF);
- endian::beputu16(buf[z..], len(r.rdata): u16);
- z += 2;
+ encode_u16(enc, len(r.rdata): u16)?;
+
+ if (len(enc.buf) <= enc.offs + len(r.rdata)) {
+ return errors::overflow;
+ };
- buf[z..len(r.rdata)] = r.rdata[..];
- z += len(r.rdata);
- return z;
+ enc.buf[enc.offs..len(r.rdata)] = r.rdata[..];
+ enc.offs += len(r.rdata);
};
fn op_encode(op: *op) u16 =
diff --git a/net/dns/error.ha b/net/dns/error.ha
@@ -1,3 +1,4 @@
+use errors;
use fmt;
use net;
@@ -23,9 +24,9 @@ export type refused = !void;
export type unknown_error = !u8;
// All error types which might be returned from [[net::dns]] functions.
-export type error = (format | server_failure | name_error
+export type error = !(format | server_failure | name_error
| not_implemented | refused | unknown_error
- | net::error);
+ | errors::overflow | net::error);
export fn strerror(err: error) const str = {
static let buf: [64]u8 = [0...];
@@ -36,6 +37,7 @@ export fn strerror(err: error) const str = {
not_implemented => "The name server does not support the requested kind of query",
refused => "The name server refuses to perform the specified operation for policy reasons",
ue: unknown_error => fmt::bsprintf(buf, "Unknown DNS error {}", ue: u8),
+ errors::overflow => "The encoded message would exceed the buffer size",
err: net::error => net::strerror(err),
};
};
diff --git a/net/dns/query.ha b/net/dns/query.ha
@@ -16,7 +16,7 @@ export fn query(query: *message, addr: ip::addr...) (*message | error) = {
// TODO: Use TCP for messages >512 bytes
let sendbuf: [512]u8 = [0...];
- let z = encode(sendbuf, query);
+ let z = encode(sendbuf, query)?;
// TODO: Query multiple servers
udp::sendto(socket, sendbuf[..z], addr[0], 53)?;