hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

base64.ha (13005B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 use ascii;
      5 use bytes;
      6 use errors;
      7 use io;
      8 use memio;
      9 use os;
     10 use strings;
     11 
     12 def PADDING: u8 = '=';
     13 
     14 export type encoding = struct {
     15 	encmap: [64]u8,
     16 	decmap: [128]u8,
     17 };
     18 
     19 // Represents the standard base-64 encoding alphabet as defined in RFC 4648.
     20 export const std_encoding: encoding = encoding { ... };
     21 
     22 // Represents the "base64url" alphabet as defined in RFC 4648, suitable for use
     23 // in URLs and file paths.
     24 export const url_encoding: encoding = encoding { ... };
     25 
     26 // Initializes a new encoding based on the passed alphabet, which must be a
     27 // 64-byte ASCII string.
     28 export fn encoding_init(enc: *encoding, alphabet: str) void = {
     29 	const alphabet = strings::toutf8(alphabet);
     30 	enc.decmap[..] = [-1...];
     31 	assert(len(alphabet) == 64);
     32 	for (let i: u8 = 0; i < 64; i += 1) {
     33 		const ch = alphabet[i];
     34 		assert(ascii::valid(ch: rune) && enc.decmap[ch] == -1);
     35 		enc.encmap[i] = ch;
     36 		enc.decmap[ch] = i;
     37 	};
     38 };
     39 
     40 @init fn init() void = {
     41 	const std_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
     42 	const url_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
     43 	encoding_init(&std_encoding, std_alpha);
     44 	encoding_init(&url_encoding, url_alpha);
     45 };
     46 
     47 export type encoder = struct {
     48 	stream: io::stream,
     49 	out: io::handle,
     50 	enc: *encoding,
     51 	ibuf: [3]u8,
     52 	obuf: [4]u8,
     53 	iavail: u8,
     54 	oavail: u8,
     55 };
     56 
     57 const encoder_vtable: io::vtable = io::vtable {
     58 	writer = &encode_writer,
     59 	closer = &encode_closer,
     60 	...
     61 };
     62 
     63 // Creates a stream that encodes writes as base64 before writing them to a
     64 // secondary stream. Afterwards [[io::close]] must be called to write any
     65 // unwritten bytes, in case of padding. Closing this stream will not close the
     66 // underlying stream. After a write returns an error, the stream must not be
     67 // written to again or closed.
     68 export fn newencoder(
     69 	enc: *encoding,
     70 	out: io::handle,
     71 ) encoder = {
     72 	return encoder {
     73 		stream = &encoder_vtable,
     74 		out = out,
     75 		enc = enc,
     76 		...
     77 	};
     78 };
     79 
     80 fn encode_writer(
     81 	s: *io::stream,
     82 	in: const []u8
     83 ) (size | io::error) = {
     84 	let s = s: *encoder;
     85 	let i = 0z;
     86 	for (i < len(in)) {
     87 		let b = s.ibuf[..];
     88 		// fill ibuf
     89 		for (let j = s.iavail; j < 3 && i < len(in); j += 1) {
     90 			b[j] = in[i];
     91 			i += 1;
     92 			s.iavail += 1;
     93 		};
     94 
     95 		if (s.iavail != 3) {
     96 			return i;
     97 		};
     98 
     99 		fillobuf(s);
    100 
    101 		match (writeavail(s)) {
    102 		case let e: io::error =>
    103 			if (i == 0) {
    104 				return e;
    105 			};
    106 			return i;
    107 		case void => void;
    108 		};
    109 	};
    110 
    111 	return i;
    112 };
    113 
    114 fn fillobuf(s: *encoder) void = {
    115 	assert(s.iavail == 3);
    116 	let b = s.ibuf[..];
    117 	s.obuf[..] = [
    118 		s.enc.encmap[b[0] >> 2],
    119 		s.enc.encmap[(b[0] & 0x3) << 4 | b[1] >> 4],
    120 		s.enc.encmap[(b[1] & 0xf) << 2 | b[2] >> 6],
    121 		s.enc.encmap[b[2] & 0x3f],
    122 	][..];
    123 	s.oavail = 4;
    124 };
    125 
    126 fn writeavail(s: *encoder) (void | io::error) = {
    127 	if (s.oavail == 0) {
    128 		return;
    129 	};
    130 
    131 	for (s.oavail > 0) {
    132 		let n = io::write(s.out, s.obuf[len(s.obuf) - s.oavail..])?;
    133 		s.oavail -= n: u8;
    134 	};
    135 
    136 	if (s.oavail == 0) {
    137 		s.iavail = 0;
    138 	};
    139 };
    140 
    141 // Flushes pending writes to the underlying stream.
    142 fn encode_closer(s: *io::stream) (void | io::error) = {
    143 	let s = s: *encoder;
    144 	let finished = false;
    145 	defer if (finished) clear(s);
    146 
    147 	if (s.oavail > 0) {
    148 		for (s.oavail > 0) {
    149 			writeavail(s)?;
    150 		};
    151 		finished = true;
    152 		return;
    153 	};
    154 
    155 	if (s.iavail == 0) {
    156 		finished = true;
    157 		return;
    158 	};
    159 
    160 	// prepare padding as input length was not a multiple of 3
    161 	//                        0  1  2
    162 	static const npa: []u8 = [0, 2, 1];
    163 	const np = npa[s.iavail];
    164 
    165 	for (let i = s.iavail; i < 3; i += 1) {
    166 		s.ibuf[i] = 0;
    167 		s.iavail += 1;
    168 	};
    169 
    170 	fillobuf(s);
    171 	for (let i = 0z; i < np; i += 1) {
    172 		s.obuf[3 - i] = PADDING;
    173 	};
    174 
    175 	for (s.oavail > 0) {
    176 		writeavail(s)?;
    177 	};
    178 	finished = true;
    179 };
    180 
    181 fn clear(e: *encoder) void = {
    182 	bytes::zero(e.ibuf);
    183 	bytes::zero(e.obuf);
    184 };
    185 
    186 @test fn partialwrite() void = {
    187 	const raw: [_]u8 = [
    188 		0x00, 0x00, 0x00, 0x07, 0x73, 0x73, 0x68, 0x2d, 0x72, 0x73,
    189 		0x61, 0x00,
    190 	];
    191 	const expected: str = `AAAAB3NzaC1yc2EA`;
    192 
    193 	let buf = memio::dynamic();
    194 	let e = newencoder(&std_encoding, &buf);
    195 	io::writeall(&e, raw[..4])!;
    196 	io::writeall(&e, raw[4..11])!;
    197 	io::writeall(&e, raw[11..])!;
    198 	io::close(&e)!;
    199 
    200 	assert(memio::string(&buf)! == expected);
    201 
    202 	let encb = memio::buffer(&buf);
    203 	free(encb);
    204 };
    205 
    206 // Encodes a byte slice in base 64, using the given encoding, returning a slice
    207 // of ASCII bytes. The caller must free the return value.
    208 export fn encodeslice(enc: *encoding, in: []u8) []u8 = {
    209 	let out = memio::dynamic();
    210 	let encoder = newencoder(enc, &out);
    211 	io::writeall(&encoder, in)!;
    212 	io::close(&encoder)!;
    213 	return memio::buffer(&out);
    214 };
    215 
    216 // Encodes base64 data using the given alphabet and writes it to a stream,
    217 // returning the number of bytes of data written (i.e. len(buf)).
    218 export fn encode(
    219 	out: io::handle,
    220 	enc: *encoding,
    221 	buf: []u8,
    222 ) (size | io::error) = {
    223 	const enc = newencoder(enc, out);
    224 	match (io::writeall(&enc, buf)) {
    225 	case let z: size =>
    226 		io::close(&enc)?;
    227 		return z;
    228 	case let err: io::error =>
    229 		clear(&enc);
    230 		return err;
    231 	};
    232 };
    233 
    234 // Encodes a byte slice in base 64, using the given encoding, returning a
    235 // string. The caller must free the return value.
    236 export fn encodestr(enc: *encoding, in: []u8) str = {
    237 	return strings::fromutf8(encodeslice(enc, in))!;
    238 };
    239 
    240 @test fn encode() void = {
    241 	// RFC 4648 test vectors
    242 	const in: [_]u8 = ['f', 'o', 'o', 'b', 'a', 'r'];
    243 	const expect: [_]str = [
    244 		"",
    245 		"Zg==",
    246 		"Zm8=",
    247 		"Zm9v",
    248 		"Zm9vYg==",
    249 		"Zm9vYmE=",
    250 		"Zm9vYmFy"
    251 	];
    252 	for (let i = 0z; i <= len(in); i += 1) {
    253 		let out = memio::dynamic();
    254 		let encoder = newencoder(&std_encoding, &out);
    255 		io::writeall(&encoder, in[..i])!;
    256 		io::close(&encoder)!;
    257 		let encb = memio::buffer(&out);
    258 		defer free(encb);
    259 		assert(bytes::equal(encb, strings::toutf8(expect[i])));
    260 
    261 		// Testing encodestr should cover encodeslice too
    262 		let s = encodestr(&std_encoding, in[..i]);
    263 		defer free(s);
    264 		assert(s == expect[i]);
    265 	};
    266 };
    267 
    268 export type decoder = struct {
    269 	stream: io::stream,
    270 	in: io::handle,
    271 	enc: *encoding,
    272 	obuf: [3]u8, // leftover decoded output
    273 	ibuf: [4]u8,
    274 	iavail: u8,
    275 	oavail: u8,
    276 	pad: bool, // if padding was seen in a previous read
    277 };
    278 
    279 const decoder_vtable: io::vtable = io::vtable {
    280 	reader = &decode_reader,
    281 	...
    282 };
    283 
    284 // Creates a stream that reads and decodes base 64 data from a secondary stream.
    285 // This stream does not need to be closed, and closing it will not close the
    286 // underlying stream. If a read returns an error, the stream must not be read
    287 // from again.
    288 export fn newdecoder(
    289 	enc: *encoding,
    290 	in: io::handle,
    291 ) decoder = {
    292 	return decoder {
    293 		stream = &decoder_vtable,
    294 		in = in,
    295 		enc = enc,
    296 		...
    297 	};
    298 };
    299 
    300 fn decode_reader(
    301 	s: *io::stream,
    302 	out: []u8
    303 ) (size | io::EOF | io::error) = {
    304 	let s = s: *decoder;
    305 	if (len(out) == 0) {
    306 		return 0z;
    307 	};
    308 	let n = 0z;
    309 	if (s.oavail != 0) {
    310 		if (len(out) <= s.oavail) {
    311 			out[..] = s.obuf[..len(out)];
    312 			s.obuf[..len(s.obuf) - len(out)] = s.obuf[len(out)..];
    313 			s.oavail = s.oavail - len(out): u8;
    314 			return len(out);
    315 		};
    316 		n = s.oavail;
    317 		s.oavail = 0;
    318 		out[..n] = s.obuf[..n];
    319 		out = out[n..];
    320 	};
    321 	let buf: [os::BUFSZ]u8 = [0...];
    322 	buf[..s.iavail] = s.ibuf[..s.iavail];
    323 
    324 	let want = encodedsize(len(out));
    325 	let nr = s.iavail: size;
    326 	let lim = if (want > len(buf)) len(buf) else want;
    327 	match (io::readall(s.in, buf[s.iavail..lim])) {
    328 	case let n: size =>
    329 		nr += n;
    330 	case io::EOF =>
    331 		return if (s.iavail != 0) errors::invalid
    332 			else if (n != 0) n
    333 			else io::EOF;
    334 	case let err: io::error =>
    335 		if (!(err is io::underread)) {
    336 			return err;
    337 		};
    338 		nr += err: io::underread;
    339 	};
    340 	if (s.pad) {
    341 		return errors::invalid;
    342 	};
    343 	s.iavail = nr: u8 % 4;
    344 	s.ibuf[..s.iavail] = buf[nr - s.iavail..nr];
    345 	nr -= s.iavail;
    346 	if (nr == 0) {
    347 		return 0z;
    348 	};
    349 	// Validating read buffer
    350 	let np = 0z; // Number of padding chars.
    351 	for (let i = 0z; i < nr; i += 1) {
    352 		if (buf[i] == PADDING) {
    353 			for (i + np < nr; np += 1) {
    354 				if (np > 2 || buf[i + np] != PADDING) {
    355 					return errors::invalid;
    356 				};
    357 			};
    358 			s.pad = true;
    359 			break;
    360 		};
    361 		if (!ascii::valid(buf[i]: u32: rune) || s.enc.decmap[buf[i]] == -1) {
    362 			return errors::invalid;
    363 		};
    364 		buf[i] = s.enc.decmap[buf[i]];
    365 	};
    366 
    367 	if (nr / 4 * 3 - np < len(out)) {
    368 		out = out[..nr / 4 * 3 - np];
    369 	};
    370 	let i = 0z, j = 0z;
    371 	nr -= 4;
    372 	for (i < nr) {
    373 		out[j    ] = buf[i    ] << 2 | buf[i + 1] >> 4;
    374 		out[j + 1] = buf[i + 1] << 4 | buf[i + 2] >> 2;
    375 		out[j + 2] = buf[i + 2] << 6 | buf[i + 3];
    376 
    377 		i += 4;
    378 		j += 3;
    379 	};
    380 	s.obuf = [
    381 		buf[i    ] << 2 | buf[i + 1] >> 4,
    382 		buf[i + 1] << 4 | buf[i + 2] >> 2,
    383 		buf[i + 2] << 6 | buf[i + 3],
    384 	];
    385 	out[j..] = s.obuf[..len(out) - j];
    386 	s.oavail = (len(s.obuf) - (len(out) - j)): u8;
    387 	s.obuf[..s.oavail] = s.obuf[len(s.obuf) - s.oavail..];
    388 	s.oavail -= np: u8;
    389 	return n + len(out);
    390 };
    391 
    392 // Decodes a byte slice of ASCII-encoded base 64 data, using the given encoding,
    393 // returning a slice of decoded bytes. The caller must free the return value.
    394 export fn decodeslice(
    395 	enc: *encoding,
    396 	in: []u8,
    397 ) ([]u8 | errors::invalid) = {
    398 	if (len(in) == 0) {
    399 		return [];
    400 	};
    401 	if (len(in) % 4 != 0) {
    402 		return errors::invalid;
    403 	};
    404 	let ins = memio::fixed(in);
    405 	let decoder = newdecoder(enc, &ins);
    406 	let out = alloc([0u8...], decodedsize(len(in)))!;
    407 	let outs = memio::fixed(out);
    408 	match (io::copy(&outs, &decoder)) {
    409 	case io::error =>
    410 		free(out);
    411 		return errors::invalid;
    412 	case let sz: size =>
    413 		return memio::buffer(&outs)[..sz];
    414 	};
    415 };
    416 
    417 // Decodes a string of ASCII-encoded base 64 data, using the given encoding,
    418 // returning a slice of decoded bytes. The caller must free the return value.
    419 export fn decodestr(enc: *encoding, in: str) ([]u8 | errors::invalid) = {
    420 	return decodeslice(enc, strings::toutf8(in));
    421 };
    422 
    423 // Decodes base64 data from a stream using the given alphabet, returning the
    424 // number of bytes of bytes read (i.e. len(buf)).
    425 export fn decode(
    426 	in: io::handle,
    427 	enc: *encoding,
    428 	buf: []u8,
    429 ) (size | io::EOF | io::error) = {
    430 	const dec = newdecoder(enc, in);
    431 	return io::readall(&dec, buf);
    432 };
    433 
    434 @test fn decode() void = {
    435 	// RFC 4648 test vectors
    436 	const cases: [_](str, str, *encoding) = [
    437 		("", "", &std_encoding),
    438 		("Zg==", "f", &std_encoding),
    439 		("Zm8=", "fo", &std_encoding),
    440 		("Zm9v", "foo", &std_encoding),
    441 		("Zm9vYg==", "foob", &std_encoding),
    442 		("Zm9vYmE=", "fooba", &std_encoding),
    443 		("Zm9vYmFy", "foobar", &std_encoding),
    444 	];
    445 	const invalid: [_](str, *encoding) = [
    446 		// invalid padding
    447 		("=", &std_encoding),
    448 		("==", &std_encoding),
    449 	        ("===", &std_encoding),
    450 	        ("=====", &std_encoding),
    451 	        ("======", &std_encoding),
    452 	        // invalid characters
    453 	        ("@Zg=", &std_encoding),
    454 	        ("ê==", &std_encoding),
    455 	        ("êg==", &std_encoding),
    456 		("$3d==", &std_encoding),
    457 		("%3d==", &std_encoding),
    458 		("[==", &std_encoding),
    459 		("!", &std_encoding),
    460 	        // data after padding is encountered
    461 	        ("Zg===", &std_encoding),
    462 	        ("Zg====", &std_encoding),
    463 	        ("Zg==Zg==", &std_encoding),
    464 	        ("Zm8=Zm8=", &std_encoding),
    465 	];
    466 	let buf: [12]u8 = [0...];
    467 	for (let bufsz = 1z; bufsz <= 12; bufsz += 1) {
    468 		for (let (input, expected, encoding) .. cases) {
    469 			let in = memio::fixed(strings::toutf8(input));
    470 			let decoder = newdecoder(encoding, &in);
    471 			let buf = buf[..bufsz];
    472 			let decb: []u8 = [];
    473 			defer free(decb);
    474 			for (true) match (io::read(&decoder, buf)!) {
    475 			case let z: size =>
    476 				if (z > 0) {
    477 					append(decb, buf[..z]...)!;
    478 				};
    479 			case io::EOF =>
    480 				break;
    481 			};
    482 			assert(bytes::equal(decb, strings::toutf8(expected)));
    483 
    484 			// Testing decodestr should cover decodeslice too
    485 			let decb = decodestr(encoding, input) as []u8;
    486 			defer free(decb);
    487 			assert(bytes::equal(decb, strings::toutf8(expected)));
    488 		};
    489 
    490 		for (let (input, encoding) .. invalid) {
    491 			let in = memio::fixed(strings::toutf8(input));
    492 			let decoder = newdecoder(encoding, &in);
    493 			let buf = buf[..bufsz];
    494 			let valid = false;
    495 			for (true) match(io::read(&decoder, buf)) {
    496 			case errors::invalid =>
    497 				break;
    498 			case size =>
    499 				void;
    500 			case io::EOF =>
    501 				abort();
    502 			};
    503 
    504 			// Testing decodestr should cover decodeslice too
    505 			assert(decodestr(encoding, input) is errors::invalid);
    506 		};
    507 	};
    508 };
    509 
    510 // Given the length of the message, returns the size of its base64 encoding
    511 export fn encodedsize(sz: size) size = if (sz == 0) 0 else ((sz - 1)/ 3 + 1) * 4;
    512 
    513 // Given the size of base64 encoded data, returns maximal length of decoded message.
    514 // The message may be at most 2 bytes shorter than the returned value. Input
    515 // size must be a multiple of 4.
    516 export fn decodedsize(sz: size) size = {
    517 	assert(sz % 4 == 0);
    518 	return sz / 4 * 3;
    519 };
    520 
    521 @test fn sizecalc() void = {
    522 	let enc: [_](size, size) = [(1, 4), (2, 4), (3, 4), (4, 8), (10, 16),
    523 		(119, 160), (120, 160), (121, 164), (122, 164), (123, 164)
    524 	];
    525 	assert(encodedsize(0) == 0 && decodedsize(0) == 0);
    526 	for (let i = 0z; i < len(enc); i += 1) {
    527 		let (decoded, encoded) = enc[i];
    528 		assert(encodedsize(decoded) == encoded);
    529 		assert(decodedsize(encoded) == ((decoded - 1) / 3 + 1) * 3);
    530 	};
    531 };