hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

base64.ha (12956B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 use ascii;
      5 use bytes;
      6 use errors;
      7 use io;
      8 use memio;
      9 use os;
     10 use strings;
     11 
     12 def PADDING: u8 = '=';
     13 
     14 export type encoding = struct {
     15 	encmap: [64]u8,
     16 	decmap: [128]u8,
     17 };
     18 
     19 // Represents the standard base-64 encoding alphabet as defined in RFC 4648.
     20 export const std_encoding: encoding = encoding { ... };
     21 
     22 // Represents the "base64url" alphabet as defined in RFC 4648, suitable for use
     23 // in URLs and file paths.
     24 export const url_encoding: encoding = encoding { ... };
     25 
     26 // Initializes a new encoding based on the passed alphabet, which must be a
     27 // 64-byte ASCII string.
     28 export fn encoding_init(enc: *encoding, alphabet: str) void = {
     29 	const alphabet = strings::toutf8(alphabet);
     30 	enc.decmap[..] = [-1...];
     31 	assert(len(alphabet) == 64);
     32 	for (let i: u8 = 0; i < 64; i += 1) {
     33 		const ch = alphabet[i];
     34 		assert(ascii::valid(ch: rune) && enc.decmap[ch] == -1);
     35 		enc.encmap[i] = ch;
     36 		enc.decmap[ch] = i;
     37 	};
     38 };
     39 
     40 @init fn init() void = {
     41 	const std_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
     42 	const url_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
     43 	encoding_init(&std_encoding, std_alpha);
     44 	encoding_init(&url_encoding, url_alpha);
     45 };
     46 
     47 export type encoder = struct {
     48 	stream: io::stream,
     49 	out: io::handle,
     50 	enc: *encoding,
     51 	ibuf: [3]u8,
     52 	obuf: [4]u8,
     53 	iavail: u8,
     54 	oavail: u8,
     55 };
     56 
     57 const encoder_vtable: io::vtable = io::vtable {
     58 	writer = &encode_writer,
     59 	closer = &encode_closer,
     60 	...
     61 };
     62 
     63 // Creates a stream that encodes writes as base64 before writing them to a
     64 // secondary stream. Afterwards [[io::close]] must be called to write any
     65 // unwritten bytes, in case of padding. Closing this stream will not close the
     66 // underlying stream. After a write returns an error, the stream must not be
     67 // written to again or closed.
     68 export fn newencoder(
     69 	enc: *encoding,
     70 	out: io::handle,
     71 ) encoder = {
     72 	return encoder {
     73 		stream = &encoder_vtable,
     74 		out = out,
     75 		enc = enc,
     76 		...
     77 	};
     78 };
     79 
     80 fn encode_writer(
     81 	s: *io::stream,
     82 	in: const []u8
     83 ) (size | io::error) = {
     84 	let s = s: *encoder;
     85 	let i = 0z;
     86 	for (i < len(in)) {
     87 		let b = s.ibuf[..];
     88 		// fill ibuf
     89 		for (let j = s.iavail; j < 3 && i < len(in); j += 1) {
     90 			b[j] = in[i];
     91 			i += 1;
     92 			s.iavail += 1;
     93 		};
     94 
     95 		if (s.iavail != 3) {
     96 			return i;
     97 		};
     98 
     99 		fillobuf(s);
    100 
    101 		match (writeavail(s)) {
    102 		case let e: io::error =>
    103 			if (i == 0) {
    104 				return e;
    105 			};
    106 			return i;
    107 		case void => void;
    108 		};
    109 	};
    110 
    111 	return i;
    112 };
    113 
    114 fn fillobuf(s: *encoder) void = {
    115 	assert(s.iavail == 3);
    116 	let b = s.ibuf[..];
    117 	s.obuf[..] = [
    118 		s.enc.encmap[b[0] >> 2],
    119 		s.enc.encmap[(b[0] & 0x3) << 4 | b[1] >> 4],
    120 		s.enc.encmap[(b[1] & 0xf) << 2 | b[2] >> 6],
    121 		s.enc.encmap[b[2] & 0x3f],
    122 	][..];
    123 	s.oavail = 4;
    124 };
    125 
    126 fn writeavail(s: *encoder) (void | io::error) = {
    127 	if (s.oavail == 0) {
    128 		return;
    129 	};
    130 
    131 	for (s.oavail > 0) {
    132 		let n = io::write(s.out, s.obuf[len(s.obuf) - s.oavail..])?;
    133 		s.oavail -= n: u8;
    134 	};
    135 
    136 	if (s.oavail == 0) {
    137 		s.iavail = 0;
    138 	};
    139 };
    140 
    141 // Flushes pending writes to the underlying stream.
    142 fn encode_closer(s: *io::stream) (void | io::error) = {
    143 	let s = s: *encoder;
    144 	let finished = false;
    145 	defer if (finished) clear(s);
    146 
    147 	if (s.oavail > 0) {
    148 		for (s.oavail > 0) {
    149 			writeavail(s)?;
    150 		};
    151 		finished = true;
    152 		return;
    153 	};
    154 
    155 	if (s.iavail == 0) {
    156 		finished = true;
    157 		return;
    158 	};
    159 
    160 	// prepare padding as input length was not a multiple of 3
    161 	//                        0  1  2
    162 	static const npa: []u8 = [0, 2, 1];
    163 	const np = npa[s.iavail];
    164 
    165 	for (let i = s.iavail; i < 3; i += 1) {
    166 		s.ibuf[i] = 0;
    167 		s.iavail += 1;
    168 	};
    169 
    170 	fillobuf(s);
    171 	for (let i = 0z; i < np; i += 1) {
    172 		s.obuf[3 - i] = PADDING;
    173 	};
    174 
    175 	for (s.oavail > 0) {
    176 		writeavail(s)?;
    177 	};
    178 	finished = true;
    179 };
    180 
    181 fn clear(e: *encoder) void = {
    182 	bytes::zero(e.ibuf);
    183 	bytes::zero(e.obuf);
    184 };
    185 
    186 @test fn partialwrite() void = {
    187 	const raw: [_]u8 = [
    188 		0x00, 0x00, 0x00, 0x07, 0x73, 0x73, 0x68, 0x2d, 0x72, 0x73,
    189 		0x61, 0x00,
    190 	];
    191 	const expected: str = `AAAAB3NzaC1yc2EA`;
    192 
    193 	let buf = memio::dynamic();
    194 	let e = newencoder(&std_encoding, &buf);
    195 	io::writeall(&e, raw[..4])!;
    196 	io::writeall(&e, raw[4..11])!;
    197 	io::writeall(&e, raw[11..])!;
    198 	io::close(&e)!;
    199 
    200 	assert(memio::string(&buf)! == expected);
    201 };
    202 
    203 // Encodes a byte slice in base 64, using the given encoding, returning a slice
    204 // of ASCII bytes. The caller must free the return value.
    205 export fn encodeslice(enc: *encoding, in: []u8) []u8 = {
    206 	let out = memio::dynamic();
    207 	let encoder = newencoder(enc, &out);
    208 	io::writeall(&encoder, in)!;
    209 	io::close(&encoder)!;
    210 	return memio::buffer(&out);
    211 };
    212 
    213 // Encodes base64 data using the given alphabet and writes it to a stream,
    214 // returning the number of bytes of data written (i.e. len(buf)).
    215 export fn encode(
    216 	out: io::handle,
    217 	enc: *encoding,
    218 	buf: []u8,
    219 ) (size | io::error) = {
    220 	const enc = newencoder(enc, out);
    221 	match (io::writeall(&enc, buf)) {
    222 	case let z: size =>
    223 		io::close(&enc)?;
    224 		return z;
    225 	case let err: io::error =>
    226 		clear(&enc);
    227 		return err;
    228 	};
    229 };
    230 
    231 // Encodes a byte slice in base 64, using the given encoding, returning a
    232 // string. The caller must free the return value.
    233 export fn encodestr(enc: *encoding, in: []u8) str = {
    234 	return strings::fromutf8(encodeslice(enc, in))!;
    235 };
    236 
    237 @test fn encode() void = {
    238 	// RFC 4648 test vectors
    239 	const in: [_]u8 = ['f', 'o', 'o', 'b', 'a', 'r'];
    240 	const expect: [_]str = [
    241 		"",
    242 		"Zg==",
    243 		"Zm8=",
    244 		"Zm9v",
    245 		"Zm9vYg==",
    246 		"Zm9vYmE=",
    247 		"Zm9vYmFy"
    248 	];
    249 	for (let i = 0z; i <= len(in); i += 1) {
    250 		let out = memio::dynamic();
    251 		let encoder = newencoder(&std_encoding, &out);
    252 		io::writeall(&encoder, in[..i])!;
    253 		io::close(&encoder)!;
    254 		let encb = memio::buffer(&out);
    255 		defer free(encb);
    256 		assert(bytes::equal(encb, strings::toutf8(expect[i])));
    257 
    258 		// Testing encodestr should cover encodeslice too
    259 		let s = encodestr(&std_encoding, in[..i]);
    260 		defer free(s);
    261 		assert(s == expect[i]);
    262 	};
    263 };
    264 
    265 export type decoder = struct {
    266 	stream: io::stream,
    267 	in: io::handle,
    268 	enc: *encoding,
    269 	obuf: [3]u8, // leftover decoded output
    270 	ibuf: [4]u8,
    271 	iavail: u8,
    272 	oavail: u8,
    273 	pad: bool, // if padding was seen in a previous read
    274 };
    275 
    276 const decoder_vtable: io::vtable = io::vtable {
    277 	reader = &decode_reader,
    278 	...
    279 };
    280 
    281 // Creates a stream that reads and decodes base 64 data from a secondary stream.
    282 // This stream does not need to be closed, and closing it will not close the
    283 // underlying stream. If a read returns an error, the stream must not be read
    284 // from again.
    285 export fn newdecoder(
    286 	enc: *encoding,
    287 	in: io::handle,
    288 ) decoder = {
    289 	return decoder {
    290 		stream = &decoder_vtable,
    291 		in = in,
    292 		enc = enc,
    293 		...
    294 	};
    295 };
    296 
    297 fn decode_reader(
    298 	s: *io::stream,
    299 	out: []u8
    300 ) (size | io::EOF | io::error) = {
    301 	let s = s: *decoder;
    302 	if (len(out) == 0) {
    303 		return 0z;
    304 	};
    305 	let n = 0z;
    306 	if (s.oavail != 0) {
    307 		if (len(out) <= s.oavail) {
    308 			out[..] = s.obuf[..len(out)];
    309 			s.obuf[..len(s.obuf) - len(out)] = s.obuf[len(out)..];
    310 			s.oavail = s.oavail - len(out): u8;
    311 			return len(out);
    312 		};
    313 		n = s.oavail;
    314 		s.oavail = 0;
    315 		out[..n] = s.obuf[..n];
    316 		out = out[n..];
    317 	};
    318 	let buf: [os::BUFSZ]u8 = [0...];
    319 	buf[..s.iavail] = s.ibuf[..s.iavail];
    320 
    321 	let want = encodedsize(len(out));
    322 	let nr = s.iavail: size;
    323 	let lim = if (want > len(buf)) len(buf) else want;
    324 	match (io::readall(s.in, buf[s.iavail..lim])) {
    325 	case let n: size =>
    326 		nr += n;
    327 	case io::EOF =>
    328 		return if (s.iavail != 0) errors::invalid
    329 			else if (n != 0) n
    330 			else io::EOF;
    331 	case let err: io::error =>
    332 		if (!(err is io::underread)) {
    333 			return err;
    334 		};
    335 		nr += err: io::underread;
    336 	};
    337 	if (s.pad) {
    338 		return errors::invalid;
    339 	};
    340 	s.iavail = nr: u8 % 4;
    341 	s.ibuf[..s.iavail] = buf[nr - s.iavail..nr];
    342 	nr -= s.iavail;
    343 	if (nr == 0) {
    344 		return 0z;
    345 	};
    346 	// Validating read buffer
    347 	let np = 0z; // Number of padding chars.
    348 	for (let i = 0z; i < nr; i += 1) {
    349 		if (buf[i] == PADDING) {
    350 			for (i + np < nr; np += 1) {
    351 				if (np > 2 || buf[i + np] != PADDING) {
    352 					return errors::invalid;
    353 				};
    354 			};
    355 			s.pad = true;
    356 			break;
    357 		};
    358 		if (!ascii::valid(buf[i]: u32: rune) || s.enc.decmap[buf[i]] == -1) {
    359 			return errors::invalid;
    360 		};
    361 		buf[i] = s.enc.decmap[buf[i]];
    362 	};
    363 
    364 	if (nr / 4 * 3 - np < len(out)) {
    365 		out = out[..nr / 4 * 3 - np];
    366 	};
    367 	let i = 0z, j = 0z;
    368 	nr -= 4;
    369 	for (i < nr) {
    370 		out[j    ] = buf[i    ] << 2 | buf[i + 1] >> 4;
    371 		out[j + 1] = buf[i + 1] << 4 | buf[i + 2] >> 2;
    372 		out[j + 2] = buf[i + 2] << 6 | buf[i + 3];
    373 
    374 		i += 4;
    375 		j += 3;
    376 	};
    377 	s.obuf = [
    378 		buf[i    ] << 2 | buf[i + 1] >> 4,
    379 		buf[i + 1] << 4 | buf[i + 2] >> 2,
    380 		buf[i + 2] << 6 | buf[i + 3],
    381 	];
    382 	out[j..] = s.obuf[..len(out) - j];
    383 	s.oavail = (len(s.obuf) - (len(out) - j)): u8;
    384 	s.obuf[..s.oavail] = s.obuf[len(s.obuf) - s.oavail..];
    385 	s.oavail -= np: u8;
    386 	return n + len(out);
    387 };
    388 
    389 // Decodes a byte slice of ASCII-encoded base 64 data, using the given encoding,
    390 // returning a slice of decoded bytes. The caller must free the return value.
    391 export fn decodeslice(
    392 	enc: *encoding,
    393 	in: []u8,
    394 ) ([]u8 | errors::invalid) = {
    395 	if (len(in) == 0) {
    396 		return [];
    397 	};
    398 	if (len(in) % 4 != 0) {
    399 		return errors::invalid;
    400 	};
    401 	let ins = memio::fixed(in);
    402 	let decoder = newdecoder(enc, &ins);
    403 	let out = alloc([0u8...], decodedsize(len(in)));
    404 	let outs = memio::fixed(out);
    405 	match (io::copy(&outs, &decoder)) {
    406 	case io::error =>
    407 		free(out);
    408 		return errors::invalid;
    409 	case let sz: size =>
    410 		return memio::buffer(&outs)[..sz];
    411 	};
    412 };
    413 
    414 // Decodes a string of ASCII-encoded base 64 data, using the given encoding,
    415 // returning a slice of decoded bytes. The caller must free the return value.
    416 export fn decodestr(enc: *encoding, in: str) ([]u8 | errors::invalid) = {
    417 	return decodeslice(enc, strings::toutf8(in));
    418 };
    419 
    420 // Decodes base64 data from a stream using the given alphabet, returning the
    421 // number of bytes of bytes read (i.e. len(buf)).
    422 export fn decode(
    423 	in: io::handle,
    424 	enc: *encoding,
    425 	buf: []u8,
    426 ) (size | io::EOF | io::error) = {
    427 	const dec = newdecoder(enc, in);
    428 	return io::readall(&dec, buf);
    429 };
    430 
    431 @test fn decode() void = {
    432 	// RFC 4648 test vectors
    433 	const cases: [_](str, str, *encoding) = [
    434 		("", "", &std_encoding),
    435 		("Zg==", "f", &std_encoding),
    436 		("Zm8=", "fo", &std_encoding),
    437 		("Zm9v", "foo", &std_encoding),
    438 		("Zm9vYg==", "foob", &std_encoding),
    439 		("Zm9vYmE=", "fooba", &std_encoding),
    440 		("Zm9vYmFy", "foobar", &std_encoding),
    441 	];
    442 	const invalid: [_](str, *encoding) = [
    443 		// invalid padding
    444 		("=", &std_encoding),
    445 		("==", &std_encoding),
    446 	        ("===", &std_encoding),
    447 	        ("=====", &std_encoding),
    448 	        ("======", &std_encoding),
    449 	        // invalid characters
    450 	        ("@Zg=", &std_encoding),
    451 	        ("ê==", &std_encoding),
    452 	        ("êg==", &std_encoding),
    453 		("$3d==", &std_encoding),
    454 		("%3d==", &std_encoding),
    455 		("[==", &std_encoding),
    456 		("!", &std_encoding),
    457 	        // data after padding is encountered
    458 	        ("Zg===", &std_encoding),
    459 	        ("Zg====", &std_encoding),
    460 	        ("Zg==Zg==", &std_encoding),
    461 	        ("Zm8=Zm8=", &std_encoding),
    462 	];
    463 	let buf: [12]u8 = [0...];
    464 	for (let bufsz = 1z; bufsz <= 12; bufsz += 1) {
    465 		for (let (input, expected, encoding) .. cases) {
    466 			let in = memio::fixed(strings::toutf8(input));
    467 			let decoder = newdecoder(encoding, &in);
    468 			let buf = buf[..bufsz];
    469 			let decb: []u8 = [];
    470 			defer free(decb);
    471 			for (true) match (io::read(&decoder, buf)!) {
    472 			case let z: size =>
    473 				if (z > 0) {
    474 					append(decb, buf[..z]...);
    475 				};
    476 			case io::EOF =>
    477 				break;
    478 			};
    479 			assert(bytes::equal(decb, strings::toutf8(expected)));
    480 
    481 			// Testing decodestr should cover decodeslice too
    482 			let decb = decodestr(encoding, input) as []u8;
    483 			defer free(decb);
    484 			assert(bytes::equal(decb, strings::toutf8(expected)));
    485 		};
    486 
    487 		for (let (input, encoding) .. invalid) {
    488 			let in = memio::fixed(strings::toutf8(input));
    489 			let decoder = newdecoder(encoding, &in);
    490 			let buf = buf[..bufsz];
    491 			let valid = false;
    492 			for (true) match(io::read(&decoder, buf)) {
    493 			case errors::invalid =>
    494 				break;
    495 			case size =>
    496 				void;
    497 			case io::EOF =>
    498 				abort();
    499 			};
    500 
    501 			// Testing decodestr should cover decodeslice too
    502 			assert(decodestr(encoding, input) is errors::invalid);
    503 		};
    504 	};
    505 };
    506 
    507 // Given the length of the message, returns the size of its base64 encoding
    508 export fn encodedsize(sz: size) size = if (sz == 0) 0 else ((sz - 1)/ 3 + 1) * 4;
    509 
    510 // Given the size of base64 encoded data, returns maximal length of decoded message.
    511 // The message may be at most 2 bytes shorter than the returned value. Input
    512 // size must be a multiple of 4.
    513 export fn decodedsize(sz: size) size = {
    514 	assert(sz % 4 == 0);
    515 	return sz / 4 * 3;
    516 };
    517 
    518 @test fn sizecalc() void = {
    519 	let enc: [_](size, size) = [(1, 4), (2, 4), (3, 4), (4, 8), (10, 16),
    520 		(119, 160), (120, 160), (121, 164), (122, 164), (123, 164)
    521 	];
    522 	assert(encodedsize(0) == 0 && decodedsize(0) == 0);
    523 	for (let i = 0z; i < len(enc); i += 1) {
    524 		let (decoded, encoded) = enc[i];
    525 		assert(encodedsize(decoded) == encoded);
    526 		assert(decodedsize(encoded) == ((decoded - 1) / 3 + 1) * 3);
    527 	};
    528 };