hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

encode.ha (4073B)


      1 // License: MPL-2.0
      2 // (c) 2021 Drew DeVault <sir@cmpwn.com>
      3 use endian;
      4 use errors;
      5 use fmt;
      6 
      7 type encoder = struct {
      8 	buf: []u8,
      9 	offs: size,
     10 };
     11 
     12 // Encodes a DNS message, returning its size, or an error.
     13 export fn encode(buf: []u8, msg: *message) (size | error) = {
     14 	let enc = encoder { buf = buf, offs = 0z };
     15 	encode_u16(&enc, msg.header.id)?;
     16 	encode_u16(&enc, encode_op(&msg.header.op))?;
     17 	encode_u16(&enc, msg.header.qdcount)?;
     18 	encode_u16(&enc, msg.header.ancount)?;
     19 	encode_u16(&enc, msg.header.nscount)?;
     20 	encode_u16(&enc, msg.header.arcount)?;
     21 
     22 	for (let i = 0z; i < len(msg.questions); i += 1) {
     23 		question_encode(&enc, &msg.questions[i])?;
     24 	};
     25 	for (let i = 0z; i < len(msg.answers); i += 1) {
     26 		rrecord_encode(&enc, &msg.answers[i])?;
     27 	};
     28 	for (let i = 0z; i < len(msg.authority); i += 1) {
     29 		rrecord_encode(&enc, &msg.authority[i])?;
     30 	};
     31 	for (let i = 0z; i < len(msg.additional); i += 1) {
     32 		rrecord_encode(&enc, &msg.additional[i])?;
     33 	};
     34 
     35 	return enc.offs;
     36 };
     37 
     38 fn encode_u8(enc: *encoder, val: u8) (void | error) = {
     39 	if (len(enc.buf) <= enc.offs + 1) {
     40 		return errors::overflow;
     41 	};
     42 	enc.buf[enc.offs] = val;
     43 	enc.offs += 1;
     44 };
     45 
     46 fn encode_u16(enc: *encoder, val: u16) (void | error) = {
     47 	if (len(enc.buf) <= enc.offs + 2) {
     48 		return errors::overflow;
     49 	};
     50 	endian::beputu16(enc.buf[enc.offs..], val);
     51 	enc.offs += 2;
     52 };
     53 
     54 fn encode_u32(enc: *encoder, val: u32) (void | error) = {
     55 	if (len(enc.buf) <= enc.offs + 4) {
     56 		return errors::overflow;
     57 	};
     58 	endian::beputu32(enc.buf[enc.offs..], val);
     59 	enc.offs += 4;
     60 };
     61 
     62 fn encode_raw(enc: *encoder, val: []u8) (void | error) = {
     63 	let end = enc.offs + len(val);
     64 	if (len(enc.buf) <= end) {
     65 		return errors::overflow;
     66 	};
     67 	enc.buf[enc.offs..end] = val;
     68 	enc.offs += len(val);
     69 };
     70 
     71 fn encode_labels(enc: *encoder, names: []str) (void | error) = {
     72 	// TODO: Assert that the labels are all valid ASCII?
     73 	for (let i = 0z; i < len(names); i += 1) {
     74 		// XXX: Should I return an error instead of asserting?
     75 		assert(len(names[i]) < 256);
     76 		if (len(enc.buf) <= enc.offs + 1 + len(names[i])) {
     77 			return errors::overflow;
     78 		};
     79 		encode_u8(enc, len(names[i]): u8)?;
     80 		let label = fmt::bsprintf(enc.buf[enc.offs..], "{}", names[i]);
     81 		enc.offs += len(label);
     82 	};
     83 };
     84 
     85 fn question_encode(enc: *encoder, q: *question) (void | error) = {
     86 	encode_labels(enc, q.qname)?;
     87 	encode_u8(enc, 0)?;
     88 	encode_u16(enc, q.qtype)?;
     89 	encode_u16(enc, q.qclass)?;
     90 };
     91 
     92 fn rrecord_encode(enc: *encoder, r: *rrecord) (void | error) = {
     93 	encode_labels(enc, r.name)?;
     94 	encode_u8(enc, 0)?;
     95 	encode_u16(enc, r.rtype)?;
     96 	encode_u16(enc, r.class)?;
     97 	encode_u32(enc, r.ttl)?;
     98 	let ln_enc = *enc; // save state for rdata len
     99 	encode_u16(enc, 0)?; // write dummy rdata len
    100 	encode_rdata(enc, r.rdata)?; // write rdata
    101 	let rdata_len = enc.offs - ln_enc.offs - 2;
    102 	encode_u16(&ln_enc, rdata_len: u16)?; // write rdata len to its place
    103 };
    104 
    105 fn encode_rdata(enc: *encoder, rdata: rdata) (void | error) = {
    106 	match (rdata) {
    107 	case let d: unknown_rdata =>
    108 		return encode_raw(enc, d);
    109 	case let d: txt =>
    110 		return encode_txt(enc, d);
    111 	case =>
    112 		abort(); // TODO
    113 	};
    114 };
    115 
    116 fn encode_txt(enc: *encoder, txt: txt) (void | error) = {
    117 	for (let i = 0z; i < len(txt); i += 1) {
    118 		if (len(txt[i]) > 255) return errors::invalid;
    119 		encode_u8(enc, len(txt[i]): u8)?;
    120 		encode_raw(enc, txt[i])?;
    121 	};
    122 };
    123 
    124 fn encode_op(op: *op) u16 =
    125 	(op.qr: u16 << 15u16) |
    126 	(op.opcode: u16 << 11u16) |
    127 	(if (op.aa) 0b0000010000000000u16 else 0u16) |
    128 	(if (op.tc) 0b0000001000000000u16 else 0u16) |
    129 	(if (op.rd) 0b0000000100000000u16 else 0u16) |
    130 	(if (op.ra) 0b0000000010000000u16 else 0u16) |
    131 	op.rcode: u16;
    132 
    133 @test fn opcode() void = {
    134 	let opcode = op {
    135 		qr = qr::RESPONSE,
    136 		opcode = opcode::IQUERY,
    137 		aa = false,
    138 		tc = true,
    139 		rd = false,
    140 		ra = true,
    141 		rcode = rcode::SERVFAIL,
    142 	};
    143 	let enc = encode_op(&opcode);
    144 	let opcode2 = op { ... };
    145 	decode_op(enc, &opcode2);
    146 	assert(opcode.qr == opcode2.qr && opcode.opcode == opcode2.opcode &&
    147 		opcode.aa == opcode2.aa && opcode.tc == opcode2.tc &&
    148 		opcode.rd == opcode2.rd && opcode.ra == opcode2.ra &&
    149 		opcode.rcode == opcode2.rcode);
    150 };