hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

base32.ha (12025B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 use ascii;
      5 use bytes;
      6 use errors;
      7 use io;
      8 use memio;
      9 use os;
     10 use strings;
     11 
     12 def PADDING: u8 = '=';
     13 
     14 export type encoding = struct {
     15 	encmap: [32]u8,
     16 	decmap: [256]u8,
     17 	valid: [256]bool,
     18 };
     19 
     20 // Represents the standard base-32 encoding alphabet as defined in RFC 4648.
     21 export const std_encoding: encoding = encoding { ... };
     22 
     23 // Represents the "base32hex" alphabet as defined in RFC 4648.
     24 export const hex_encoding: encoding = encoding { ... };
     25 
     26 // Initializes a new encoding based on the passed alphabet, which must be a
     27 // 32-byte ASCII string.
     28 export fn encoding_init(enc: *encoding, alphabet: str) void = {
     29 	const alphabet = strings::toutf8(alphabet);
     30 	assert(len(alphabet) == 32);
     31 	for (let i: u8 = 0; i < 32; i += 1) {
     32 		const ch = alphabet[i];
     33 		assert(ascii::valid(ch: rune));
     34 		enc.encmap[i] = ch;
     35 		enc.decmap[ch] = i;
     36 		enc.valid[ch] = true;
     37 	};
     38 };
     39 
     40 @init fn init() void = {
     41 	const std_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
     42 	const hex_alpha: str = "0123456789ABCDEFGHIJKLMNOPQRSTUV";
     43 	encoding_init(&std_encoding, std_alpha);
     44 	encoding_init(&hex_encoding, hex_alpha);
     45 };
     46 
     47 export type encoder = struct {
     48 	stream: io::stream,
     49 	out: io::handle,
     50 	enc: *encoding,
     51 	buf: [4]u8, // leftover input
     52 	avail: size, // bytes available in buf
     53 	err: (void | io::error),
     54 };
     55 
     56 const encoder_vtable: io::vtable = io::vtable {
     57 	writer = &encode_writer,
     58 	closer = &encode_closer,
     59 	...
     60 };
     61 
     62 // Creates a stream that encodes writes as base-32 before writing them to a
     63 // secondary stream. The encoder stream must be closed to finalize any unwritten
     64 // bytes. Closing this stream will not close the underlying stream.
     65 export fn newencoder(
     66 	enc: *encoding,
     67 	out: io::handle,
     68 ) encoder = {
     69 	return encoder {
     70 		stream = &encoder_vtable,
     71 		out = out,
     72 		enc = enc,
     73 		err = void,
     74 		...
     75 	};
     76 };
     77 
     78 fn encode_writer(
     79 	s: *io::stream,
     80 	in: const []u8
     81 ) (size | io::error) = {
     82 	let s = s: *encoder;
     83 	match(s.err) {
     84 	case let err: io::error =>
     85 		return err;
     86 	case void => void;
     87 	};
     88 	let l = len(in);
     89 	let i = 0z;
     90 	for (i + 4 < l + s.avail; i += 5) {
     91 		static let b: [5]u8 = [0...]; // 5 bytes -> (enc) 8 bytes
     92 		if (i < s.avail) {
     93 			for (let j = 0z; j < s.avail; j += 1) {
     94 				b[j] = s.buf[i];
     95 			};
     96 			for (let j = s.avail; j < 5; j += 1) {
     97 				b[j] = in[j - s.avail];
     98 			};
     99 		} else {
    100 			for (let j = 0z; j < 5; j += 1) {
    101 				b[j] = in[j - s.avail + i];
    102 			};
    103 		};
    104 		let encb: [8]u8 = [
    105 			s.enc.encmap[b[0] >> 3],
    106 			s.enc.encmap[(b[0] & 0x7) << 2 | (b[1] & 0xC0) >> 6],
    107 			s.enc.encmap[(b[1] & 0x3E) >> 1],
    108 			s.enc.encmap[(b[1] & 0x1) << 4 | (b[2] & 0xF0) >> 4],
    109 			s.enc.encmap[(b[2] & 0xF) << 1 | (b[3] & 0x80) >> 7],
    110 			s.enc.encmap[(b[3] & 0x7C) >> 2],
    111 			s.enc.encmap[(b[3] & 0x3) << 3 | (b[4] & 0xE0) >> 5],
    112 			s.enc.encmap[b[4] & 0x1F],
    113 		];
    114 		match(io::write(s.out, encb)) {
    115 		case let err: io::error =>
    116 			s.err = err;
    117 			return err;
    118 		case size => void;
    119 		};
    120 	};
    121 	// storing leftover bytes
    122 	if (l + s.avail < 5) {
    123 		for (let j = s.avail; j < s.avail + l; j += 1) {
    124 			s.buf[j] = in[j - s.avail];
    125 		};
    126 	} else {
    127 		const begin = (l + s.avail) / 5 * 5;
    128 		for (let j = begin; j < l + s.avail; j += 1) {
    129 			s.buf[j - begin] = in[j - s.avail];
    130 		};
    131 	};
    132 	s.avail = (l + s.avail) % 5;
    133 	return l;
    134 };
    135 
    136 fn encode_closer(s: *io::stream) (void | io::error) = {
    137 	let s = s: *encoder;
    138 	if (s.avail == 0) {
    139 		return;
    140 	};
    141 	static let b: [5]u8 = [0...]; // the 5 bytes that will be encoded into 8 bytes
    142 	for (let i = 0z; i < 5; i += 1) {
    143 		b[i] = if (i < s.avail) s.buf[i] else 0;
    144 	};
    145 	let encb: [8]u8 = [
    146 		s.enc.encmap[b[0] >> 3],
    147 		s.enc.encmap[(b[0] & 0x7) << 2 | (b[1] & 0xC0) >> 6],
    148 		s.enc.encmap[(b[1] & 0x3E) >> 1],
    149 		s.enc.encmap[(b[1] & 0x1) << 4 | (b[2] & 0xF0) >> 4],
    150 		s.enc.encmap[(b[2] & 0xF) << 1 | (b[3] & 0x80) >> 7],
    151 		s.enc.encmap[(b[3] & 0x7C) >> 2],
    152 		s.enc.encmap[(b[3] & 0x3) << 3 | (b[4] & 0xE0) >> 5],
    153 		s.enc.encmap[b[4] & 0x1F],
    154 	];
    155 	// adding padding as input length was not a multiple of 5
    156 	//                        0  1  2  3  4
    157 	static const npa: []u8 = [0, 6, 4, 3, 1];
    158 	const np = npa[s.avail];
    159 	for (let i = 0z; i < np; i += 1) {
    160 		encb[7 - i] = PADDING;
    161 	};
    162 	io::writeall(s.out, encb)?;
    163 };
    164 
    165 // Encodes a byte slice in base-32, using the given encoding, returning a slice
    166 // of ASCII bytes. The caller must free the return value.
    167 export fn encodeslice(enc: *encoding, in: []u8) []u8 = {
    168 	let out = memio::dynamic();
    169 	let encoder = newencoder(enc, &out);
    170 	io::writeall(&encoder, in)!;
    171 	io::close(&encoder)!;
    172 	return memio::buffer(&out);
    173 };
    174 
    175 // Encodes a byte slice in base-32, using the given encoding, returning a
    176 // string. The caller must free the return value.
    177 export fn encodestr(enc: *encoding, in: []u8) str = {
    178 	return strings::fromutf8(encodeslice(enc, in))!;
    179 };
    180 
    181 @test fn encode() void = {
    182 	// RFC 4648 test vectors
    183 	const in: [_]u8 = ['f', 'o', 'o', 'b', 'a', 'r'];
    184 	const expect: [_]str = [
    185 		"",
    186 		"MY======",
    187 		"MZXQ====",
    188 		"MZXW6===",
    189 		"MZXW6YQ=",
    190 		"MZXW6YTB",
    191 		"MZXW6YTBOI======",
    192 	];
    193 	const expect_hex: [_]str = [
    194 		"",
    195 		"CO======",
    196 		"CPNG====",
    197 		"CPNMU===",
    198 		"CPNMUOG=",
    199 		"CPNMUOJ1",
    200 		"CPNMUOJ1E8======",
    201 	];
    202 	for (let i = 0z; i <= len(in); i += 1) {
    203 		let out = memio::dynamic();
    204 		let enc = newencoder(&std_encoding, &out);
    205 		io::writeall(&enc, in[..i]) as size;
    206 		io::close(&enc)!;
    207 		let outb = memio::buffer(&out);
    208 		assert(bytes::equal(outb, strings::toutf8(expect[i])));
    209 		free(outb);
    210 		// Testing encodestr should cover encodeslice too
    211 		let s = encodestr(&std_encoding, in[..i]);
    212 		defer free(s);
    213 		assert(s == expect[i]);
    214 
    215 		out = memio::dynamic();
    216 		enc = newencoder(&hex_encoding, &out);
    217 		io::writeall(&enc, in[..i]) as size;
    218 		io::close(&enc)!;
    219 		let outb = memio::buffer(&out);
    220 		assert(bytes::equal(outb, strings::toutf8(expect_hex[i])));
    221 		free(outb);
    222 		let s = encodestr(&hex_encoding, in[..i]);
    223 		defer free(s);
    224 		assert(s == expect_hex[i]);
    225 	};
    226 };
    227 
    228 export type decoder = struct {
    229 	stream: io::stream,
    230 	in: io::handle,
    231 	enc: *encoding,
    232 	avail: []u8, // leftover decoded output
    233 	pad: bool, // if padding was seen in a previous read
    234 	state: (void | io::EOF | io::error),
    235 };
    236 
    237 const decoder_vtable: io::vtable = io::vtable {
    238 	reader = &decode_reader,
    239 	...
    240 };
    241 
    242 // Creates a stream that reads and decodes base-32 data from a secondary stream.
    243 // This stream does not need to be closed, and closing it will not close the
    244 // underlying stream.
    245 export fn newdecoder(
    246 	enc: *encoding,
    247 	in: io::handle,
    248 ) decoder = {
    249 	return decoder {
    250 		stream = &decoder_vtable,
    251 		in = in,
    252 		enc = enc,
    253 		state = void,
    254 		...
    255 	};
    256 };
    257 
    258 fn decode_reader(
    259 	s: *io::stream,
    260 	out: []u8
    261 ) (size | io::EOF | io::error) = {
    262 	let s = s: *decoder;
    263 	let n = 0z;
    264 	let l = len(out);
    265 	match(s.state) {
    266 	case let err: (io::EOF | io ::error) =>
    267 		return err;
    268 	case void => void;
    269 	};
    270 	if (len(s.avail) > 0) {
    271 		n += if (l < len(s.avail)) l else len(s.avail);
    272 		out[..n] = s.avail[0..n];
    273 		s.avail = s.avail[n..];
    274 		if (l == n) {
    275 			return n;
    276 		};
    277 	};
    278 	static let buf: [os::BUFSZ]u8 = [0...];
    279 	static let obuf: [os::BUFSZ / 8 * 5]u8 = [0...];
    280 	const nn = ((l - n) / 5 + 1) * 8; // 8 extra bytes may be read.
    281 	let nr = 0z;
    282 	for (nr < nn) {
    283 		match (io::read(s.in, buf[nr..])) {
    284 		case let n: size =>
    285 			nr += n;
    286 		case io::EOF =>
    287 			s.state = io::EOF;
    288 			break;
    289 		case let err: io::error =>
    290 			s.state = err;
    291 			return err;
    292 		};
    293 	};
    294 	if (nr % 8 != 0) {
    295 		s.state = errors::invalid;
    296 		return errors::invalid;
    297 	};
    298 	if (nr == 0) { // io::EOF already set
    299 		return n;
    300 	};
    301 	// Validating read buffer
    302 	let valid = true;
    303 	let np = 0; // Number of padding chars.
    304 	let p = true; // Pad allowed in buf
    305 	for (let i = nr; i > 0; i -= 1) {
    306 		const ch = buf[i - 1];
    307 		if (ch == PADDING) {
    308 			if(s.pad || !p) {
    309 				valid = false;
    310 				break;
    311 			};
    312 			np += 1;
    313 		} else {
    314 			if (!s.enc.valid[ch]) {
    315 				valid = false;
    316 				break;
    317 			};
    318 			// Disallow padding on seeing a non-padding char
    319 			p = false;
    320 		};
    321 	};
    322 	valid = valid && np <= 6 && np != 2 && np != 5;
    323 	if (np > 0) {
    324 		s.pad = true;
    325 	};
    326 	if (!valid) {
    327 		s.state = errors::invalid;
    328 		return errors::invalid;
    329 	};
    330 	for (let i = 0z; i < nr; i += 1) {
    331 		buf[i] = s.enc.decmap[buf[i]];
    332 	};
    333 	for (let i = 0z, j = 0z; i < nr) {
    334 		obuf[j] = (buf[i] << 3) | (buf[i + 1] & 0x1C) >> 2;
    335 		obuf[j + 1] =
    336 			(buf[i + 1] & 0x3) << 6 | buf[i + 2] << 1 | (buf[i + 3] & 0x10) >> 4;
    337 		obuf[j + 2] = (buf[i + 3] & 0x0F) << 4 | (buf[i + 4] & 0x1E) >> 1;
    338 		obuf[j + 3] =
    339 			(buf[i + 4] & 0x1) << 7 | buf[i + 5] << 2 | (buf[i + 6] & 0x18) >> 3;
    340 		obuf[j + 4] = (buf[i + 6] & 0x7) << 5 | buf[i + 7];
    341 		i += 8;
    342 		j += 5;
    343 	};
    344 	// Removing bytes added due to padding.
    345 	//                         0  1  2  3  4  5  6   // np
    346 	static const npr: [7]u8 = [0, 1, 0, 2, 3, 0, 4]; // bytes to discard
    347 	const navl = nr / 8 * 5 - npr[np];
    348 	const rem = if(l - n < navl) l - n else navl;
    349 	out[n..n + rem] = obuf[..rem];
    350 	s.avail = obuf[rem..navl];
    351 	return n + rem;
    352 };
    353 
    354 // Decodes a byte slice of ASCII-encoded base-32 data, using the given encoding,
    355 // returning a slice of decoded bytes. The caller must free the return value.
    356 export fn decodeslice(
    357 	enc: *encoding,
    358 	in: []u8,
    359 ) ([]u8 | errors::invalid) = {
    360 	let in = memio::fixed(in);
    361 	let decoder = newdecoder(enc, &in);
    362 	let out = memio::dynamic();
    363 	match (io::copy(&out, &decoder)) {
    364 	case io::error =>
    365 		io::close(&out)!;
    366 		return errors::invalid;
    367 	case size =>
    368 		return memio::buffer(&out);
    369 	};
    370 };
    371 
    372 // Decodes a string of ASCII-encoded base-32 data, using the given encoding,
    373 // returning a slice of decoded bytes. The caller must free the return value.
    374 export fn decodestr(enc: *encoding, in: str) ([]u8 | errors::invalid) = {
    375 	return decodeslice(enc, strings::toutf8(in));
    376 };
    377 
    378 @test fn decode() void = {
    379 	const cases: [_](str, str, *encoding) = [
    380 		("", "", &std_encoding),
    381 		("MY======", "f", &std_encoding),
    382 		("MZXQ====", "fo", &std_encoding),
    383 		("MZXW6===", "foo", &std_encoding),
    384 		("MZXW6YQ=", "foob", &std_encoding),
    385 		("MZXW6YTB", "fooba", &std_encoding),
    386 		("MZXW6YTBOI======", "foobar", &std_encoding),
    387 		("", "", &hex_encoding),
    388 		("CO======", "f", &hex_encoding),
    389 		("CPNG====", "fo", &hex_encoding),
    390 		("CPNMU===", "foo", &hex_encoding),
    391 		("CPNMUOG=", "foob", &hex_encoding),
    392 		("CPNMUOJ1", "fooba", &hex_encoding),
    393 		("CPNMUOJ1E8======", "foobar", &hex_encoding),
    394 	];
    395 	for (let i = 0z; i < len(cases); i += 1) {
    396 		let in = memio::fixed(strings::toutf8(cases[i].0));
    397 		let dec = newdecoder(cases[i].2, &in);
    398 		let out: []u8 = io::drain(&dec)!;
    399 		defer free(out);
    400 		assert(bytes::equal(out, strings::toutf8(cases[i].1)));
    401 
    402 		// Testing decodestr should cover decodeslice too
    403 		let decb = decodestr(cases[i].2, cases[i].0) as []u8;
    404 		defer free(decb);
    405 		assert(bytes::equal(decb, strings::toutf8(cases[i].1)));
    406 	};
    407 	// Repeat of the above, but with a larger buffer
    408 	for (let i = 0z; i < len(cases); i += 1) {
    409 		let in = memio::fixed(strings::toutf8(cases[i].0));
    410 		let dec = newdecoder(cases[i].2, &in);
    411 		let out: []u8 = io::drain(&dec)!;
    412 		defer free(out);
    413 		assert(bytes::equal(out, strings::toutf8(cases[i].1)));
    414 	};
    415 
    416 	const invalid: [_](str, *encoding) = [
    417 		// invalid padding
    418 		("=", &std_encoding),
    419 		("==", &std_encoding),
    420 		("===", &std_encoding),
    421 		("=====", &std_encoding),
    422 		("======", &std_encoding),
    423 		("=======", &std_encoding),
    424 		("========", &std_encoding),
    425 		("=========", &std_encoding),
    426 		// invalid characters
    427 		("1ZXW6YQ=", &std_encoding),
    428 		("êZXW6YQ=", &std_encoding),
    429 		("MZXW1YQ=", &std_encoding),
    430 		// data after padding is encountered
    431 		("CO======CO======", &std_encoding),
    432 		("CPNG====CPNG====", &std_encoding),
    433 	];
    434 	for (let i = 0z; i < len(invalid); i += 1) {
    435 		let in = memio::fixed(strings::toutf8(invalid[i].0));
    436 		let dec = newdecoder(invalid[i].1, &in);
    437 		let buf: [1]u8 = [0...];
    438 		let valid = false;
    439 		for (true) match(io::read(&dec, buf)) {
    440 		case errors::invalid =>
    441 			break;
    442 		case size =>
    443 			valid = true;
    444 		case io::EOF =>
    445 			break;
    446 		};
    447 		assert(valid == false, "valid is not false");
    448 
    449 		// Testing decodestr should cover decodeslice too
    450 		assert(decodestr(invalid[i].1, invalid[i].0) is errors::invalid);
    451 	};
    452 };