hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

decode.ha (4711B)


      1 // License: MPL-2.0
      2 // (c) 2021 Drew DeVault <sir@cmpwn.com>
      3 use ascii;
      4 use endian;
      5 use fmt;
      6 use net::ip;
      7 use strings;
      8 
      9 type decoder = struct {
     10 	buf: []u8,
     11 	cur: []u8,
     12 };
     13 
     14 // Decodes a DNS message, heap allocating the resources necessary to represent
     15 // it in Hare's type system. The caller must use [[message_free]] to free the
     16 // return value. To decode without use of the heap, see [[decoder_init]].
     17 export fn decode(buf: []u8) (*message | format) = {
     18 	let msg = alloc(message { ... });
     19 	let dec = decoder_init(buf);
     20 	decode_header(&dec, &msg.header)?;
     21 	for (let i = 0z; i < msg.header.qdcount; i += 1) {
     22 		append(msg.questions, decode_question(&dec)?);
     23 	};
     24 	decode_rrecords(&dec, msg.header.ancount, &msg.answers)?;
     25 	decode_rrecords(&dec, msg.header.nscount, &msg.authority)?;
     26 	decode_rrecords(&dec, msg.header.arcount, &msg.additional)?;
     27 	return msg;
     28 };
     29 
     30 fn decode_rrecords(
     31 	dec: *decoder,
     32 	count: u16,
     33 	out: *[]rrecord,
     34 ) (void | format) = {
     35 	for (let i = 0z; i < count; i += 1) {
     36 		append(out, decode_rrecord(dec)?);
     37 	};
     38 };
     39 
     40 fn decoder_init(buf: []u8) decoder = decoder {
     41 	buf = buf,
     42 	cur = buf,
     43 	...
     44 };
     45 
     46 fn decode_u8(dec: *decoder) (u8 | format) = {
     47 	if (len(dec.cur) < 1) {
     48 		return format;
     49 	};
     50 	const val = dec.cur[0];
     51 	dec.cur = dec.cur[1..];
     52 	return val;
     53 };
     54 
     55 fn decode_u16(dec: *decoder) (u16 | format) = {
     56 	if (len(dec.cur) < 2) {
     57 		return format;
     58 	};
     59 	const val = endian::begetu16(dec.cur);
     60 	dec.cur = dec.cur[2..];
     61 	return val;
     62 };
     63 
     64 fn decode_u32(dec: *decoder) (u32 | format) = {
     65 	if (len(dec.cur) < 4) {
     66 		return format;
     67 	};
     68 	const val = endian::begetu32(dec.cur);
     69 	dec.cur = dec.cur[4..];
     70 	return val;
     71 };
     72 
     73 fn decode_header(dec: *decoder, head: *header) (void | format) = {
     74 	head.id = decode_u16(dec)?;
     75 	const rawop = decode_u16(dec)?;
     76 	decode_op(rawop, &head.op);
     77 	head.qdcount = decode_u16(dec)?;
     78 	head.ancount = decode_u16(dec)?;
     79 	head.nscount = decode_u16(dec)?;
     80 	head.arcount = decode_u16(dec)?;
     81 };
     82 
     83 fn decode_op(in: u16, out: *op) void = {
     84 	out.qr = ((in & 0b1000000000000000) >> 15): qr;
     85 	out.opcode = ((in & 0b01111000000000u16) >> 11): opcode;
     86 	out.aa = in & 0b0000010000000000u16 != 0;
     87 	out.tc = in & 0b0000001000000000u16 != 0;
     88 	out.rd = in & 0b0000000100000000u16 != 0;
     89 	out.ra = in & 0b0000000010000000u16 != 0;
     90 	out.rcode = (in & 0b1111): rcode;
     91 };
     92 
     93 fn decode_name(dec: *decoder) ([]str | format) = {
     94 	let names: []str = [];
     95 	for (true) {
     96 		const z = dec.cur[0];
     97 		if (z & 0b11000000 == 0b11000000) {
     98 			const offs = decode_u16(dec)? & ~0b1100000000000000u16;
     99 			const sub = decoder {
    100 				buf = dec.buf,
    101 				cur = dec.buf[offs..],
    102 				...
    103 			};
    104 			append(names, decode_name(&sub)?...);
    105 			break;
    106 		};
    107 		dec.cur = dec.cur[1..];
    108 		if (z == 0) {
    109 			break;
    110 		};
    111 
    112 		const name = strings::fromutf8(dec.cur[..z]);
    113 		dec.cur = dec.cur[z..];
    114 		if (!ascii::validstr(name)) {
    115 			return format;
    116 		};
    117 
    118 		append(names, strings::dup(name));
    119 	};
    120 	return names;
    121 };
    122 
    123 fn decode_question(dec: *decoder) (question | format) = {
    124 	return question {
    125 		qname = decode_name(dec)?,
    126 		qtype = decode_u16(dec)?: qtype,
    127 		qclass = decode_u16(dec)?: qclass,
    128 	};
    129 };
    130 
    131 fn decode_rrecord(dec: *decoder) (rrecord | format) = {
    132 	const name = decode_name(dec)?;
    133 	const rtype = decode_u16(dec)?: rtype;
    134 	const class = decode_u16(dec)?: class;
    135 	const ttl = decode_u32(dec)?;
    136 	const rlen = decode_u16(dec)?;
    137 	const rdata = decode_rdata(dec, rtype, rlen)?;
    138 	return rrecord {
    139 		name = name,
    140 		rtype = rtype,
    141 		class = class,
    142 		ttl = ttl,
    143 		rdata = rdata
    144 	};
    145 };
    146 
    147 fn decode_rdata(dec: *decoder, rtype: rtype, rlen: size) (rdata | format) = {
    148 	switch (rtype) {
    149 	case rtype::A =>
    150 		return decode_a(dec);
    151 	case rtype::AAAA =>
    152 		return decode_aaaa(dec);
    153 	case rtype::MX =>
    154 		return decode_mx(dec);
    155 	case rtype::TXT =>
    156 		return decode_txt(dec);
    157 	case =>
    158 		let buf = dec.cur[..rlen];
    159 		dec.cur = dec.cur[rlen..];
    160 		return buf: unknown_rdata;
    161 	};
    162 };
    163 
    164 fn decode_a(dec: *decoder) (rdata | format) = {
    165 	if (len(dec.cur) < 4) {
    166 		return format;
    167 	};
    168 	let ip: ip::addr4 = [0...];
    169 	ip[..] = dec.cur[..4];
    170 	dec.cur = dec.cur[4..];
    171 	return ip: a;
    172 };
    173 
    174 fn decode_aaaa(dec: *decoder) (rdata | format) = {
    175 	if (len(dec.cur) < 16) {
    176 		return format;
    177 	};
    178 	let ip: ip::addr6 = [0...];
    179 	ip[..] = dec.cur[..16];
    180 	dec.cur = dec.cur[16..];
    181 	return ip: aaaa;
    182 };
    183 
    184 fn decode_mx(dec: *decoder) (rdata | format) = {
    185 	return mx {
    186 		priority = decode_u16(dec)?,
    187 		name = decode_name(dec)?,
    188 	};
    189 };
    190 
    191 fn decode_txt(dec: *decoder) (rdata | format) = {
    192 	let items: txt = [];
    193 	for (len(dec.cur) != 0) {
    194 		const ln = decode_u8(dec)?;
    195 		if (len(dec.cur) < ln) {
    196 			return format;
    197 		};
    198 		let item: []u8 = [];
    199 		append(item, dec.cur[..ln]...);
    200 		dec.cur = dec.cur[ln..];
    201 		append(items, item);
    202 	};
    203 	return items;
    204 };
    205 
    206 // TODO: Expand breadth of supported rdata decoders