hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

commit 7ed1697b32d5460819e27202ad7d3b66615b8418
parent 87e128dec27e8d0c269365ad52a169c5e8375573
Author: Conrad Hoffmann <ch@bitfehler.net>
Date:   Mon, 10 Jul 2023 15:44:23 +0200

net::dns: retry over TCP on truncated response

Signed-off-by: Conrad Hoffmann <ch@bitfehler.net>

Diffstat:
Mnet/dns/error.ha | 5++++-
Mnet/dns/query.ha | 53++++++++++++++++++++++++++++++++++++++++++++++++++---
2 files changed, 54 insertions(+), 4 deletions(-)

diff --git a/net/dns/error.ha b/net/dns/error.ha @@ -3,6 +3,7 @@ // (c) 2021 Ember Sawady <ecs@d2evs.net> use errors; use fmt; +use io; use net; // The DNS message was poorly formatted. @@ -29,7 +30,7 @@ export type unknown_error = !u8; // All error types which might be returned from [[net::dns]] functions. export type error = !(format | server_failure | name_error | not_implemented | refused | unknown_error - | errors::overflow | errors::timeout | net::error); + | errors::overflow | errors::timeout | net::error | io::error); export fn strerror(err: error) const str = { static let buf: [64]u8 = [0...]; @@ -52,6 +53,8 @@ export fn strerror(err: error) const str = { return "The DNS request timed out"; case let err: net::error => return net::strerror(err); + case let err: io::error => + return io::strerror(err); }; }; diff --git a/net/dns/query.ha b/net/dns/query.ha @@ -1,10 +1,14 @@ // License: MPL-2.0 // (c) 2021 Drew DeVault <sir@cmpwn.com> +use endian; use errors; +use io; use net; use net::ip; use net::udp; +use net::tcp; use time; +use types; use unix::poll; use unix::resolvconf; @@ -55,13 +59,13 @@ export fn query(query: *message, servers: ip::addr...) (*message | error) = { }; let header = header { ... }; + let src: ip::addr = ip::ANY_V4; for (true) { let nevent = poll::poll(pollfd, timeout)!; if (nevent == 0) { return errors::timeout; }; - let src: ip::addr = ip::ANY_V4; if (pollfd[0].revents & poll::event::POLLIN != 0) { z = udp::recvfrom(socket4, buf, &src, null)?; }; @@ -87,8 +91,51 @@ export fn query(query: *message, servers: ip::addr...) (*message | error) = { }; }; - assert(!header.op.tc, "TODO: Retry with TCP for truncated DNS response"); + if (!header.op.tc) { + check_rcode(header.op.rcode)?; + return decode(buf[..z])?; + }; + + // Response was truncated, retry over TCP. In TCP mode, the + // query is preceded by two bytes indicating the query length + z = encode(buf, query)?; + if (z > types::U16_MAX) { + return errors::overflow; + }; + let zbuf: [2]u8 = [0...]; + endian::beputu16(zbuf, z: u16); + let socket = tcp::connect(src, 53)?; + defer net::close(socket)!; + io::writeall(socket, zbuf)!; + io::writeall(socket, buf[..z])!; + + let rz: u16 = match (io::readall(socket, zbuf)?) { + case let s: size => + if (s != 2) { + return format; + }; + yield endian::begetu16(zbuf); + case => + return format; + }; + let tcpbuf: []u8 = alloc([0...], rz); + defer free(tcpbuf); + + match (io::readall(socket, tcpbuf)?) { + case let s: size => + if (s != rz) { + return format; + }; + case => + return format; + }; + + const dec = decoder_init(tcpbuf); + decode_header(&dec, &header)?; + if ((header.id != query.header.id) || header.op.tc) { + return format; + }; check_rcode(header.op.rcode)?; - return decode(buf[..z])?; + return decode(tcpbuf)?; };