hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

base32.ha (12504B)


      1 // License: MPL-2.0
      2 // (c) 2022 Ajay R <ar324@protonmail.com>
      3 use ascii;
      4 use bufio;
      5 use bytes;
      6 use errors;
      7 use io;
      8 use os;
      9 use strings;
     10 
     11 def PADDING: u8 = '=';
     12 
     13 export type encoding = struct {
     14 	encmap: [32]u8,
     15 	decmap: [256]u8,
     16 	valid: [256]bool,
     17 };
     18 
     19 // Represents the standard base-32 encoding alphabet as defined in RFC 4648.
     20 export const std_encoding: encoding = encoding { ... };
     21 
     22 // Represents the "base32hex" alphabet as defined in RFC 4648.
     23 export const hex_encoding: encoding = encoding { ... };
     24 
     25 // Initializes a new encoding based on the passed alphabet, which must be a
     26 // 32-byte ASCII string.
     27 export fn encoding_init(enc: *encoding, alphabet: str) void = {
     28 	const alphabet = strings::toutf8(alphabet);
     29 	assert(len(alphabet) == 32);
     30 	for (let i: u8 = 0; i < 32; i += 1) {
     31 		const ch = alphabet[i];
     32 		assert(ascii::valid(ch: u32: rune));
     33 		enc.encmap[i] = ch;
     34 		enc.decmap[ch] = i;
     35 		enc.valid[ch] = true;
     36 	};
     37 };
     38 
     39 @init fn init() void = {
     40 	const std_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
     41 	const hex_alpha: str = "0123456789ABCDEFGHIJKLMNOPQRSTUV";
     42 	encoding_init(&std_encoding, std_alpha);
     43 	encoding_init(&hex_encoding, hex_alpha);
     44 };
     45 
     46 export type encoder = struct {
     47 	stream: io::stream,
     48 	out: io::handle,
     49 	enc: *encoding,
     50 	buf: [4]u8, // leftover input
     51 	avail: size, // bytes available in buf
     52 	err: (void | io::error),
     53 };
     54 
     55 const encoder_vtable: io::vtable = io::vtable {
     56 	writer = &encode_writer,
     57 	closer = &encode_closer,
     58 	...
     59 };
     60 
     61 // Creates a stream that encodes writes as base-32 before writing them to a
     62 // secondary stream. The encoder stream must be closed to finalize any unwritten
     63 // bytes. Closing this stream will not close the underlying stream.
     64 export fn newencoder(
     65 	enc: *encoding,
     66 	out: io::handle,
     67 ) encoder = {
     68 	return encoder {
     69 		stream = &encoder_vtable,
     70 		out = out,
     71 		enc = enc,
     72 		err = void,
     73 		...
     74 	};
     75 };
     76 
     77 fn encode_writer(
     78 	s: *io::stream,
     79 	in: const []u8
     80 ) (size | io::error) = {
     81 	let s = s: *encoder;
     82 	match(s.err) {
     83 	case let err: io::error =>
     84 		return err;
     85 	case void =>
     86 		yield;
     87 	};
     88 	let l = len(in);
     89 	let i = 0z;
     90 	for (i + 4 < l + s.avail; i += 5) {
     91 		static let b: [5]u8 = [0...]; // 5 bytes -> (enc) 8 bytes
     92 		if (i < s.avail) {
     93 			for (let j = 0z; j < s.avail; j += 1) {
     94 				b[j] = s.buf[i];
     95 			};
     96 			for (let j = s.avail; j < 5; j += 1) {
     97 				b[j] = in[j - s.avail];
     98 			};
     99 		} else {
    100 			for (let j = 0z; j < 5; j += 1) {
    101 				b[j] = in[j - s.avail + i];
    102 			};
    103 		};
    104 		let encb: [8]u8 = [
    105 			s.enc.encmap[b[0] >> 3],
    106 			s.enc.encmap[(b[0] & 0x7) << 2 | (b[1] & 0xC0) >> 6],
    107 			s.enc.encmap[(b[1] & 0x3E) >> 1],
    108 			s.enc.encmap[(b[1] & 0x1) << 4 | (b[2] & 0xF0) >> 4],
    109 			s.enc.encmap[(b[2] & 0xF) << 1 | (b[3] & 0x80) >> 7],
    110 			s.enc.encmap[(b[3] & 0x7C) >> 2],
    111 			s.enc.encmap[(b[3] & 0x3) << 3 | (b[4] & 0xE0) >> 5],
    112 			s.enc.encmap[b[4] & 0x1F],
    113 		];
    114 		match(io::write(s.out, encb)) {
    115 		case let err: io::error =>
    116 			s.err = err;
    117 			return err;
    118 		case size =>
    119 			yield;
    120 		};
    121 	};
    122 	// storing leftover bytes
    123 	if (l + s.avail < 5) {
    124 		for (let j = s.avail; j < s.avail + l; j += 1) {
    125 			s.buf[j] = in[j - s.avail];
    126 		};
    127 	} else {
    128 		const begin = (l + s.avail) / 5 * 5;
    129 		for (let j = begin; j < l + s.avail; j += 1) {
    130 			s.buf[j - begin] = in[j - s.avail];
    131 		};
    132 	};
    133 	s.avail = (l + s.avail) % 5;
    134 	return l;
    135 };
    136 
    137 fn encode_closer(s: *io::stream) (void | io::error) = {
    138 	let s = s: *encoder;
    139 	if (s.avail == 0) {
    140 		return;
    141 	};
    142 	static let b: [5]u8 = [0...]; // the 5 bytes that will be encoded into 8 bytes
    143 	for (let i = 0z; i < 5; i += 1) {
    144 		b[i] = if (i < s.avail) s.buf[i] else 0;
    145 	};
    146 	let encb: [8]u8 = [
    147 		s.enc.encmap[b[0] >> 3],
    148 		s.enc.encmap[(b[0] & 0x7) << 2 | (b[1] & 0xC0) >> 6],
    149 		s.enc.encmap[(b[1] & 0x3E) >> 1],
    150 		s.enc.encmap[(b[1] & 0x1) << 4 | (b[2] & 0xF0) >> 4],
    151 		s.enc.encmap[(b[2] & 0xF) << 1 | (b[3] & 0x80) >> 7],
    152 		s.enc.encmap[(b[3] & 0x7C) >> 2],
    153 		s.enc.encmap[(b[3] & 0x3) << 3 | (b[4] & 0xE0) >> 5],
    154 		s.enc.encmap[b[4] & 0x1F],
    155 	];
    156 	// adding padding as input length was not a multiple of 5
    157 	//                        0  1  2  3  4
    158 	static const npa: []u8 = [0, 6, 4, 3, 1];
    159 	const np = npa[s.avail];
    160 	for (let i = 0z; i < np; i += 1) {
    161 		encb[7 - i] = PADDING;
    162 	};
    163 	io::writeall(s.out, encb)?;
    164 };
    165 
    166 // Encodes a byte slice in base-32, using the given encoding, returning a slice
    167 // of ASCII bytes. The caller must free the return value.
    168 export fn encodeslice(enc: *encoding, in: []u8) []u8 = {
    169 	let out = bufio::dynamic(io::mode::WRITE);
    170 	let encoder = newencoder(enc, &out);
    171 	io::writeall(&encoder, in)!;
    172 	io::close(&encoder)!;
    173 	return bufio::buffer(&out);
    174 };
    175 
    176 // Encodes a byte slice in base-32, using the given encoding, returning a
    177 // string. The caller must free the return value.
    178 export fn encodestr(enc: *encoding, in: []u8) str = {
    179 	return strings::fromutf8(encodeslice(enc, in));
    180 };
    181 
    182 @test fn encode() void = {
    183 	// RFC 4648 test vectors
    184 	const in: [_]u8 = ['f', 'o', 'o', 'b', 'a', 'r'];
    185 	const expect: [_]str = [
    186 		"",
    187 		"MY======",
    188 		"MZXQ====",
    189 		"MZXW6===",
    190 		"MZXW6YQ=",
    191 		"MZXW6YTB",
    192 		"MZXW6YTBOI======",
    193 	];
    194 	const expect_hex: [_]str = [
    195 		"",
    196 		"CO======",
    197 		"CPNG====",
    198 		"CPNMU===",
    199 		"CPNMUOG=",
    200 		"CPNMUOJ1",
    201 		"CPNMUOJ1E8======",
    202 	];
    203 	for (let i = 0z; i <= len(in); i += 1) {
    204 		let out = bufio::dynamic(io::mode::RDWR);
    205 		let enc = newencoder(&std_encoding, &out);
    206 		io::writeall(&enc, in[..i]) as size;
    207 		io::close(&enc)!;
    208 		let outb = bufio::buffer(&out);
    209 		assert(bytes::equal(outb, strings::toutf8(expect[i])));
    210 		free(outb);
    211 		// Testing encodestr should cover encodeslice too
    212 		let s = encodestr(&std_encoding, in[..i]);
    213 		defer free(s);
    214 		assert(s == expect[i]);
    215 
    216 		out = bufio::dynamic(io::mode::RDWR);
    217 		enc = newencoder(&hex_encoding, &out);
    218 		io::writeall(&enc, in[..i]) as size;
    219 		io::close(&enc)!;
    220 		let outb = bufio::buffer(&out);
    221 		assert(bytes::equal(outb, strings::toutf8(expect_hex[i])));
    222 		free(outb);
    223 		let s = encodestr(&hex_encoding, in[..i]);
    224 		defer free(s);
    225 		assert(s == expect_hex[i]);
    226 	};
    227 };
    228 
    229 export type decoder = struct {
    230 	stream: io::stream,
    231 	in: io::handle,
    232 	enc: *encoding,
    233 	avail: []u8, // leftover decoded output
    234 	pad: bool, // if padding was seen in a previous read
    235 	state: (void | io::EOF | io::error),
    236 };
    237 
    238 const decoder_vtable: io::vtable = io::vtable {
    239 	reader = &decode_reader,
    240 	...
    241 };
    242 
    243 // Creates a stream that reads and decodes base-32 data from a secondary stream.
    244 // This stream does not need to be closed, and closing it will not close the
    245 // underlying stream.
    246 export fn newdecoder(
    247 	enc: *encoding,
    248 	in: io::handle,
    249 ) decoder = {
    250 	return decoder {
    251 		stream = &decoder_vtable,
    252 		in = in,
    253 		enc = enc,
    254 		state = void,
    255 		...
    256 	};
    257 };
    258 
    259 fn decode_reader(
    260 	s: *io::stream,
    261 	out: []u8
    262 ) (size | io::EOF | io::error) = {
    263 	let s = s: *decoder;
    264 	let n = 0z;
    265 	let l = len(out);
    266 	match(s.state) {
    267 	case let err: (io::EOF | io ::error) =>
    268 		return err;
    269 	case void =>
    270 		yield;
    271 	};
    272 	if (len(s.avail) > 0) {
    273 		n += if (l < len(s.avail)) l else len(s.avail);
    274 		out[..n] = s.avail[0..n];
    275 		s.avail = s.avail[n..];
    276 		if (l == n) {
    277 			return n;
    278 		};
    279 	};
    280 	static let buf: [os::BUFSIZ]u8 = [0...];
    281 	static let obuf: [os::BUFSIZ / 8 * 5]u8 = [0...];
    282 	const nn = ((l - n) / 5 + 1) * 8; // 8 extra bytes may be read.
    283 	let nr = 0z;
    284 	for (nr < nn) {
    285 		match (io::read(s.in, buf[nr..])) {
    286 		case let n: size =>
    287 			nr += n;
    288 		case io::EOF =>
    289 			s.state = io::EOF;
    290 			break;
    291 		case let err: io::error =>
    292 			s.state = err;
    293 			return err;
    294 		};
    295 	};
    296 	if (nr % 8 != 0) {
    297 		s.state = errors::invalid;
    298 		return errors::invalid;
    299 	};
    300 	if (nr == 0) { // io::EOF already set
    301 		return n;
    302 	};
    303 	// Validating read buffer
    304 	let valid = true;
    305 	let np = 0; // Number of padding chars.
    306 	let p = true; // Pad allowed in buf
    307 	for (let i = nr; i > 0; i -= 1) {
    308 		const ch = buf[i - 1];
    309 		if (ch == PADDING) {
    310 			if(s.pad || !p) {
    311 				valid = false;
    312 				break;
    313 			};
    314 			np += 1;
    315 		} else {
    316 			if (!s.enc.valid[ch]) {
    317 				valid = false;
    318 				break;
    319 			};
    320 			// Disallow padding on seeing a non-padding char
    321 			p = false;
    322 		};
    323 	};
    324 	valid = valid && np <= 6 && np != 2 && np != 5;
    325 	if (np > 0) {
    326 		s.pad = true;
    327 	};
    328 	if (!valid) {
    329 		s.state = errors::invalid;
    330 		return errors::invalid;
    331 	};
    332 	for (let i = 0z; i < nr; i += 1) {
    333 		buf[i] = s.enc.decmap[buf[i]];
    334 	};
    335 	for (let i = 0z, j = 0z; i < nr) {
    336 		obuf[j] = (buf[i] << 3) | (buf[i + 1] & 0x1C) >> 2;
    337 		obuf[j + 1] =
    338 			(buf[i + 1] & 0x3) << 6 | buf[i + 2] << 1 | (buf[i + 3] & 0x10) >> 4;
    339 		obuf[j + 2] = (buf[i + 3] & 0x0F) << 4 | (buf[i + 4] & 0x1E) >> 1;
    340 		obuf[j + 3] =
    341 			(buf[i + 4] & 0x1) << 7 | buf[i + 5] << 2 | (buf[i + 6] & 0x18) >> 3;
    342 		obuf[j + 4] = (buf[i + 6] & 0x7) << 5 | buf[i + 7];
    343 		i += 8;
    344 		j += 5;
    345 	};
    346 	// Removing bytes added due to padding.
    347 	//                         0  1  2  3  4  5  6   // np
    348 	static const npr: [7]u8 = [0, 1, 0, 2, 3, 0, 4]; // bytes to discard
    349 	const navl = nr / 8 * 5 - npr[np];
    350 	const rem = if(l - n < navl) l - n else navl;
    351 	for (let i = n; i < n + rem; i += 1) {
    352 		out[i] = obuf[i - n];
    353 	};
    354 	s.avail = obuf[rem..navl];
    355 	return n + rem;
    356 };
    357 
    358 // Decodes a byte slice of ASCII-encoded base-32 data, using the given encoding,
    359 // returning a slice of decoded bytes. The caller must free the return value.
    360 export fn decodeslice(
    361 	enc: *encoding,
    362 	in: []u8,
    363 ) ([]u8 | errors::invalid) = {
    364 	let in = bufio::fixed(in, io::mode::READ);
    365 	let decoder = newdecoder(enc, &in);
    366 	let out = bufio::dynamic(io::mode::WRITE);
    367 	match (io::copy(&out, &decoder)) {
    368 	case io::error =>
    369 		io::close(&out)!;
    370 		return errors::invalid;
    371 	case size =>
    372 		return bufio::buffer(&out);
    373 	};
    374 };
    375 
    376 // Decodes a string of ASCII-encoded base-32 data, using the given encoding,
    377 // returning a slice of decoded bytes. The caller must free the return value.
    378 export fn decodestr(enc: *encoding, in: str) ([]u8 | errors::invalid) = {
    379 	return decodeslice(enc, strings::toutf8(in));
    380 };
    381 
    382 @test fn decode() void = {
    383 	const cases: [_](str, str, *encoding) = [
    384 		("", "", &std_encoding),
    385 		("MY======", "f", &std_encoding),
    386 		("MZXQ====", "fo", &std_encoding),
    387 		("MZXW6===", "foo", &std_encoding),
    388 		("MZXW6YQ=", "foob", &std_encoding),
    389 		("MZXW6YTB", "fooba", &std_encoding),
    390 		("MZXW6YTBOI======", "foobar", &std_encoding),
    391 		("", "", &hex_encoding),
    392 		("CO======", "f", &hex_encoding),
    393 		("CPNG====", "fo", &hex_encoding),
    394 		("CPNMU===", "foo", &hex_encoding),
    395 		("CPNMUOG=", "foob", &hex_encoding),
    396 		("CPNMUOJ1", "fooba", &hex_encoding),
    397 		("CPNMUOJ1E8======", "foobar", &hex_encoding),
    398 	];
    399 	for (let i = 0z; i < len(cases); i += 1) {
    400 		let in = bufio::fixed(strings::toutf8(cases[i].0), io::mode::READ);
    401 		let dec = newdecoder(cases[i].2, &in);
    402 		let buf: [1]u8 = [0];
    403 		let out: []u8 = [];
    404 		defer free(out);
    405 		for (true) match (io::read(&dec, buf)!) {
    406 		case let z: size =>
    407 			if (z > 0) {
    408 				append(out, buf[0]);
    409 			};
    410 		case io::EOF =>
    411 			break;
    412 		};
    413 		assert(bytes::equal(out, strings::toutf8(cases[i].1)));
    414 
    415 		// Testing decodestr should cover decodeslice too
    416 		let decb = decodestr(cases[i].2, cases[i].0) as []u8;
    417 		defer free(decb);
    418 		assert(bytes::equal(decb, strings::toutf8(cases[i].1)));
    419 	};
    420 	// Repeat of the above, but with a larger buffer
    421 	for (let i = 0z; i < len(cases); i += 1) {
    422 		let in = bufio::fixed(strings::toutf8(cases[i].0), io::mode::READ);
    423 		let dec = newdecoder(cases[i].2, &in);
    424 		let buf: [1024]u8 = [0...];
    425 		let out: []u8 = [];
    426 		defer free(out);
    427 		for (true) match (io::read(&dec, buf)!) {
    428 		case let z: size =>
    429 			if (z > 0) {
    430 				append(out, buf[..z]...);
    431 			};
    432 		case io::EOF =>
    433 			break;
    434 		};
    435 		assert(bytes::equal(out, strings::toutf8(cases[i].1)));
    436 	};
    437 
    438 	const invalid: [_](str, *encoding) = [
    439 		// invalid padding
    440 		("=", &std_encoding),
    441 		("==", &std_encoding),
    442 		("===", &std_encoding),
    443 		("=====", &std_encoding),
    444 		("======", &std_encoding),
    445 		("=======", &std_encoding),
    446 		("========", &std_encoding),
    447 		("=========", &std_encoding),
    448 		// invalid characters
    449 		("1ZXW6YQ=", &std_encoding),
    450 		("êZXW6YQ=", &std_encoding),
    451 		("MZXW1YQ=", &std_encoding),
    452 		// data after padding is encountered
    453 		("CO======CO======", &std_encoding),
    454 		("CPNG====CPNG====", &std_encoding),
    455 	];
    456 	for (let i = 0z; i < len(invalid); i += 1) {
    457 		let in = bufio::fixed(strings::toutf8(invalid[i].0), io::mode::READ);
    458 		let dec = newdecoder(invalid[i].1, &in);
    459 		let buf: [1]u8 = [0...];
    460 		let valid = false;
    461 		for (true) match(io::read(&dec, buf)) {
    462 		case errors::invalid =>
    463 			break;
    464 		case size =>
    465 			valid = true;
    466 		case io::EOF =>
    467 			break;
    468 		};
    469 		assert(valid == false, "valid is not false");
    470 
    471 		// Testing decodestr should cover decodeslice too
    472 		assert(decodestr(invalid[i].1, invalid[i].0) is errors::invalid);
    473 	};
    474 };