hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

encode.ha (4405B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 use endian;
      5 use errors;
      6 use fmt;
      7 
      8 type encoder = struct {
      9 	buf: []u8,
     10 	offs: size,
     11 };
     12 
     13 // Encodes a DNS message, returning its size, or an error.
     14 export fn encode(buf: []u8, msg: *message) (size | error) = {
     15 	let enc = encoder { buf = buf, offs = 0z };
     16 	encode_u16(&enc, msg.header.id)?;
     17 	encode_u16(&enc, encode_op(&msg.header.op))?;
     18 	encode_u16(&enc, msg.header.qdcount)?;
     19 	encode_u16(&enc, msg.header.ancount)?;
     20 	encode_u16(&enc, msg.header.nscount)?;
     21 	encode_u16(&enc, msg.header.arcount)?;
     22 
     23 	for (let i = 0z; i < len(msg.questions); i += 1) {
     24 		question_encode(&enc, &msg.questions[i])?;
     25 	};
     26 	for (let i = 0z; i < len(msg.answers); i += 1) {
     27 		rrecord_encode(&enc, &msg.answers[i])?;
     28 	};
     29 	for (let i = 0z; i < len(msg.authority); i += 1) {
     30 		rrecord_encode(&enc, &msg.authority[i])?;
     31 	};
     32 	for (let i = 0z; i < len(msg.additional); i += 1) {
     33 		rrecord_encode(&enc, &msg.additional[i])?;
     34 	};
     35 
     36 	return enc.offs;
     37 };
     38 
     39 fn encode_u8(enc: *encoder, val: u8) (void | error) = {
     40 	if (len(enc.buf) <= enc.offs + 1) {
     41 		return errors::overflow;
     42 	};
     43 	enc.buf[enc.offs] = val;
     44 	enc.offs += 1;
     45 };
     46 
     47 fn encode_u16(enc: *encoder, val: u16) (void | error) = {
     48 	if (len(enc.buf) <= enc.offs + 2) {
     49 		return errors::overflow;
     50 	};
     51 	endian::beputu16(enc.buf[enc.offs..], val);
     52 	enc.offs += 2;
     53 };
     54 
     55 fn encode_u32(enc: *encoder, val: u32) (void | error) = {
     56 	if (len(enc.buf) <= enc.offs + 4) {
     57 		return errors::overflow;
     58 	};
     59 	endian::beputu32(enc.buf[enc.offs..], val);
     60 	enc.offs += 4;
     61 };
     62 
     63 fn encode_raw(enc: *encoder, val: []u8) (void | error) = {
     64 	let end = enc.offs + len(val);
     65 	if (len(enc.buf) < end) {
     66 		return errors::overflow;
     67 	};
     68 	enc.buf[enc.offs..end] = val;
     69 	enc.offs += len(val);
     70 };
     71 
     72 fn encode_labels(enc: *encoder, names: []str) (void | error) = {
     73 	// TODO: Assert that the labels are all valid ASCII?
     74 	for (let i = 0z; i < len(names); i += 1) {
     75 		if (len(names[i]) > 63) {
     76 			return format;
     77 		};
     78 		if (len(enc.buf) <= enc.offs + 1 + len(names[i])) {
     79 			return errors::overflow;
     80 		};
     81 		encode_u8(enc, len(names[i]): u8)?;
     82 		let label = fmt::bsprintf(enc.buf[enc.offs..], "{}", names[i]);
     83 		enc.offs += len(label);
     84 	};
     85 	encode_u8(enc, 0)?;
     86 };
     87 
     88 fn question_encode(enc: *encoder, q: *question) (void | error) = {
     89 	encode_labels(enc, q.qname)?;
     90 	encode_u16(enc, q.qtype)?;
     91 	encode_u16(enc, q.qclass)?;
     92 };
     93 
     94 fn rrecord_encode(enc: *encoder, r: *rrecord) (void | error) = {
     95 	encode_labels(enc, r.name)?;
     96 	encode_u16(enc, r.rtype)?;
     97 	encode_u16(enc, r.class)?;
     98 	encode_u32(enc, r.ttl)?;
     99 	let ln_enc = *enc; // save state for rdata len
    100 	encode_u16(enc, 0)?; // write dummy rdata len
    101 	encode_rdata(enc, r.rdata)?; // write rdata
    102 	let rdata_len = enc.offs - ln_enc.offs - 2;
    103 	encode_u16(&ln_enc, rdata_len: u16)?; // write rdata len to its place
    104 };
    105 
    106 fn encode_rdata(enc: *encoder, rdata: rdata) (void | error) = {
    107 	match (rdata) {
    108 	case let d: unknown_rdata =>
    109 		return encode_raw(enc, d);
    110 	case let d: opt =>
    111 		return encode_opt(enc, d);
    112 	case let d: txt =>
    113 		return encode_txt(enc, d);
    114 	case =>
    115 		abort(); // TODO
    116 	};
    117 };
    118 
    119 fn encode_opt(enc: *encoder, opt: opt) (void | error) = {
    120 	for (let i = 0z; i < len(opt.options); i += 1) {
    121 		if (len(opt.options[i].data) > 65535) {
    122 			return errors::invalid;
    123 		};
    124 		encode_u16(enc, opt.options[i].code)?;
    125 		encode_u16(enc, len(opt.options[i].data): u16)?;
    126 		encode_raw(enc, opt.options[i].data)?;
    127 	};
    128 };
    129 
    130 fn encode_txt(enc: *encoder, txt: txt) (void | error) = {
    131 	for (let i = 0z; i < len(txt); i += 1) {
    132 		if (len(txt[i]) > 255) return errors::invalid;
    133 		encode_u8(enc, len(txt[i]): u8)?;
    134 		encode_raw(enc, txt[i])?;
    135 	};
    136 };
    137 
    138 fn encode_op(op: *op) u16 =
    139 	(op.qr: u16 << 15u16) |
    140 	(op.opcode: u16 << 11u16) |
    141 	(if (op.aa) 0b0000010000000000u16 else 0u16) |
    142 	(if (op.tc) 0b0000001000000000u16 else 0u16) |
    143 	(if (op.rd) 0b0000000100000000u16 else 0u16) |
    144 	(if (op.ra) 0b0000000010000000u16 else 0u16) |
    145 	op.rcode: u16;
    146 
    147 @test fn opcode() void = {
    148 	let opcode = op {
    149 		qr = qr::RESPONSE,
    150 		opcode = opcode::IQUERY,
    151 		aa = false,
    152 		tc = true,
    153 		rd = false,
    154 		ra = true,
    155 		rcode = rcode::SERVFAIL,
    156 	};
    157 	let enc = encode_op(&opcode);
    158 	let opcode2 = op { ... };
    159 	decode_op(enc, &opcode2);
    160 	assert(opcode.qr == opcode2.qr && opcode.opcode == opcode2.opcode &&
    161 		opcode.aa == opcode2.aa && opcode.tc == opcode2.tc &&
    162 		opcode.rd == opcode2.rd && opcode.ra == opcode2.ra &&
    163 		opcode.rcode == opcode2.rcode);
    164 };