encode.ha (4073B)
1 // License: MPL-2.0 2 // (c) 2021 Drew DeVault <sir@cmpwn.com> 3 use endian; 4 use errors; 5 use fmt; 6 7 type encoder = struct { 8 buf: []u8, 9 offs: size, 10 }; 11 12 // Encodes a DNS message, returning its size, or an error. 13 export fn encode(buf: []u8, msg: *message) (size | error) = { 14 let enc = encoder { buf = buf, offs = 0z }; 15 encode_u16(&enc, msg.header.id)?; 16 encode_u16(&enc, encode_op(&msg.header.op))?; 17 encode_u16(&enc, msg.header.qdcount)?; 18 encode_u16(&enc, msg.header.ancount)?; 19 encode_u16(&enc, msg.header.nscount)?; 20 encode_u16(&enc, msg.header.arcount)?; 21 22 for (let i = 0z; i < len(msg.questions); i += 1) { 23 question_encode(&enc, &msg.questions[i])?; 24 }; 25 for (let i = 0z; i < len(msg.answers); i += 1) { 26 rrecord_encode(&enc, &msg.answers[i])?; 27 }; 28 for (let i = 0z; i < len(msg.authority); i += 1) { 29 rrecord_encode(&enc, &msg.authority[i])?; 30 }; 31 for (let i = 0z; i < len(msg.additional); i += 1) { 32 rrecord_encode(&enc, &msg.additional[i])?; 33 }; 34 35 return enc.offs; 36 }; 37 38 fn encode_u8(enc: *encoder, val: u8) (void | error) = { 39 if (len(enc.buf) <= enc.offs + 1) { 40 return errors::overflow; 41 }; 42 enc.buf[enc.offs] = val; 43 enc.offs += 1; 44 }; 45 46 fn encode_u16(enc: *encoder, val: u16) (void | error) = { 47 if (len(enc.buf) <= enc.offs + 2) { 48 return errors::overflow; 49 }; 50 endian::beputu16(enc.buf[enc.offs..], val); 51 enc.offs += 2; 52 }; 53 54 fn encode_u32(enc: *encoder, val: u32) (void | error) = { 55 if (len(enc.buf) <= enc.offs + 4) { 56 return errors::overflow; 57 }; 58 endian::beputu32(enc.buf[enc.offs..], val); 59 enc.offs += 4; 60 }; 61 62 fn encode_raw(enc: *encoder, val: []u8) (void | error) = { 63 let end = enc.offs + len(val); 64 if (len(enc.buf) <= end) { 65 return errors::overflow; 66 }; 67 enc.buf[enc.offs..end] = val; 68 enc.offs += len(val); 69 }; 70 71 fn encode_labels(enc: *encoder, names: []str) (void | error) = { 72 // TODO: Assert that the labels are all valid ASCII? 73 for (let i = 0z; i < len(names); i += 1) { 74 // XXX: Should I return an error instead of asserting? 75 assert(len(names[i]) < 256); 76 if (len(enc.buf) <= enc.offs + 1 + len(names[i])) { 77 return errors::overflow; 78 }; 79 encode_u8(enc, len(names[i]): u8)?; 80 let label = fmt::bsprintf(enc.buf[enc.offs..], "{}", names[i]); 81 enc.offs += len(label); 82 }; 83 }; 84 85 fn question_encode(enc: *encoder, q: *question) (void | error) = { 86 encode_labels(enc, q.qname)?; 87 encode_u8(enc, 0)?; 88 encode_u16(enc, q.qtype)?; 89 encode_u16(enc, q.qclass)?; 90 }; 91 92 fn rrecord_encode(enc: *encoder, r: *rrecord) (void | error) = { 93 encode_labels(enc, r.name)?; 94 encode_u8(enc, 0)?; 95 encode_u16(enc, r.rtype)?; 96 encode_u16(enc, r.class)?; 97 encode_u32(enc, r.ttl)?; 98 let ln_enc = *enc; // save state for rdata len 99 encode_u16(enc, 0)?; // write dummy rdata len 100 encode_rdata(enc, r.rdata)?; // write rdata 101 let rdata_len = enc.offs - ln_enc.offs - 2; 102 encode_u16(&ln_enc, rdata_len: u16)?; // write rdata len to its place 103 }; 104 105 fn encode_rdata(enc: *encoder, rdata: rdata) (void | error) = { 106 match (rdata) { 107 case let d: unknown_rdata => 108 return encode_raw(enc, d); 109 case let d: txt => 110 return encode_txt(enc, d); 111 case => 112 abort(); // TODO 113 }; 114 }; 115 116 fn encode_txt(enc: *encoder, txt: txt) (void | error) = { 117 for (let i = 0z; i < len(txt); i += 1) { 118 if (len(txt[i]) > 255) return errors::invalid; 119 encode_u8(enc, len(txt[i]): u8)?; 120 encode_raw(enc, txt[i])?; 121 }; 122 }; 123 124 fn encode_op(op: *op) u16 = 125 (op.qr: u16 << 15u16) | 126 (op.opcode: u16 << 11u16) | 127 (if (op.aa) 0b0000010000000000u16 else 0u16) | 128 (if (op.tc) 0b0000001000000000u16 else 0u16) | 129 (if (op.rd) 0b0000000100000000u16 else 0u16) | 130 (if (op.ra) 0b0000000010000000u16 else 0u16) | 131 op.rcode: u16; 132 133 @test fn opcode() void = { 134 let opcode = op { 135 qr = qr::RESPONSE, 136 opcode = opcode::IQUERY, 137 aa = false, 138 tc = true, 139 rd = false, 140 ra = true, 141 rcode = rcode::SERVFAIL, 142 }; 143 let enc = encode_op(&opcode); 144 let opcode2 = op { ... }; 145 decode_op(enc, &opcode2); 146 assert(opcode.qr == opcode2.qr && opcode.opcode == opcode2.opcode && 147 opcode.aa == opcode2.aa && opcode.tc == opcode2.tc && 148 opcode.rd == opcode2.rd && opcode.ra == opcode2.ra && 149 opcode.rcode == opcode2.rcode); 150 };