hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

decode.ha (9875B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 use ascii;
      5 use endian;
      6 use fmt;
      7 use net::ip;
      8 use strings;
      9 
     10 type decoder = struct {
     11 	buf: []u8,
     12 	cur: []u8,
     13 };
     14 
     15 // Decodes a DNS message, heap allocating the resources necessary to represent
     16 // it in Hare's type system. The caller must use [[message_free]] to free the
     17 // return value.
     18 export fn decode(buf: []u8) (*message | format) = {
     19 	let success = false;
     20 	let msg = alloc(message { ... });
     21 	defer if (!success) message_free(msg);
     22 	let dec = decoder_init(buf);
     23 	decode_header(&dec, &msg.header)?;
     24 	for (let i = 0z; i < msg.header.qdcount; i += 1) {
     25 		append(msg.questions, decode_question(&dec)?);
     26 	};
     27 	decode_rrecords(&dec, msg.header.ancount, &msg.answers)?;
     28 	decode_rrecords(&dec, msg.header.nscount, &msg.authority)?;
     29 	decode_rrecords(&dec, msg.header.arcount, &msg.additional)?;
     30 	success = true;
     31 	return msg;
     32 };
     33 
     34 fn decode_rrecords(
     35 	dec: *decoder,
     36 	count: u16,
     37 	out: *[]rrecord,
     38 ) (void | format) = {
     39 	for (let i = 0z; i < count; i += 1) {
     40 		append(out, decode_rrecord(dec)?);
     41 	};
     42 };
     43 
     44 fn decoder_init(buf: []u8) decoder = decoder {
     45 	buf = buf,
     46 	cur = buf,
     47 	...
     48 };
     49 
     50 fn decode_u8(dec: *decoder) (u8 | format) = {
     51 	if (len(dec.cur) < 1) {
     52 		return format;
     53 	};
     54 	const val = dec.cur[0];
     55 	dec.cur = dec.cur[1..];
     56 	return val;
     57 };
     58 
     59 fn decode_u16(dec: *decoder) (u16 | format) = {
     60 	if (len(dec.cur) < 2) {
     61 		return format;
     62 	};
     63 	const val = endian::begetu16(dec.cur);
     64 	dec.cur = dec.cur[2..];
     65 	return val;
     66 };
     67 
     68 fn decode_u32(dec: *decoder) (u32 | format) = {
     69 	if (len(dec.cur) < 4) {
     70 		return format;
     71 	};
     72 	const val = endian::begetu32(dec.cur);
     73 	dec.cur = dec.cur[4..];
     74 	return val;
     75 };
     76 
     77 fn decode_u48(dec: *decoder) (u64 | format) = {
     78 	if (len(dec.cur) < 6) {
     79 		return format;
     80 	};
     81 	let buf: [8]u8 = [0...];
     82 	buf[2..] = dec.cur[..6];
     83 	const val = endian::begetu64(buf[..]);
     84 	dec.cur = dec.cur[6..];
     85 	return val;
     86 };
     87 
     88 fn decode_header(dec: *decoder, head: *header) (void | format) = {
     89 	head.id = decode_u16(dec)?;
     90 	const rawop = decode_u16(dec)?;
     91 	decode_op(rawop, &head.op);
     92 	head.qdcount = decode_u16(dec)?;
     93 	head.ancount = decode_u16(dec)?;
     94 	head.nscount = decode_u16(dec)?;
     95 	head.arcount = decode_u16(dec)?;
     96 };
     97 
     98 fn decode_op(in: u16, out: *op) void = {
     99 	out.qr = ((in & 0b1000000000000000) >> 15): qr;
    100 	out.opcode = ((in & 0b0111100000000000u16) >> 11): opcode;
    101 	out.aa = in & 0b0000010000000000u16 != 0;
    102 	out.tc = in & 0b0000001000000000u16 != 0;
    103 	out.rd = in & 0b0000000100000000u16 != 0;
    104 	out.ra = in & 0b0000000010000000u16 != 0;
    105 	out.rcode = (in & 0b1111): rcode;
    106 };
    107 
    108 fn decode_name(dec: *decoder) ([]str | format) = {
    109 	let success = false;
    110 	let names: []str = [];
    111 	defer if (!success) strings::freeall(names);
    112 	let totalsize = 0z;
    113 	let sub = decoder {
    114 		buf = dec.buf,
    115 		...
    116 	};
    117 	for (let i = 0z; i < len(dec.buf); i += 2) {
    118 		if (len(dec.cur) < 1) {
    119 			return format;
    120 		};
    121 		const z = dec.cur[0];
    122 		if (z & 0b11000000 == 0b11000000) {
    123 			const offs = decode_u16(dec)? & ~0b1100000000000000u16;
    124 			if (len(dec.buf) < offs) {
    125 				return format;
    126 			};
    127 			sub.cur = dec.buf[offs..];
    128 			dec = &sub;
    129 			continue;
    130 		};
    131 		dec.cur = dec.cur[1..];
    132 		totalsize += z + 1;
    133 		if (totalsize > 255) {
    134 			return format;
    135 		};
    136 		if (z == 0) {
    137 			success = true;
    138 			return names;
    139 		};
    140 
    141 		if (len(dec.cur) < z) {
    142 			return format;
    143 		};
    144 		const name = match (strings::fromutf8(dec.cur[..z])) {
    145 		case let name: str =>
    146 			yield name;
    147 		case =>
    148 			return format;
    149 		};
    150 		dec.cur = dec.cur[z..];
    151 		if (!ascii::validstr(name)) {
    152 			return format;
    153 		};
    154 
    155 		append(names, strings::dup(name));
    156 	};
    157 	return format;
    158 };
    159 
    160 fn decode_question(dec: *decoder) (question | format) = {
    161 	let success = false;
    162 	const qname = decode_name(dec)?;
    163 	defer if (!success) strings::freeall(qname);
    164 	const qtype = decode_u16(dec)?: qtype;
    165 	const qclass = decode_u16(dec)?: qclass;
    166 	success = true;
    167 	return question {
    168 		qname = qname,
    169 		qtype = qtype,
    170 		qclass = qclass,
    171 	};
    172 };
    173 
    174 fn decode_rrecord(dec: *decoder) (rrecord | format) = {
    175 	let success = false;
    176 	const name = decode_name(dec)?;
    177 	defer if (!success) strings::freeall(name);
    178 	const rtype = decode_u16(dec)?: rtype;
    179 	const class = decode_u16(dec)?: class;
    180 	const ttl = decode_u32(dec)?;
    181 	const rlen = decode_u16(dec)?;
    182 	const rdata = decode_rdata(dec, rtype, rlen)?;
    183 	success = true;
    184 	return rrecord {
    185 		name = name,
    186 		rtype = rtype,
    187 		class = class,
    188 		ttl = ttl,
    189 		rdata = rdata
    190 	};
    191 };
    192 
    193 fn decode_rdata(dec: *decoder, rtype: rtype, rlen: size) (rdata | format) = {
    194 	if (len(dec.cur) < rlen) {
    195 		return format;
    196 	};
    197 	let sub = decoder {
    198 		cur = dec.cur[..rlen],
    199 		buf = dec.buf,
    200 	};
    201 	dec.cur = dec.cur[rlen..];
    202 	switch (rtype) {
    203 	case rtype::A =>
    204 		return decode_a(&sub);
    205 	case rtype::AAAA =>
    206 		return decode_aaaa(&sub);
    207 	case rtype::CAA =>
    208 		return decode_caa(&sub);
    209 	case rtype::CNAME =>
    210 		return decode_cname(&sub);
    211 	case rtype::DNSKEY =>
    212 		return decode_dnskey(&sub);
    213 	case rtype::MX =>
    214 		return decode_mx(&sub);
    215 	case rtype::NS =>
    216 		return decode_ns(&sub);
    217 	case rtype::OPT =>
    218 		return decode_opt(&sub);
    219 	case rtype::NSEC =>
    220 		return decode_nsec(&sub);
    221 	case rtype::PTR =>
    222 		return decode_ptr(&sub);
    223 	case rtype::RRSIG =>
    224 		return decode_rrsig(&sub);
    225 	case rtype::SOA =>
    226 		return decode_soa(&sub);
    227 	case rtype::SRV =>
    228 		return decode_srv(&sub);
    229 	case rtype::SSHFP =>
    230 		return decode_sshfp(&sub);
    231 	case rtype::TSIG =>
    232 		return decode_tsig(&sub);
    233 	case rtype::TXT =>
    234 		return decode_txt(&sub);
    235 	case =>
    236 		return sub.cur: unknown_rdata;
    237 	};
    238 };
    239 
    240 fn decode_a(dec: *decoder) (rdata | format) = {
    241 	if (len(dec.cur) < 4) {
    242 		return format;
    243 	};
    244 	let ip: ip::addr4 = [0...];
    245 	ip[..] = dec.cur[..4];
    246 	dec.cur = dec.cur[4..];
    247 	return ip: a;
    248 };
    249 
    250 fn decode_aaaa(dec: *decoder) (rdata | format) = {
    251 	if (len(dec.cur) < 16) {
    252 		return format;
    253 	};
    254 	let ip: ip::addr6 = [0...];
    255 	ip[..] = dec.cur[..16];
    256 	dec.cur = dec.cur[16..];
    257 	return ip: aaaa;
    258 };
    259 
    260 fn decode_caa(dec: *decoder) (rdata | format) = {
    261 	let flags = decode_u8(dec)?;
    262 	let tag_len = decode_u8(dec)?;
    263 
    264 	if (len(dec.cur) < tag_len) {
    265 		return format;
    266 	};
    267 	let tag = match(strings::fromutf8(dec.cur[..tag_len])) {
    268 	case let t: str =>
    269 		yield t;
    270 	case =>
    271 		return format;
    272 	};
    273 	let value = match (strings::fromutf8(dec.cur[tag_len..])) {
    274 	case let v: str =>
    275 		yield v;
    276 	case =>
    277 		return format;
    278 	};
    279 
    280 	return caa {
    281 		flags = flags,
    282 		tag = strings::dup(tag),
    283 		value = strings::dup(value),
    284 	};
    285 };
    286 
    287 fn decode_cname(dec: *decoder) (rdata | format) = {
    288 	return cname {
    289 		name = decode_name(dec)?,
    290 	};
    291 };
    292 
    293 fn decode_dnskey(dec: *decoder) (rdata | format) = {
    294 	let r = dnskey {
    295 		flags = decode_u16(dec)?,
    296 		protocol = decode_u8(dec)?,
    297 		algorithm = decode_u8(dec)?,
    298 		key = [],
    299 	};
    300 	append(r.key, dec.cur[..]...);
    301 	return r;
    302 };
    303 
    304 fn decode_mx(dec: *decoder) (rdata | format) = {
    305 	return mx {
    306 		priority = decode_u16(dec)?,
    307 		name = decode_name(dec)?,
    308 	};
    309 };
    310 
    311 fn decode_ns(dec: *decoder) (rdata | format) = {
    312 	return ns {
    313 		name = decode_name(dec)?,
    314 	};
    315 };
    316 
    317 fn decode_opt(dec: *decoder) (rdata | format) = {
    318 	let success = false;
    319 	let r = opt {
    320 		options = [],
    321 	};
    322 	defer if (!success) {
    323 		for (let i = 0z; i < len(r.options); i += 1) {
    324 			free(r.options[i].data);
    325 		};
    326 		free(r.options);
    327 	};
    328 	for (len(dec.cur) > 0) {
    329 		let o = edns_opt {
    330 			code = decode_u16(dec)?,
    331 			data = [],
    332 		};
    333 		let sz = decode_u16(dec)?;
    334 		if (len(dec.cur) < sz) {
    335 			return format;
    336 		};
    337 		append(o.data, dec.cur[..sz]...);
    338 		dec.cur = dec.cur[sz..];
    339 		append(r.options, o);
    340 	};
    341 	success = true;
    342 	return r;
    343 };
    344 
    345 fn decode_nsec(dec: *decoder) (rdata | format) = {
    346 	let r = nsec {
    347 		next_domain = decode_name(dec)?,
    348 		type_bitmaps = [],
    349 	};
    350 	append(r.type_bitmaps, dec.cur[..]...);
    351 	return r;
    352 };
    353 
    354 fn decode_ptr(dec: *decoder) (rdata | format) = {
    355 	return ptr {
    356 		name = decode_name(dec)?,
    357 	};
    358 };
    359 
    360 fn decode_rrsig(dec: *decoder) (rdata | format) = {
    361 	let r = rrsig {
    362 		type_covered = decode_u16(dec)?,
    363 		algorithm = decode_u8(dec)?,
    364 		labels = decode_u8(dec)?,
    365 		orig_ttl = decode_u32(dec)?,
    366 		sig_expiration = decode_u32(dec)?,
    367 		sig_inception = decode_u32(dec)?,
    368 		key_tag = decode_u16(dec)?,
    369 		signer_name = decode_name(dec)?,
    370 		signature = [],
    371 	};
    372 	append(r.signature, dec.cur[..]...);
    373 	return r;
    374 };
    375 
    376 fn decode_soa(dec: *decoder) (rdata | format) = {
    377 	return soa {
    378 		mname = decode_name(dec)?,
    379 		rname = decode_name(dec)?,
    380 		serial = decode_u32(dec)?,
    381 		refresh = decode_u32(dec)?,
    382 		retry = decode_u32(dec)?,
    383 		expire = decode_u32(dec)?,
    384 	};
    385 };
    386 
    387 fn decode_srv(dec: *decoder) (rdata | format) = {
    388 	return srv {
    389 		priority = decode_u16(dec)?,
    390 		weight = decode_u16(dec)?,
    391 		port = decode_u16(dec)?,
    392 		target = decode_name(dec)?,
    393 	};
    394 };
    395 
    396 fn decode_sshfp(dec: *decoder) (rdata | format) = {
    397 	let r = sshfp {
    398 		algorithm = decode_u8(dec)?,
    399 		fp_type = decode_u8(dec)?,
    400 		fingerprint = [],
    401 	};
    402 	append(r.fingerprint, dec.cur[..]...);
    403 	return r;
    404 };
    405 
    406 fn decode_tsig(dec: *decoder) (rdata | format) = {
    407 	let success = false;
    408 	let r = tsig {
    409 		algorithm = decode_name(dec)?,
    410 		...
    411 	};
    412 	defer if (!success) free(r.algorithm);
    413 
    414 	r.time_signed = decode_u48(dec)?;
    415 	r.fudge = decode_u16(dec)?;
    416 	r.mac_len = decode_u16(dec)?;
    417 
    418 	if (len(dec.cur) < r.mac_len) {
    419 		return format;
    420 	};
    421 	append(r.mac, dec.cur[..r.mac_len]...);
    422 	defer if (!success) free(r.mac);
    423 	dec.cur = dec.cur[r.mac_len..];
    424 
    425 	r.orig_id = decode_u16(dec)?;
    426 	r.error = decode_u16(dec)?;
    427 	r.other_len = decode_u16(dec)?;
    428 
    429 	if (len(dec.cur) != r.other_len) {
    430 		return format;
    431 	};
    432 	if (r.other_len > 0) {
    433 		append(r.other_data, dec.cur[..]...);
    434 	};
    435 
    436 	success = true;
    437 	return r;
    438 };
    439 
    440 fn decode_txt(dec: *decoder) (rdata | format) = {
    441 	let success = false;
    442 	let items: txt = [];
    443 	defer if (!success) bytes_free(items);
    444 	for (len(dec.cur) != 0) {
    445 		const ln = decode_u8(dec)?;
    446 		if (len(dec.cur) < ln) {
    447 			return format;
    448 		};
    449 		let item: []u8 = [];
    450 		append(item, dec.cur[..ln]...);
    451 		dec.cur = dec.cur[ln..];
    452 		append(items, item);
    453 	};
    454 	success = true;
    455 	return items;
    456 };
    457 
    458 // TODO: Expand breadth of supported rdata decoders