hare

The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

commit f199b4b03b394f26671447c9113eb8faf9303433
parent bee3c3a8b5847449ce6c4ac79e2f7354edc4f1db
Author: Drew DeVault <sir@cmpwn.com>
Date:   Tue, 22 Jun 2021 13:52:00 -0400

net::dns: refactor encode

Adds error handling and streamlines the code a bit.

Signed-off-by: Drew DeVault <sir@cmpwn.com>

Diffstat:
Mnet/dns/encode.ha | 128++++++++++++++++++++++++++++++++++++++++++++-----------------------------------
Mnet/dns/error.ha | 6++++--
Mnet/dns/query.ha | 2+-
3 files changed, 77 insertions(+), 59 deletions(-)

diff --git a/net/dns/encode.ha b/net/dns/encode.ha @@ -1,88 +1,104 @@ // TODO: Refactor me +use errors; use endian; use fmt; -// Encodes a DNS message, returning its size. -export fn encode(buf: []u8, msg: *message) size = { - let z = 0z; - endian::beputu16(buf[z..], msg.header.id); - z += 2; - endian::beputu16(buf[z..], op_encode(&msg.header.op)); - z += 2; - endian::beputu16(buf[z..], msg.header.qdcount); - z += 2; - endian::beputu16(buf[z..], msg.header.ancount); - z += 2; - endian::beputu16(buf[z..], msg.header.nscount); - z += 2; - endian::beputu16(buf[z..], msg.header.arcount); - z += 2; +type encoder = struct { + buf: []u8, + offs: size, +}; + +// Encodes a DNS message, returning its size, or an error. +export fn encode(buf: []u8, msg: *message) (size | error) = { + let enc = encoder { buf = buf, offs = 0z }; + encode_u16(&enc, msg.header.id)?; + encode_u16(&enc, op_encode(&msg.header.op))?; + encode_u16(&enc, msg.header.qdcount)?; + encode_u16(&enc, msg.header.ancount)?; + encode_u16(&enc, msg.header.nscount)?; + encode_u16(&enc, msg.header.arcount)?; for (let i = 0z; i < len(msg.questions); i += 1) { - z += question_encode(buf[z..], &msg.questions[i]); + question_encode(&enc, &msg.questions[i])?; }; for (let i = 0z; i < len(msg.answers); i += 1) { - z += rrecord_encode(buf[z..], &msg.answers[i]); + rrecord_encode(&enc, &msg.answers[i])?; }; for (let i = 0z; i < len(msg.authority); i += 1) { - z += rrecord_encode(buf[z..], &msg.authority[i]); + rrecord_encode(&enc, &msg.authority[i])?; }; for (let i = 0z; i < len(msg.additional); i += 1) { - z += rrecord_encode(buf[z..], &msg.additional[i]); + rrecord_encode(&enc, &msg.additional[i])?; + }; + + return enc.offs; +}; + +fn encode_u8(enc: *encoder, val: u8) (void | error) = { + if (len(enc.buf) <= enc.offs + 1) { + return errors::overflow; + }; + enc.buf[enc.offs] = val; + enc.offs += 1; +}; + +fn encode_u16(enc: *encoder, val: u16) (void | error) = { + if (len(enc.buf) <= enc.offs + 2) { + return errors::overflow; }; + endian::beputu16(enc.buf[enc.offs..], val); + enc.offs += 2; +}; - return z; +fn encode_u32(enc: *encoder, val: u32) (void | error) = { + if (len(enc.buf) <= enc.offs + 4) { + return errors::overflow; + }; + endian::beputu32(enc.buf[enc.offs..], val); + enc.offs += 4; }; -fn question_encode(buf: []u8, q: *question) size = { +fn question_encode(enc: *encoder, q: *question) (void | error) = { // TODO: Assert that the labels are all valid ASCII? - let z = 0z; for (let i = 0z; i < len(q.qname); i += 1) { assert(len(q.qname[i]) < 256); - buf[z] = len(q.qname[i]): u8; - z += 1; - let label = fmt::bsprintf(buf[z..], "{}", q.qname[i]); - z += len(label); + if (len(enc.buf) <= enc.offs + 1 + len(q.qname[i])) { + return errors::overflow; + }; + encode_u8(enc, len(q.qname[i]): u8)?; + let label = fmt::bsprintf(enc.buf[enc.offs..], "{}", q.qname[i]); + enc.offs += len(label); }; - // Root - buf[z] = 0; - z += 1; - // Trailers - endian::beputu16(buf[z..], q.qtype); - z += 2; - endian::beputu16(buf[z..], q.qclass); - z += 2; - return z; + encode_u8(enc, 0)?; + encode_u16(enc, q.qtype)?; + encode_u16(enc, q.qclass)?; }; -fn rrecord_encode(buf: []u8, r: *rrecord) size = { +fn rrecord_encode(enc: *encoder, r: *rrecord) (void | error) = { // TODO: Assert that the labels are all valid ASCII? - let z = 0z; for (let i = 0z; i < len(r.name); i += 1) { assert(len(r.name[i]) < 256); - buf[z] = len(r.name[i]): u8; - z += 1; - let label = fmt::bsprintf(buf[z..], "{}", r.name[i]); - z += len(label); + if (len(enc.buf) <= enc.offs + 1 + len(r.name[i])) { + return errors::overflow; + }; + encode_u8(enc, len(r.name[i]): u8)?; + let label = fmt::bsprintf(enc.buf[enc.offs..], "{}", r.name[i]); + enc.offs += len(label); }; - // Root - buf[z] = 0; - z += 1; - - endian::beputu16(buf[z..], r.rtype); - z += 2; - endian::beputu16(buf[z..], r.class); - z += 2; - endian::beputu32(buf[z..], r.ttl); - z += 4; + encode_u8(enc, 0)?; + encode_u16(enc, r.rtype)?; + encode_u16(enc, r.class)?; + encode_u32(enc, r.ttl)?; assert(len(r.rdata) <= 0xFFFF); - endian::beputu16(buf[z..], len(r.rdata): u16); - z += 2; + encode_u16(enc, len(r.rdata): u16)?; + + if (len(enc.buf) <= enc.offs + len(r.rdata)) { + return errors::overflow; + }; - buf[z..len(r.rdata)] = r.rdata[..]; - z += len(r.rdata); - return z; + enc.buf[enc.offs..len(r.rdata)] = r.rdata[..]; + enc.offs += len(r.rdata); }; fn op_encode(op: *op) u16 = diff --git a/net/dns/error.ha b/net/dns/error.ha @@ -1,3 +1,4 @@ +use errors; use fmt; use net; @@ -23,9 +24,9 @@ export type refused = !void; export type unknown_error = !u8; // All error types which might be returned from [[net::dns]] functions. -export type error = (format | server_failure | name_error +export type error = !(format | server_failure | name_error | not_implemented | refused | unknown_error - | net::error); + | errors::overflow | net::error); export fn strerror(err: error) const str = { static let buf: [64]u8 = [0...]; @@ -36,6 +37,7 @@ export fn strerror(err: error) const str = { not_implemented => "The name server does not support the requested kind of query", refused => "The name server refuses to perform the specified operation for policy reasons", ue: unknown_error => fmt::bsprintf(buf, "Unknown DNS error {}", ue: u8), + errors::overflow => "The encoded message would exceed the buffer size", err: net::error => net::strerror(err), }; }; diff --git a/net/dns/query.ha b/net/dns/query.ha @@ -16,7 +16,7 @@ export fn query(query: *message, addr: ip::addr...) (*message | error) = { // TODO: Use TCP for messages >512 bytes let sendbuf: [512]u8 = [0...]; - let z = encode(sendbuf, query); + let z = encode(sendbuf, query)?; // TODO: Query multiple servers udp::sendto(socket, sendbuf[..z], addr[0], 53)?;