hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

decode.ha (9880B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 use ascii;
      5 use endian;
      6 use net::ip;
      7 use strings;
      8 
      9 type decoder = struct {
     10 	buf: []u8,
     11 	cur: []u8,
     12 };
     13 
     14 // Decodes a DNS message, heap allocating the resources necessary to represent
     15 // it in Hare's type system. The caller must use [[message_free]] to free the
     16 // return value.
     17 export fn decode(buf: []u8) (*message | format) = {
     18 	let success = false;
     19 	let msg = alloc(message { ... })!;
     20 	defer if (!success) message_free(msg);
     21 	let dec = decoder_init(buf);
     22 	decode_header(&dec, &msg.header)?;
     23 	for (let i = 0z; i < msg.header.qdcount; i += 1) {
     24 		append(msg.questions, decode_question(&dec)?)!;
     25 	};
     26 	decode_rrecords(&dec, msg.header.ancount, &msg.answers)?;
     27 	decode_rrecords(&dec, msg.header.nscount, &msg.authority)?;
     28 	decode_rrecords(&dec, msg.header.arcount, &msg.additional)?;
     29 	success = true;
     30 	return msg;
     31 };
     32 
     33 fn decode_rrecords(
     34 	dec: *decoder,
     35 	count: u16,
     36 	out: *[]rrecord,
     37 ) (void | format) = {
     38 	for (let i = 0z; i < count; i += 1) {
     39 		append(out, decode_rrecord(dec)?)!;
     40 	};
     41 };
     42 
     43 fn decoder_init(buf: []u8) decoder = decoder {
     44 	buf = buf,
     45 	cur = buf,
     46 	...
     47 };
     48 
     49 fn decode_u8(dec: *decoder) (u8 | format) = {
     50 	if (len(dec.cur) < 1) {
     51 		return format;
     52 	};
     53 	const val = dec.cur[0];
     54 	dec.cur = dec.cur[1..];
     55 	return val;
     56 };
     57 
     58 fn decode_u16(dec: *decoder) (u16 | format) = {
     59 	if (len(dec.cur) < 2) {
     60 		return format;
     61 	};
     62 	const val = endian::begetu16(dec.cur);
     63 	dec.cur = dec.cur[2..];
     64 	return val;
     65 };
     66 
     67 fn decode_u32(dec: *decoder) (u32 | format) = {
     68 	if (len(dec.cur) < 4) {
     69 		return format;
     70 	};
     71 	const val = endian::begetu32(dec.cur);
     72 	dec.cur = dec.cur[4..];
     73 	return val;
     74 };
     75 
     76 fn decode_u48(dec: *decoder) (u64 | format) = {
     77 	if (len(dec.cur) < 6) {
     78 		return format;
     79 	};
     80 	let buf: [8]u8 = [0...];
     81 	buf[2..] = dec.cur[..6];
     82 	const val = endian::begetu64(buf[..]);
     83 	dec.cur = dec.cur[6..];
     84 	return val;
     85 };
     86 
     87 fn decode_header(dec: *decoder, head: *header) (void | format) = {
     88 	head.id = decode_u16(dec)?;
     89 	const rawop = decode_u16(dec)?;
     90 	decode_op(rawop, &head.op);
     91 	head.qdcount = decode_u16(dec)?;
     92 	head.ancount = decode_u16(dec)?;
     93 	head.nscount = decode_u16(dec)?;
     94 	head.arcount = decode_u16(dec)?;
     95 };
     96 
     97 fn decode_op(in: u16, out: *op) void = {
     98 	out.qr = ((in & 0b1000000000000000) >> 15): qr;
     99 	out.opcode = ((in & 0b0111100000000000u16) >> 11): opcode;
    100 	out.aa = in & 0b0000010000000000u16 != 0;
    101 	out.tc = in & 0b0000001000000000u16 != 0;
    102 	out.rd = in & 0b0000000100000000u16 != 0;
    103 	out.ra = in & 0b0000000010000000u16 != 0;
    104 	out.rcode = (in & 0b1111): rcode;
    105 };
    106 
    107 fn decode_name(dec: *decoder) ([]str | format) = {
    108 	let success = false;
    109 	let names: []str = [];
    110 	defer if (!success) strings::freeall(names);
    111 	let totalsize = 0z;
    112 	let sub = decoder {
    113 		buf = dec.buf,
    114 		...
    115 	};
    116 	for (let i = 0z; i < len(dec.buf); i += 2) {
    117 		if (len(dec.cur) < 1) {
    118 			return format;
    119 		};
    120 		const z = dec.cur[0];
    121 		if (z & 0b11000000 == 0b11000000) {
    122 			const offs = decode_u16(dec)? & ~0b1100000000000000u16;
    123 			if (len(dec.buf) < offs) {
    124 				return format;
    125 			};
    126 			sub.cur = dec.buf[offs..];
    127 			dec = &sub;
    128 			continue;
    129 		};
    130 		dec.cur = dec.cur[1..];
    131 		totalsize += z + 1;
    132 		if (totalsize > 255) {
    133 			return format;
    134 		};
    135 		if (z == 0) {
    136 			success = true;
    137 			return names;
    138 		};
    139 
    140 		if (len(dec.cur) < z) {
    141 			return format;
    142 		};
    143 		const name = match (strings::fromutf8(dec.cur[..z])) {
    144 		case let name: str =>
    145 			yield name;
    146 		case =>
    147 			return format;
    148 		};
    149 		dec.cur = dec.cur[z..];
    150 		if (!ascii::validstr(name)) {
    151 			return format;
    152 		};
    153 
    154 		append(names, strings::dup(name))!;
    155 	};
    156 	return format;
    157 };
    158 
    159 fn decode_question(dec: *decoder) (question | format) = {
    160 	let success = false;
    161 	const qname = decode_name(dec)?;
    162 	defer if (!success) strings::freeall(qname);
    163 	const qtype = decode_u16(dec)?: qtype;
    164 	const qclass = decode_u16(dec)?: qclass;
    165 	success = true;
    166 	return question {
    167 		qname = qname,
    168 		qtype = qtype,
    169 		qclass = qclass,
    170 	};
    171 };
    172 
    173 fn decode_rrecord(dec: *decoder) (rrecord | format) = {
    174 	let success = false;
    175 	const name = decode_name(dec)?;
    176 	defer if (!success) strings::freeall(name);
    177 	const rtype = decode_u16(dec)?: rtype;
    178 	const class = decode_u16(dec)?: class;
    179 	const ttl = decode_u32(dec)?;
    180 	const rlen = decode_u16(dec)?;
    181 	const rdata = decode_rdata(dec, rtype, rlen)?;
    182 	success = true;
    183 	return rrecord {
    184 		name = name,
    185 		rtype = rtype,
    186 		class = class,
    187 		ttl = ttl,
    188 		rdata = rdata
    189 	};
    190 };
    191 
    192 fn decode_rdata(dec: *decoder, rtype: rtype, rlen: size) (rdata | format) = {
    193 	if (len(dec.cur) < rlen) {
    194 		return format;
    195 	};
    196 	let sub = decoder {
    197 		cur = dec.cur[..rlen],
    198 		buf = dec.buf,
    199 	};
    200 	dec.cur = dec.cur[rlen..];
    201 	switch (rtype) {
    202 	case rtype::A =>
    203 		return decode_a(&sub);
    204 	case rtype::AAAA =>
    205 		return decode_aaaa(&sub);
    206 	case rtype::CAA =>
    207 		return decode_caa(&sub);
    208 	case rtype::CNAME =>
    209 		return decode_cname(&sub);
    210 	case rtype::DNSKEY =>
    211 		return decode_dnskey(&sub);
    212 	case rtype::MX =>
    213 		return decode_mx(&sub);
    214 	case rtype::NS =>
    215 		return decode_ns(&sub);
    216 	case rtype::OPT =>
    217 		return decode_opt(&sub);
    218 	case rtype::NSEC =>
    219 		return decode_nsec(&sub);
    220 	case rtype::PTR =>
    221 		return decode_ptr(&sub);
    222 	case rtype::RRSIG =>
    223 		return decode_rrsig(&sub);
    224 	case rtype::SOA =>
    225 		return decode_soa(&sub);
    226 	case rtype::SRV =>
    227 		return decode_srv(&sub);
    228 	case rtype::SSHFP =>
    229 		return decode_sshfp(&sub);
    230 	case rtype::TSIG =>
    231 		return decode_tsig(&sub);
    232 	case rtype::TXT =>
    233 		return decode_txt(&sub);
    234 	case =>
    235 		return sub.cur: unknown_rdata;
    236 	};
    237 };
    238 
    239 fn decode_a(dec: *decoder) (rdata | format) = {
    240 	if (len(dec.cur) < 4) {
    241 		return format;
    242 	};
    243 	let ip: ip::addr4 = [0...];
    244 	ip[..] = dec.cur[..4];
    245 	dec.cur = dec.cur[4..];
    246 	return ip: a;
    247 };
    248 
    249 fn decode_aaaa(dec: *decoder) (rdata | format) = {
    250 	if (len(dec.cur) < 16) {
    251 		return format;
    252 	};
    253 	let ip: ip::addr6 = [0...];
    254 	ip[..] = dec.cur[..16];
    255 	dec.cur = dec.cur[16..];
    256 	return ip: aaaa;
    257 };
    258 
    259 fn decode_caa(dec: *decoder) (rdata | format) = {
    260 	let flags = decode_u8(dec)?;
    261 	let tag_len = decode_u8(dec)?;
    262 
    263 	if (len(dec.cur) < tag_len) {
    264 		return format;
    265 	};
    266 	let tag = match(strings::fromutf8(dec.cur[..tag_len])) {
    267 	case let t: str =>
    268 		yield t;
    269 	case =>
    270 		return format;
    271 	};
    272 	let value = match (strings::fromutf8(dec.cur[tag_len..])) {
    273 	case let v: str =>
    274 		yield v;
    275 	case =>
    276 		return format;
    277 	};
    278 
    279 	return caa {
    280 		flags = flags,
    281 		tag = strings::dup(tag),
    282 		value = strings::dup(value),
    283 	};
    284 };
    285 
    286 fn decode_cname(dec: *decoder) (rdata | format) = {
    287 	return cname {
    288 		name = decode_name(dec)?,
    289 	};
    290 };
    291 
    292 fn decode_dnskey(dec: *decoder) (rdata | format) = {
    293 	let r = dnskey {
    294 		flags = decode_u16(dec)?,
    295 		protocol = decode_u8(dec)?,
    296 		algorithm = decode_u8(dec)?,
    297 		key = [],
    298 	};
    299 	append(r.key, dec.cur[..]...)!;
    300 	return r;
    301 };
    302 
    303 fn decode_mx(dec: *decoder) (rdata | format) = {
    304 	return mx {
    305 		priority = decode_u16(dec)?,
    306 		name = decode_name(dec)?,
    307 	};
    308 };
    309 
    310 fn decode_ns(dec: *decoder) (rdata | format) = {
    311 	return ns {
    312 		name = decode_name(dec)?,
    313 	};
    314 };
    315 
    316 fn decode_opt(dec: *decoder) (rdata | format) = {
    317 	let success = false;
    318 	let r = opt {
    319 		options = [],
    320 	};
    321 	defer if (!success) {
    322 		for (let i = 0z; i < len(r.options); i += 1) {
    323 			free(r.options[i].data);
    324 		};
    325 		free(r.options);
    326 	};
    327 	for (len(dec.cur) > 0) {
    328 		let o = edns_opt {
    329 			code = decode_u16(dec)?,
    330 			data = [],
    331 		};
    332 		let sz = decode_u16(dec)?;
    333 		if (len(dec.cur) < sz) {
    334 			return format;
    335 		};
    336 		append(o.data, dec.cur[..sz]...)!;
    337 		dec.cur = dec.cur[sz..];
    338 		append(r.options, o)!;
    339 	};
    340 	success = true;
    341 	return r;
    342 };
    343 
    344 fn decode_nsec(dec: *decoder) (rdata | format) = {
    345 	let r = nsec {
    346 		next_domain = decode_name(dec)?,
    347 		type_bitmaps = [],
    348 	};
    349 	append(r.type_bitmaps, dec.cur[..]...)!;
    350 	return r;
    351 };
    352 
    353 fn decode_ptr(dec: *decoder) (rdata | format) = {
    354 	return ptr {
    355 		name = decode_name(dec)?,
    356 	};
    357 };
    358 
    359 fn decode_rrsig(dec: *decoder) (rdata | format) = {
    360 	let r = rrsig {
    361 		type_covered = decode_u16(dec)?,
    362 		algorithm = decode_u8(dec)?,
    363 		labels = decode_u8(dec)?,
    364 		orig_ttl = decode_u32(dec)?,
    365 		sig_expiration = decode_u32(dec)?,
    366 		sig_inception = decode_u32(dec)?,
    367 		key_tag = decode_u16(dec)?,
    368 		signer_name = decode_name(dec)?,
    369 		signature = [],
    370 	};
    371 	append(r.signature, dec.cur[..]...)!;
    372 	return r;
    373 };
    374 
    375 fn decode_soa(dec: *decoder) (rdata | format) = {
    376 	return soa {
    377 		mname = decode_name(dec)?,
    378 		rname = decode_name(dec)?,
    379 		serial = decode_u32(dec)?,
    380 		refresh = decode_u32(dec)?,
    381 		retry = decode_u32(dec)?,
    382 		expire = decode_u32(dec)?,
    383 	};
    384 };
    385 
    386 fn decode_srv(dec: *decoder) (rdata | format) = {
    387 	return srv {
    388 		priority = decode_u16(dec)?,
    389 		weight = decode_u16(dec)?,
    390 		port = decode_u16(dec)?,
    391 		target = decode_name(dec)?,
    392 	};
    393 };
    394 
    395 fn decode_sshfp(dec: *decoder) (rdata | format) = {
    396 	let r = sshfp {
    397 		algorithm = decode_u8(dec)?,
    398 		fp_type = decode_u8(dec)?,
    399 		fingerprint = [],
    400 	};
    401 	append(r.fingerprint, dec.cur[..]...)!;
    402 	return r;
    403 };
    404 
    405 fn decode_tsig(dec: *decoder) (rdata | format) = {
    406 	let success = false;
    407 	let r = tsig {
    408 		algorithm = decode_name(dec)?,
    409 		...
    410 	};
    411 	defer if (!success) free(r.algorithm);
    412 
    413 	r.time_signed = decode_u48(dec)?;
    414 	r.fudge = decode_u16(dec)?;
    415 	r.mac_len = decode_u16(dec)?;
    416 
    417 	if (len(dec.cur) < r.mac_len) {
    418 		return format;
    419 	};
    420 	append(r.mac, dec.cur[..r.mac_len]...)!;
    421 	defer if (!success) free(r.mac);
    422 	dec.cur = dec.cur[r.mac_len..];
    423 
    424 	r.orig_id = decode_u16(dec)?;
    425 	r.error = decode_u16(dec)?;
    426 	r.other_len = decode_u16(dec)?;
    427 
    428 	if (len(dec.cur) != r.other_len) {
    429 		return format;
    430 	};
    431 	if (r.other_len > 0) {
    432 		append(r.other_data, dec.cur[..]...)!;
    433 	};
    434 
    435 	success = true;
    436 	return r;
    437 };
    438 
    439 fn decode_txt(dec: *decoder) (rdata | format) = {
    440 	let success = false;
    441 	let items: txt = [];
    442 	defer if (!success) bytes_free(items);
    443 	for (len(dec.cur) != 0) {
    444 		const ln = decode_u8(dec)?;
    445 		if (len(dec.cur) < ln) {
    446 			return format;
    447 		};
    448 		let item: []u8 = [];
    449 		append(item, dec.cur[..ln]...)!;
    450 		dec.cur = dec.cur[ln..];
    451 		append(items, item)!;
    452 	};
    453 	success = true;
    454 	return items;
    455 };
    456 
    457 // TODO: Expand breadth of supported rdata decoders