hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

ip.ha (8497B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 use bytes;
      5 use endian;
      6 use fmt;
      7 use io;
      8 use memio;
      9 use strconv;
     10 use strings;
     11 
     12 // An IPv4 address.
     13 export type addr4 = [4]u8;
     14 
     15 // An IPv6 address.
     16 export type addr6 = [16]u8;
     17 
     18 // An IP address.
     19 export type addr = (addr4 | addr6);
     20 
     21 // An IP subnet.
     22 export type subnet = struct {
     23 	addr: addr,
     24 	mask: addr,
     25 };
     26 
     27 // An IPv4 address which represents "any" address, i.e. "0.0.0.0". Binding to
     28 // this address will listen on all available IPv4 interfaces on most systems.
     29 export const ANY_V4: addr4 = [0, 0, 0, 0];
     30 
     31 // An IPv6 address which represents "any" address, i.e. "::". Binding to this
     32 // address will listen on all available IPv6 interfaces on most systems.
     33 export const ANY_V6: addr6 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
     34 
     35 // An IPv4 address which represents the loopback address, i.e. "127.0.0.1".
     36 export const LOCAL_V4: addr4 = [127, 0, 0, 1];
     37 
     38 // An IPv6 address which represents the loopback address, i.e. "::1".
     39 export const LOCAL_V6: addr6 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
     40 
     41 // Invalid parse result.
     42 export type invalid = !void;
     43 
     44 // Test if two [[addr]]s are equal.
     45 export fn equal(l: addr, r: addr) bool = {
     46 	match (l) {
     47 	case let l: addr4 =>
     48 		if (!(r is addr4)) {
     49 			return false;
     50 		};
     51 		let r = r as addr4;
     52 		return bytes::equal(l, r);
     53 	case let l: addr6 =>
     54 		if (!(r is addr6)) {
     55 			return false;
     56 		};
     57 		let r = r as addr6;
     58 		return bytes::equal(l, r);
     59 	};
     60 };
     61 
     62 // Parses an IPv4 address.
     63 export fn parsev4(st: str) (addr4 | invalid) = {
     64 	let ret: addr4 = [0...];
     65 	let tok = strings::tokenize(st, ".");
     66 	let i = 0z;
     67 	for (i < 4; i += 1) {
     68 		let s = wanttoken(&tok)?;
     69 		if (len(s) != 1 && strings::hasprefix(s, "0")) {
     70 			return invalid;
     71 		};
     72 		ret[i] = match (strconv::stou8(s)) {
     73 		case let term: u8 =>
     74 			yield term;
     75 		case =>
     76 			return invalid;
     77 		};
     78 	};
     79 	if (i < 4 || !(strings::next_token(&tok) is done)) {
     80 		return invalid;
     81 	};
     82 	return ret;
     83 };
     84 
     85 // Parses an IPv6 address.
     86 export fn parsev6(st: str) (addr6 | invalid) = {
     87 	let ret: addr6 = [0...];
     88 	if (st == "::") {
     89 		return ret;
     90 	};
     91 	let tok = strings::tokenize(st, ":");
     92 	let ells = -1;
     93 	if (strings::hasprefix(st, "::")) {
     94 		wanttoken(&tok)?;
     95 		wanttoken(&tok)?;
     96 		ells = 0;
     97 	} else if (strings::hasprefix(st, ":")) {
     98 		return invalid;
     99 	};
    100 	let i = 0;
    101 	for (i < 16) {
    102 		let s = match (strings::next_token(&tok)) {
    103 		case let s: str =>
    104 			yield s;
    105 		case done =>
    106 			break;
    107 		};
    108 		if (s == "") {
    109 			if (ells != -1) {
    110 				return invalid;
    111 			};
    112 			ells = i;
    113 			continue;
    114 		};
    115 		match (strconv::stou16(s, strconv::base::HEX)) {
    116 		case let val: u16 =>
    117 			endian::beputu16(ret[i..], val);
    118 			i += 2;
    119 		case =>
    120 			ret[i..i + 4] = parsev4(s)?;
    121 			i += 4;
    122 			break;
    123 		};
    124 	};
    125 	if (!(strings::next_token(&tok) is done)) {
    126 		return invalid;
    127 	};
    128 	if (ells >= 0) {
    129 		if (i >= 15) {
    130 			return invalid;
    131 		};
    132 		const n = i - ells;
    133 		ret[16 - n..16] = ret[ells..ells + n];
    134 		ret[ells..ells + n] = [0...];
    135 	} else if (i != 16) {
    136 		return invalid;
    137 	};
    138 
    139 	return ret;
    140 };
    141 
    142 
    143 // Parses an IP address.
    144 export fn parse(s: str) (addr | invalid) = {
    145 	match (parsev4(s)) {
    146 	case let v4: addr4 =>
    147 		return v4;
    148 	case invalid => void;
    149 	};
    150 	match (parsev6(s)) {
    151 	case let v6: addr6 =>
    152 		return v6;
    153 	case invalid => void;
    154 	};
    155 	return invalid;
    156 };
    157 
    158 fn fmtv4(s: io::handle, a: addr4) (size | io::error) = {
    159 	let ret = 0z;
    160 	for (let i = 0; i < 4; i += 1) {
    161 		if (i > 0) {
    162 			ret += fmt::fprintf(s, ".")?;
    163 		};
    164 		ret += fmt::fprintf(s, "{}", a[i])?;
    165 	};
    166 	return ret;
    167 };
    168 
    169 fn fmtv6(s: io::handle, a: addr6) (size | io::error) = {
    170 	let ret = 0z;
    171 	let zstart: int = -1;
    172 	let zend: int = -1;
    173 	for (let i = 0; i < 16; i += 2) {
    174 		let j = i;
    175 		for (j < 16 && a[j] == 0 && a[j + 1] == 0) {
    176 			j += 2;
    177 		};
    178 
    179 		if (j > i && j - i > zend - zstart) {
    180 			zstart = i;
    181 			zend = j;
    182 			i = j;
    183 		};
    184 	};
    185 
    186 	if (zend - zstart <= 2) {
    187 		zstart = -1;
    188 		zend = -1;
    189 	};
    190 
    191 	for (let i = 0; i < 16; i += 2) {
    192 		if (i == zstart) {
    193 			ret += fmt::fprintf(s, "::")?;
    194 			i = zend;
    195 			if (i >= 16)
    196 				break;
    197 		} else if (i > 0) {
    198 			ret += fmt::fprintf(s, ":")?;
    199 		};
    200 		let term = (a[i]: u16) << 8 | a[i + 1];
    201 		ret += fmt::fprintf(s, "{:x}", term)?;
    202 	};
    203 	return ret;
    204 };
    205 
    206 // Fills a netmask according to the CIDR value
    207 // e.g. 23 -> [0xFF, 0xFF, 0xFD, 0x00]
    208 fn fillmask(mask: []u8, val: u8) void = {
    209 	mask[..] = [0xFF...];
    210 	let i: int = len(mask): int - 1;
    211 	val = len(mask): u8 * 8 - val;
    212 	for (val >= 8) {
    213 		mask[i] = 0x00;
    214 		val -= 8;
    215 		i -= 1;
    216 	};
    217 	if (i >= 0) {
    218 		mask[i] = ~((1 << val) - 1);
    219 	};
    220 };
    221 
    222 // Returns an addr representing a netmask
    223 fn cidrmask(addr: addr, val: u8) (addr | invalid) = {
    224 	let a_len: u8 = match (addr) {
    225 	case addr4 =>
    226 		yield 4;
    227 	case addr6 =>
    228 		yield 16;
    229 	};
    230 
    231 	if (val > 8 * a_len)
    232 		return invalid;
    233 	if (a_len == 4) {
    234 		let ret: addr4 = [0...];
    235 		fillmask(ret[..], val);
    236 		return ret;
    237 	};
    238 	if (a_len == 16) {
    239 		let ret: addr6 = [0...];
    240 		fillmask(ret[..], val);
    241 		return ret;
    242 	};
    243 	return invalid;
    244 };
    245 
    246 // Parse an IP subnet in CIDR notation e.g. 192.168.1.0/24
    247 export fn parsecidr(st: str) (subnet | invalid) = {
    248 	let tok = strings::tokenize(st, "/");
    249 	let ips = wanttoken(&tok)?;
    250 	let addr = parse(ips)?;
    251 	let masks = wanttoken(&tok)?;
    252 	let val = match (strconv::stou8(masks)) {
    253 	case let x: u8 =>
    254 		yield x;
    255 	case =>
    256 		return invalid;
    257 	};
    258 	if (!(strings::next_token(&tok) is done)) {
    259 		return invalid;
    260 	};
    261 	return subnet {
    262 		addr = addr,
    263 		mask = cidrmask(addr, val)?
    264 	};
    265 };
    266 
    267 fn masklen(addr: []u8) (void | size) = {
    268 	let n = 0z;
    269 	for (let i = 0z; i < len(addr); i += 1) {
    270 		if (addr[i] == 0xff) {
    271 			n += 8;
    272 			continue;
    273 		};
    274 		let val = addr[i];
    275 		for (val & 0x80 != 0) {
    276 			n += 1;
    277 			val <<= 1;
    278 		};
    279 		if (val != 0)
    280 			return;
    281 		for (let j = i + 1; j < len(addr); j += 1) {
    282 			if (addr[j] != 0)
    283 				return;
    284 		};
    285 		break;
    286 	};
    287 	return n;
    288 };
    289 
    290 fn fmtmask(s: io::handle, mask: addr) (size | io::error) = {
    291 	let ret = 0z;
    292 	let slice = match (mask) {
    293 	case let v4: addr4 =>
    294 		yield v4[..];
    295 	case let v6: addr6 =>
    296 		yield v6[..];
    297 	};
    298 	match (masklen(slice)) {
    299 	case void =>
    300 		// Format as hex, if zero runs are not contiguous
    301 		// (like golang does)
    302 		for (let part .. slice) {
    303 			ret += fmt::fprintf(s, "{:x}", part)?;
    304 		};
    305 	case let n: size =>
    306 		// Standard CIDR integer
    307 		ret += fmt::fprintf(s, "{}", n)?;
    308 	};
    309 	return ret;
    310 };
    311 
    312 fn fmtsubnet(s: io::handle, subnet: subnet) (size | io::error) = {
    313 	let ret = 0z;
    314 	ret += fmt(s, subnet.addr)?;
    315 	ret += fmt::fprintf(s, "/")?;
    316 	ret += fmtmask(s, subnet.mask)?;
    317 	return ret;
    318 };
    319 
    320 // Formats an [[addr]] or [[subnet]] and prints it to a stream.
    321 export fn fmt(s: io::handle, item: (...addr | subnet)) (size | io::error) = {
    322 	match (item) {
    323 	case let v4: addr4 =>
    324 		return fmtv4(s, v4)?;
    325 	case let v6: addr6 =>
    326 		return fmtv6(s, v6)?;
    327 	case let sub: subnet =>
    328 		return fmtsubnet(s, sub);
    329 	};
    330 };
    331 
    332 // Formats an [[addr]] or [[subnet]] as a string. The return value is statically
    333 // allocated and will be overwritten on subsequent calls; see [[strings::dup]] to
    334 // extend its lifetime.
    335 export fn string(item: (...addr | subnet)) str = {
    336 	// Maximum length of an IPv6 address plus its netmask in hexadecimal
    337 	static let buf: [64]u8 = [0...];
    338 	let stream = memio::fixed(buf);
    339 	fmt(&stream, item) as size;
    340 	return memio::string(&stream)!;
    341 };
    342 
    343 fn wanttoken(tok: *strings::tokenizer) (str | invalid) = {
    344 	match (strings::next_token(tok)) {
    345 	case let s: str =>
    346 		return s;
    347 	case done =>
    348 		return invalid;
    349 	};
    350 };
    351 
    352 // Returns whether an [[addr]] (or another [[subnet]]) is contained
    353 // within a [[subnet]].
    354 export fn subnet_contains(sub: subnet, item: (addr | subnet)) bool = {
    355 	let a: subnet = match (item) {
    356 	case let a: addr =>
    357 		yield subnet {
    358 			addr = a,
    359 			mask = sub.mask,
    360 		};
    361 	case let sub: subnet =>
    362 		yield sub;
    363 	};
    364 	// Get byte slices for both addresses and masks.
    365 	let ipa = match (sub.addr) {
    366 		case let v4: addr4 => yield v4[..];
    367 		case let v6: addr6 => yield v6[..];
    368 	};
    369 	let maska = match (sub.mask) {
    370 		case let v4: addr4 => yield v4[..];
    371 		case let v6: addr6 => yield v6[..];
    372 	};
    373 	let ipb = match (a.addr) {
    374 		case let v4: addr4 => yield v4[..];
    375 		case let v6: addr6 => yield v6[..];
    376 	};
    377 	let maskb = match (a.mask) {
    378 		case let v4: addr4 => yield v4[..];
    379 		case let v6: addr6 => yield v6[..];
    380 	};
    381 	if (len(ipa) != len(ipb) || len(maska) != len(maskb) || len(ipa) != len(maska)) {
    382 		// Mismatched addr4 and addr6 addresses / masks.
    383 		return false;
    384 	};
    385 
    386 	for (let i = 0z; i < len(ipa); i += 1) {
    387 		if (ipa[i] & maska[i] != ipb[i] & maska[i] || maska[i] > maskb[i]) {
    388 			return false;
    389 		};
    390 	};
    391 	return true;
    392 };