encode.ha (4405B)
1 // SPDX-License-Identifier: MPL-2.0 2 // (c) Hare authors <https://harelang.org> 3 4 use endian; 5 use errors; 6 use fmt; 7 8 type encoder = struct { 9 buf: []u8, 10 offs: size, 11 }; 12 13 // Encodes a DNS message, returning its size, or an error. 14 export fn encode(buf: []u8, msg: *message) (size | error) = { 15 let enc = encoder { buf = buf, offs = 0z }; 16 encode_u16(&enc, msg.header.id)?; 17 encode_u16(&enc, encode_op(&msg.header.op))?; 18 encode_u16(&enc, msg.header.qdcount)?; 19 encode_u16(&enc, msg.header.ancount)?; 20 encode_u16(&enc, msg.header.nscount)?; 21 encode_u16(&enc, msg.header.arcount)?; 22 23 for (let i = 0z; i < len(msg.questions); i += 1) { 24 question_encode(&enc, &msg.questions[i])?; 25 }; 26 for (let i = 0z; i < len(msg.answers); i += 1) { 27 rrecord_encode(&enc, &msg.answers[i])?; 28 }; 29 for (let i = 0z; i < len(msg.authority); i += 1) { 30 rrecord_encode(&enc, &msg.authority[i])?; 31 }; 32 for (let i = 0z; i < len(msg.additional); i += 1) { 33 rrecord_encode(&enc, &msg.additional[i])?; 34 }; 35 36 return enc.offs; 37 }; 38 39 fn encode_u8(enc: *encoder, val: u8) (void | error) = { 40 if (len(enc.buf) <= enc.offs + 1) { 41 return errors::overflow; 42 }; 43 enc.buf[enc.offs] = val; 44 enc.offs += 1; 45 }; 46 47 fn encode_u16(enc: *encoder, val: u16) (void | error) = { 48 if (len(enc.buf) <= enc.offs + 2) { 49 return errors::overflow; 50 }; 51 endian::beputu16(enc.buf[enc.offs..], val); 52 enc.offs += 2; 53 }; 54 55 fn encode_u32(enc: *encoder, val: u32) (void | error) = { 56 if (len(enc.buf) <= enc.offs + 4) { 57 return errors::overflow; 58 }; 59 endian::beputu32(enc.buf[enc.offs..], val); 60 enc.offs += 4; 61 }; 62 63 fn encode_raw(enc: *encoder, val: []u8) (void | error) = { 64 let end = enc.offs + len(val); 65 if (len(enc.buf) < end) { 66 return errors::overflow; 67 }; 68 enc.buf[enc.offs..end] = val; 69 enc.offs += len(val); 70 }; 71 72 fn encode_labels(enc: *encoder, names: []str) (void | error) = { 73 // TODO: Assert that the labels are all valid ASCII? 74 for (let i = 0z; i < len(names); i += 1) { 75 if (len(names[i]) > 63) { 76 return format; 77 }; 78 if (len(enc.buf) <= enc.offs + 1 + len(names[i])) { 79 return errors::overflow; 80 }; 81 encode_u8(enc, len(names[i]): u8)?; 82 let label = fmt::bsprintf(enc.buf[enc.offs..], "{}", names[i]); 83 enc.offs += len(label); 84 }; 85 encode_u8(enc, 0)?; 86 }; 87 88 fn question_encode(enc: *encoder, q: *question) (void | error) = { 89 encode_labels(enc, q.qname)?; 90 encode_u16(enc, q.qtype)?; 91 encode_u16(enc, q.qclass)?; 92 }; 93 94 fn rrecord_encode(enc: *encoder, r: *rrecord) (void | error) = { 95 encode_labels(enc, r.name)?; 96 encode_u16(enc, r.rtype)?; 97 encode_u16(enc, r.class)?; 98 encode_u32(enc, r.ttl)?; 99 let ln_enc = *enc; // save state for rdata len 100 encode_u16(enc, 0)?; // write dummy rdata len 101 encode_rdata(enc, r.rdata)?; // write rdata 102 let rdata_len = enc.offs - ln_enc.offs - 2; 103 encode_u16(&ln_enc, rdata_len: u16)?; // write rdata len to its place 104 }; 105 106 fn encode_rdata(enc: *encoder, rdata: rdata) (void | error) = { 107 match (rdata) { 108 case let d: unknown_rdata => 109 return encode_raw(enc, d); 110 case let d: opt => 111 return encode_opt(enc, d); 112 case let d: txt => 113 return encode_txt(enc, d); 114 case => 115 abort(); // TODO 116 }; 117 }; 118 119 fn encode_opt(enc: *encoder, opt: opt) (void | error) = { 120 for (let i = 0z; i < len(opt.options); i += 1) { 121 if (len(opt.options[i].data) > 65535) { 122 return errors::invalid; 123 }; 124 encode_u16(enc, opt.options[i].code)?; 125 encode_u16(enc, len(opt.options[i].data): u16)?; 126 encode_raw(enc, opt.options[i].data)?; 127 }; 128 }; 129 130 fn encode_txt(enc: *encoder, txt: txt) (void | error) = { 131 for (let i = 0z; i < len(txt); i += 1) { 132 if (len(txt[i]) > 255) return errors::invalid; 133 encode_u8(enc, len(txt[i]): u8)?; 134 encode_raw(enc, txt[i])?; 135 }; 136 }; 137 138 fn encode_op(op: *op) u16 = 139 (op.qr: u16 << 15u16) | 140 (op.opcode: u16 << 11u16) | 141 (if (op.aa) 0b0000010000000000u16 else 0u16) | 142 (if (op.tc) 0b0000001000000000u16 else 0u16) | 143 (if (op.rd) 0b0000000100000000u16 else 0u16) | 144 (if (op.ra) 0b0000000010000000u16 else 0u16) | 145 op.rcode: u16; 146 147 @test fn opcode() void = { 148 let opcode = op { 149 qr = qr::RESPONSE, 150 opcode = opcode::IQUERY, 151 aa = false, 152 tc = true, 153 rd = false, 154 ra = true, 155 rcode = rcode::SERVFAIL, 156 }; 157 let enc = encode_op(&opcode); 158 let opcode2 = op { ... }; 159 decode_op(enc, &opcode2); 160 assert(opcode.qr == opcode2.qr && opcode.opcode == opcode2.opcode && 161 opcode.aa == opcode2.aa && opcode.tc == opcode2.tc && 162 opcode.rd == opcode2.rd && opcode.ra == opcode2.ra && 163 opcode.rcode == opcode2.rcode); 164 };