query.ha (3294B)
1 // SPDX-License-Identifier: MPL-2.0 2 // (c) Hare authors <https://harelang.org> 3 4 use endian; 5 use errors; 6 use io; 7 use net; 8 use net::ip; 9 use net::tcp; 10 use net::udp; 11 use time; 12 use types; 13 use unix::poll; 14 use unix::resolvconf; 15 16 // TODO: Let user customize this? 17 def timeout: time::duration = 3 * time::SECOND; 18 19 // Performs a DNS query using the provided list of DNS servers. The caller must 20 // free the return value with [[message_free]]. 21 // 22 // If no DNS servers are provided, the system default servers (if any) are used. 23 export fn query(query: *message, servers: ip::addr...) (*message | error) = { 24 if (len(servers) == 0) { 25 servers = resolvconf::load(); 26 }; 27 if (len(servers) == 0) { 28 // Fall back to localhost 29 servers = [ip::LOCAL_V6, ip::LOCAL_V4]; 30 }; 31 32 let socket4 = udp::listen(ip::ANY_V4, 0)?; 33 defer net::close(socket4)!; 34 let socket6 = udp::listen(ip::ANY_V6, 0)?; 35 defer net::close(socket6)!; 36 const pollfd: [_]poll::pollfd = [ 37 poll::pollfd { 38 fd = socket4, 39 events = poll::event::POLLIN, 40 ... 41 }, 42 poll::pollfd { 43 fd = socket6, 44 events = poll::event::POLLIN, 45 ... 46 }, 47 ]; 48 49 let buf: [512]u8 = [0...]; 50 let z = encode(buf, query)?; 51 52 // We send requests in parallel to all configured servers and take the 53 // first one which sends us a reasonable answer. 54 for (let i = 0z; i < len(servers); i += 1) match (servers[i]) { 55 case ip::addr4 => 56 udp::sendto(socket4, buf[..z], servers[i], 53)?; 57 case ip::addr6 => 58 udp::sendto(socket6, buf[..z], servers[i], 53)?; 59 }; 60 61 let header = header { ... }; 62 let src: ip::addr = ip::ANY_V4; 63 for (true) { 64 let nevent = poll::poll(pollfd, timeout)!; 65 if (nevent == 0) { 66 return errors::timeout; 67 }; 68 69 if (pollfd[0].revents & poll::event::POLLIN != 0) { 70 z = udp::recvfrom(socket4, buf, &src, null)?; 71 }; 72 if (pollfd[1].revents & poll::event::POLLIN != 0) { 73 z = udp::recvfrom(socket6, buf, &src, null)?; 74 }; 75 76 let expected = false; 77 for (let i = 0z; i < len(servers); i += 1) { 78 if (ip::equal(src, servers[i])) { 79 expected = true; 80 break; 81 }; 82 }; 83 if (!expected) { 84 continue; 85 }; 86 87 const dec = decoder_init(buf[..z]); 88 decode_header(&dec, &header)?; 89 if (header.id == query.header.id && header.op.qr == qr::RESPONSE) { 90 break; 91 }; 92 }; 93 94 if (!header.op.tc) { 95 check_rcode(header.op.rcode)?; 96 return decode(buf[..z])?; 97 }; 98 99 // Response was truncated, retry over TCP. In TCP mode, the 100 // query is preceded by two bytes indicating the query length 101 z = encode(buf, query)?; 102 if (z > types::U16_MAX) { 103 return errors::overflow; 104 }; 105 let zbuf: [2]u8 = [0...]; 106 endian::beputu16(zbuf, z: u16); 107 let socket = tcp::connect(src, 53)?; 108 defer net::close(socket)!; 109 110 io::writeall(socket, zbuf)!; 111 io::writeall(socket, buf[..z])!; 112 113 let rz: u16 = match (io::readall(socket, zbuf)?) { 114 case let s: size => 115 if (s != 2) { 116 return format; 117 }; 118 yield endian::begetu16(zbuf); 119 case => 120 return format; 121 }; 122 let tcpbuf: []u8 = alloc([0...], rz); 123 defer free(tcpbuf); 124 125 match (io::readall(socket, tcpbuf)?) { 126 case let s: size => 127 if (s != rz) { 128 return format; 129 }; 130 case => 131 return format; 132 }; 133 134 const dec = decoder_init(tcpbuf); 135 decode_header(&dec, &header)?; 136 if ((header.id != query.header.id) || header.op.tc) { 137 return format; 138 }; 139 check_rcode(header.op.rcode)?; 140 return decode(tcpbuf)?; 141 };