hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

encode.ha (3740B)


      1 // License: MPL-2.0
      2 // (c) 2021 Drew DeVault <sir@cmpwn.com>
      3 use endian;
      4 use errors;
      5 use fmt;
      6 use strings;
      7 
      8 type encoder = struct {
      9 	buf: []u8,
     10 	offs: size,
     11 };
     12 
     13 // Converts a human-readable domain name (e.g. "example.org") into a DNS-ready
     14 // name slice (e.g. ["example", "org"]). The slice returned must be freed by the
     15 // caller, but the members of the slice themselves are borrowed from the input.
     16 export fn parse_domain(in: str) []str = strings::split(in, ".");
     17 
     18 // Converts a DNS name slice (e.g. ["example", "org"]) into a human-readable
     19 // domain name (e.g. "example.org"). The return value must be freed by the
     20 // caller.
     21 export fn unparse_domain(in: []str) str = strings::join(".", in...);
     22 
     23 // Encodes a DNS message, returning its size, or an error.
     24 export fn encode(buf: []u8, msg: *message) (size | error) = {
     25 	let enc = encoder { buf = buf, offs = 0z };
     26 	encode_u16(&enc, msg.header.id)?;
     27 	encode_u16(&enc, encode_op(&msg.header.op))?;
     28 	encode_u16(&enc, msg.header.qdcount)?;
     29 	encode_u16(&enc, msg.header.ancount)?;
     30 	encode_u16(&enc, msg.header.nscount)?;
     31 	encode_u16(&enc, msg.header.arcount)?;
     32 
     33 	for (let i = 0z; i < len(msg.questions); i += 1) {
     34 		question_encode(&enc, &msg.questions[i])?;
     35 	};
     36 	for (let i = 0z; i < len(msg.answers); i += 1) {
     37 		rrecord_encode(&enc, &msg.answers[i])?;
     38 	};
     39 	for (let i = 0z; i < len(msg.authority); i += 1) {
     40 		rrecord_encode(&enc, &msg.authority[i])?;
     41 	};
     42 	for (let i = 0z; i < len(msg.additional); i += 1) {
     43 		rrecord_encode(&enc, &msg.additional[i])?;
     44 	};
     45 
     46 	return enc.offs;
     47 };
     48 
     49 fn encode_u8(enc: *encoder, val: u8) (void | error) = {
     50 	if (len(enc.buf) <= enc.offs + 1) {
     51 		return errors::overflow;
     52 	};
     53 	enc.buf[enc.offs] = val;
     54 	enc.offs += 1;
     55 };
     56 
     57 fn encode_u16(enc: *encoder, val: u16) (void | error) = {
     58 	if (len(enc.buf) <= enc.offs + 2) {
     59 		return errors::overflow;
     60 	};
     61 	endian::beputu16(enc.buf[enc.offs..], val);
     62 	enc.offs += 2;
     63 };
     64 
     65 fn encode_u32(enc: *encoder, val: u32) (void | error) = {
     66 	if (len(enc.buf) <= enc.offs + 4) {
     67 		return errors::overflow;
     68 	};
     69 	endian::beputu32(enc.buf[enc.offs..], val);
     70 	enc.offs += 4;
     71 };
     72 
     73 fn encode_labels(enc: *encoder, names: []str) (void | error) = {
     74 	// TODO: Assert that the labels are all valid ASCII?
     75 	for (let i = 0z; i < len(names); i += 1) {
     76 		// XXX: Should I return an error instead of asserting?
     77 		assert(len(names[i]) < 256);
     78 		if (len(enc.buf) <= enc.offs + 1 + len(names[i])) {
     79 			return errors::overflow;
     80 		};
     81 		encode_u8(enc, len(names[i]): u8)?;
     82 		let label = fmt::bsprintf(enc.buf[enc.offs..], "{}", names[i]);
     83 		enc.offs += len(label);
     84 	};
     85 };
     86 
     87 fn question_encode(enc: *encoder, q: *question) (void | error) = {
     88 	encode_labels(enc, q.qname)?;
     89 	encode_u8(enc, 0)?;
     90 	encode_u16(enc, q.qtype)?;
     91 	encode_u16(enc, q.qclass)?;
     92 };
     93 
     94 fn rrecord_encode(enc: *encoder, r: *rrecord) (void | error) = {
     95 	encode_labels(enc, r.name)?;
     96 	encode_u8(enc, 0)?;
     97 	encode_u16(enc, r.rtype)?;
     98 	encode_u16(enc, r.class)?;
     99 	encode_u32(enc, r.ttl)?;
    100 
    101 	abort(); // TODO
    102 };
    103 
    104 fn encode_op(op: *op) u16 =
    105 	(op.qr: u16 << 15u16) |
    106 	(op.opcode: u16 << 11u16) |
    107 	(if (op.aa) 0b0000010000000000u16 else 0u16) |
    108 	(if (op.tc) 0b0000001000000000u16 else 0u16) |
    109 	(if (op.rd) 0b0000000100000000u16 else 0u16) |
    110 	(if (op.ra) 0b0000000010000000u16 else 0u16) |
    111 	op.rcode: u16;
    112 
    113 @test fn opcode() void = {
    114 	let opcode = op {
    115 		qr = qr::RESPONSE,
    116 		opcode = opcode::IQUERY,
    117 		aa = false,
    118 		tc = true,
    119 		rd = false,
    120 		ra = true,
    121 		rcode = rcode::SERVER_FAILURE,
    122 	};
    123 	let enc = encode_op(&opcode);
    124 	let opcode2 = op { ... };
    125 	decode_op(enc, &opcode2);
    126 	assert(opcode.qr == opcode2.qr && opcode.opcode == opcode2.opcode &&
    127 		opcode.aa == opcode2.aa && opcode.tc == opcode2.tc &&
    128 		opcode.rd == opcode2.rd && opcode.ra == opcode2.ra &&
    129 		opcode.rcode == opcode2.rcode);
    130 };