hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

query.ha (3294B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 use endian;
      5 use errors;
      6 use io;
      7 use net;
      8 use net::ip;
      9 use net::tcp;
     10 use net::udp;
     11 use time;
     12 use types;
     13 use unix::poll;
     14 use unix::resolvconf;
     15 
     16 // TODO: Let user customize this?
     17 def timeout: time::duration = 3 * time::SECOND;
     18 
     19 // Performs a DNS query using the provided list of DNS servers. The caller must
     20 // free the return value with [[message_free]].
     21 //
     22 // If no DNS servers are provided, the system default servers (if any) are used.
     23 export fn query(query: *message, servers: ip::addr...) (*message | error) = {
     24 	if (len(servers) == 0) {
     25 		servers = resolvconf::load();
     26 	};
     27 	if (len(servers) == 0) {
     28 		// Fall back to localhost
     29 		servers = [ip::LOCAL_V6, ip::LOCAL_V4];
     30 	};
     31 
     32 	let socket4 = udp::listen(ip::ANY_V4, 0)?;
     33 	defer net::close(socket4)!;
     34 	let socket6 = udp::listen(ip::ANY_V6, 0)?;
     35 	defer net::close(socket6)!;
     36 	const pollfd: [_]poll::pollfd = [
     37 		poll::pollfd {
     38 			fd = socket4,
     39 			events = poll::event::POLLIN,
     40 			...
     41 		},
     42 		poll::pollfd {
     43 			fd = socket6,
     44 			events = poll::event::POLLIN,
     45 			...
     46 		},
     47 	];
     48 
     49 	let buf: [512]u8 = [0...];
     50 	let z = encode(buf, query)?;
     51 
     52 	// We send requests in parallel to all configured servers and take the
     53 	// first one which sends us a reasonable answer.
     54 	for (let i = 0z; i < len(servers); i += 1) match (servers[i]) {
     55 	case ip::addr4 =>
     56 		udp::sendto(socket4, buf[..z], servers[i], 53)?;
     57 	case ip::addr6 =>
     58 		udp::sendto(socket6, buf[..z], servers[i], 53)?;
     59 	};
     60 
     61 	let header = header { ... };
     62 	let src: ip::addr = ip::ANY_V4;
     63 	for (true) {
     64 		let nevent = poll::poll(pollfd, timeout)!;
     65 		if (nevent == 0) {
     66 			return errors::timeout;
     67 		};
     68 
     69 		if (pollfd[0].revents & poll::event::POLLIN != 0) {
     70 			z = udp::recvfrom(socket4, buf, &src, null)?;
     71 		};
     72 		if (pollfd[1].revents & poll::event::POLLIN != 0) {
     73 			z = udp::recvfrom(socket6, buf, &src, null)?;
     74 		};
     75 
     76 		let expected = false;
     77 		for (let i = 0z; i < len(servers); i += 1) {
     78 			if (ip::equal(src, servers[i])) {
     79 				expected = true;
     80 				break;
     81 			};
     82 		};
     83 		if (!expected) {
     84 			continue;
     85 		};
     86 
     87 		const dec = decoder_init(buf[..z]);
     88 		decode_header(&dec, &header)?;
     89 		if (header.id == query.header.id && header.op.qr == qr::RESPONSE) {
     90 			break;
     91 		};
     92 	};
     93 
     94 	if (!header.op.tc) {
     95 		check_rcode(header.op.rcode)?;
     96 		return decode(buf[..z])?;
     97 	};
     98 
     99 	// Response was truncated, retry over TCP. In TCP mode, the
    100 	// query is preceded by two bytes indicating the query length
    101 	z = encode(buf, query)?;
    102 	if (z > types::U16_MAX) {
    103 		return errors::overflow;
    104 	};
    105 	let zbuf: [2]u8 = [0...];
    106 	endian::beputu16(zbuf, z: u16);
    107 	let socket = tcp::connect(src, 53)?;
    108 	defer net::close(socket)!;
    109 
    110 	io::writeall(socket, zbuf)!;
    111 	io::writeall(socket, buf[..z])!;
    112 
    113 	let rz: u16 = match (io::readall(socket, zbuf)?) {
    114 	case let s: size =>
    115 		if (s != 2) {
    116 			return format;
    117 		};
    118 		yield endian::begetu16(zbuf);
    119 	case =>
    120 		return format;
    121 	};
    122 	let tcpbuf: []u8 = alloc([0...], rz);
    123 	defer free(tcpbuf);
    124 
    125 	match (io::readall(socket, tcpbuf)?) {
    126 	case let s: size =>
    127 		if (s != rz) {
    128 			return format;
    129 		};
    130 	case =>
    131 		return format;
    132 	};
    133 
    134 	const dec = decoder_init(tcpbuf);
    135 	decode_header(&dec, &header)?;
    136 	if ((header.id != query.header.id) || header.op.tc) {
    137 		return format;
    138 	};
    139 	check_rcode(header.op.rcode)?;
    140 	return decode(tcpbuf)?;
    141 };