hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

base64.ha (12131B)


      1 // License: MPL-2.0
      2 // (c) 2022 Ajay R <ar324@protonmail.com>
      3 // (c) 2021 Drew DeVault <sir@cmpwn.com>
      4 // (c) 2021 Eyal Sawady <ecs@d2evs.net>
      5 // (c) 2021 Steven Guikal <void@fluix.one>
      6 // (c) 2021 Thomas Bracht Laumann Jespersen <t@laumann.xyz>
      7 use ascii;
      8 use bufio;
      9 use bytes;
     10 use errors;
     11 use io;
     12 use os;
     13 use strings;
     14 
     15 def PADDING: u8 = '=';
     16 
     17 export type encoding = struct {
     18 	encmap: [64]u8,
     19 	decmap: [256]u8,
     20 	valid: [256]bool,
     21 };
     22 
     23 // Represents the standard base-64 encoding alphabet as defined in RFC 4648.
     24 export const std_encoding: encoding = encoding { ... };
     25 
     26 // Represents the "base64url" alphabet as defined in RFC 4648, suitable for use
     27 // in URLs and file paths.
     28 export const url_encoding: encoding = encoding { ... };
     29 
     30 // Initializes a new encoding based on the passed alphabet, which must be a
     31 // 64-byte ASCII string.
     32 export fn encoding_init(enc: *encoding, alphabet: str) void = {
     33 	const alphabet = strings::toutf8(alphabet);
     34 	assert(len(alphabet) == 64);
     35 	for (let i: u8 = 0; i < 64; i += 1) {
     36 		const ch = alphabet[i];
     37 		assert(ascii::valid(ch: u32: rune));
     38 		enc.encmap[i] = ch;
     39 		enc.decmap[ch] = i;
     40 		enc.valid[ch] = true;
     41 	};
     42 };
     43 
     44 @init fn init() void = {
     45 	const std_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
     46 	const url_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
     47 	encoding_init(&std_encoding, std_alpha);
     48 	encoding_init(&url_encoding, url_alpha);
     49 };
     50 
     51 export type encoder = struct {
     52 	stream: io::stream,
     53 	out: io::handle,
     54 	enc: *encoding,
     55 	buf: [2]u8, // leftover input
     56 	avail: size, // bytes available in buf
     57 	err: (void | io::error),
     58 };
     59 
     60 const encoder_vtable: io::vtable = io::vtable {
     61 	writer = &encode_writer,
     62 	closer = &encode_closer,
     63 	...
     64 };
     65 
     66 // Creates a stream that encodes writes as base64 before writing them to a
     67 // secondary stream. The encoder stream must be closed to finalize any unwritten
     68 // bytes. Closing this stream will not close the underlying stream.
     69 export fn newencoder(
     70 	enc: *encoding,
     71 	out: io::handle,
     72 ) encoder = {
     73 	return encoder {
     74 		stream = &encoder_vtable,
     75 		out = out,
     76 		enc = enc,
     77 		err = void,
     78 		...
     79 	};
     80 };
     81 
     82 fn encode_writer(
     83 	s: *io::stream,
     84 	in: const []u8
     85 ) (size | io::error) = {
     86 	let s = s: *encoder;
     87 	match(s.err) {
     88 	case let err: io::error =>
     89 		return err;
     90 	case void =>
     91 		yield;
     92 	};
     93 	let l = len(in);
     94 	let i = 0z;
     95 	for (i + 2 < l + s.avail; i += 3) {
     96 		static let b: [3]u8 = [0...]; // 3 bytes get converted into 4 bytes
     97 		if (i < s.avail) {
     98 			for (let j = 0z; j < s.avail; j += 1) {
     99 				b[j] = s.buf[i];
    100 			};
    101 			for (let j = s.avail; j < 3; j += 1) {
    102 				b[j] = in[j - s.avail];
    103 			};
    104 		} else {
    105 			for (let j = 0z; j < 3; j += 1) {
    106 				b[j] = in[j - s.avail + i];
    107 			};
    108 		};
    109 		let encb: [4]u8 = [
    110 			s.enc.encmap[b[0] >> 2],
    111 			s.enc.encmap[(b[0] & 0x3) << 4 | b[1] >> 4],
    112 			s.enc.encmap[(b[1] & 0xf) << 2 | b[2] >> 6],
    113 			s.enc.encmap[b[2] & 0x3F],
    114 		];
    115 		match(io::write(s.out, encb)) {
    116 		case let err: io::error =>
    117 			s.err = err;
    118 			return err;
    119 		case size =>
    120 			yield;
    121 		};
    122 	};
    123 	// storing leftover bytes
    124 	if (l + s.avail < 3) {
    125 		for (let j = s.avail; j < s.avail + l; j += 1) {
    126 			s.buf[j] = in[j - s.avail];
    127 		};
    128 	} else {
    129 		const begin = (l + s.avail) / 3 * 3;
    130 		for (let j = begin; j < l + s.avail; j += 1) {
    131 			s.buf[j - begin] = in[j - s.avail];
    132 		};
    133 	};
    134 	s.avail = (l + s.avail) % 3;
    135 	return l;
    136 };
    137 
    138 fn encode_closer(s: *io::stream) (void | io::error) = {
    139 	let s = s: *encoder;
    140 	if (s.avail == 0) {
    141 		return;
    142 	};
    143 	static let b: [3]u8 = [0...]; // the 3 bytes that will be encoded into 4 bytes
    144 	for (let i = 0z; i < 3; i += 1) {
    145 		b[i] = if (i < s.avail) s.buf[i] else 0;
    146 	};
    147 	let encb: [4]u8 = [
    148 		s.enc.encmap[b[0] >> 2],
    149 		s.enc.encmap[(b[0] & 0x3) << 4 | b[1] >> 4],
    150 		s.enc.encmap[(b[1] & 0xf) << 2 | b[2] >> 6],
    151 		s.enc.encmap[b[2] & 0x3F],
    152 	];
    153 	// adding padding as input length was not a multiple of 3
    154 	//                        0  1  2
    155 	static const npa: []u8 = [0, 2, 1];
    156 	const np = npa[s.avail];
    157 	for (let i = 0z; i < np; i += 1) {
    158 		encb[3 - i] = PADDING;
    159 	};
    160 	io::writeall(s.out, encb)?;
    161 };
    162 
    163 // Encodes a byte slice in base 64, using the given encoding, returning a slice
    164 // of ASCII bytes. The caller must free the return value.
    165 export fn encodeslice(enc: *encoding, in: []u8) []u8 = {
    166 	let out = bufio::dynamic(io::mode::WRITE);
    167 	let encoder = newencoder(enc, &out);
    168 	io::writeall(&encoder, in)!;
    169 	io::close(&encoder)!;
    170 	return bufio::buffer(&out);
    171 };
    172 
    173 // Encodes base64 data using the given alphabet and writes it to a stream,
    174 // returning the number of bytes of data written (i.e. len(buf)).
    175 export fn encode(
    176 	out: io::handle,
    177 	enc: *encoding,
    178 	buf: []u8,
    179 ) (size | io::error) = {
    180 	const enc = newencoder(enc, out);
    181 	match (io::writeall(&enc, buf)) {
    182 	case let z: size =>
    183 		io::close(&enc)?;
    184 		return z;
    185 	case let err: io::error =>
    186 		io::close(&enc): void;
    187 		return err;
    188 	};
    189 };
    190 
    191 // Encodes a byte slice in base 64, using the given encoding, returning a
    192 // string. The caller must free the return value.
    193 export fn encodestr(enc: *encoding, in: []u8) str = {
    194 	return strings::fromutf8(encodeslice(enc, in));
    195 };
    196 
    197 @test fn encode() void = {
    198 	// RFC 4648 test vectors
    199 	const in: [_]u8 = ['f', 'o', 'o', 'b', 'a', 'r'];
    200 	const expect: [_]str = [
    201 		"",
    202 		"Zg==",
    203 		"Zm8=",
    204 		"Zm9v",
    205 		"Zm9vYg==",
    206 		"Zm9vYmE=",
    207 		"Zm9vYmFy"
    208 	];
    209 	for (let i = 0z; i <= len(in); i += 1) {
    210 		let out = bufio::dynamic(io::mode::WRITE);
    211 		let encoder = newencoder(&std_encoding, &out);
    212 		io::writeall(&encoder, in[..i])!;
    213 		io::close(&encoder)!;
    214 		let encb = bufio::buffer(&out);
    215 		defer free(encb);
    216 		assert(bytes::equal(encb, strings::toutf8(expect[i])));
    217 
    218 		// Testing encodestr should cover encodeslice too
    219 		let s = encodestr(&std_encoding, in[..i]);
    220 		defer free(s);
    221 		assert(s == expect[i]);
    222 	};
    223 };
    224 
    225 export type decoder = struct {
    226 	stream: io::stream,
    227 	in: io::handle,
    228 	enc: *encoding,
    229 	avail: []u8, // leftover decoded output
    230 	pad: bool, // if padding was seen in a previous read
    231 	state: (void | io::EOF | io::error),
    232 };
    233 
    234 const decoder_vtable: io::vtable = io::vtable {
    235 	reader = &decode_reader,
    236 	...
    237 };
    238 
    239 // Creates a stream that reads and decodes base 64 data from a secondary stream.
    240 // This stream does not need to be closed, and closing it will not close the
    241 // underlying stream.
    242 export fn newdecoder(
    243 	enc: *encoding,
    244 	in: io::handle,
    245 ) decoder = {
    246 	return decoder {
    247 		stream = &decoder_vtable,
    248 		in = in,
    249 		enc = enc,
    250 		state = void,
    251 		...
    252 	};
    253 };
    254 
    255 fn decode_reader(
    256 	s: *io::stream,
    257 	out: []u8
    258 ) (size | io::EOF | io::error) = {
    259 	let s = s: *decoder;
    260 	let n = 0z;
    261 	let l = len(out);
    262 	match(s.state) {
    263 	case let err: (io::EOF | io ::error) =>
    264 		return err;
    265 	case void =>
    266 		yield;
    267 	};
    268 	if (len(s.avail) > 0) {
    269 		n += if (l < len(s.avail)) l else len(s.avail);
    270 		out[..n] = s.avail[0..n];
    271 		s.avail = s.avail[n..];
    272 		if (l == n) {
    273 			return n;
    274 		};
    275 	};
    276 	static let buf: [os::BUFSIZ]u8 = [0...];
    277 	static let obuf: [os::BUFSIZ / 4 * 3]u8 = [0...];
    278 	const nn = ((l - n) / 3 + 1) * 4; // 4 extra bytes may be read.
    279 	let nr = 0z;
    280 	for (nr < nn) {
    281 		match (io::read(s.in, buf[nr..])) {
    282 		case let n: size =>
    283 			nr += n;
    284 		case io::EOF =>
    285 			s.state = io::EOF;
    286 			break;
    287 		case let err: io::error =>
    288 			s.state = err;
    289 			return err;
    290 		};
    291 	};
    292 	if (nr % 4 != 0) {
    293 		s.state = errors::invalid;
    294 		return errors::invalid;
    295 	};
    296 	if (nr == 0) { // io::EOF already set
    297 		return n;
    298 	};
    299 	// Validating read buffer
    300 	let valid = true;
    301 	let np = 0; // Number of padding chars.
    302 	let p = true; // Pad allowed in buf
    303 	for (let i = nr; i > 0; i -= 1) {
    304 		const ch = buf[i - 1];
    305 		if (ch == PADDING) {
    306 			if(s.pad || !p) {
    307 				valid = false;
    308 				break;
    309 			};
    310 			np += 1;
    311 		} else {
    312 			if (!s.enc.valid[ch]) {
    313 				valid = false;
    314 				break;
    315 			};
    316 			// Disallow padding on seeing a non-padding char
    317 			p = false;
    318 		};
    319 	};
    320 	valid = valid && np <= 2;
    321 	if (np > 0) {
    322 		s.pad = true;
    323 	};
    324 	if (!valid) {
    325 		s.state = errors::invalid;
    326 		return errors::invalid;
    327 	};
    328 	for (let i = 0z; i < nr; i += 1) {
    329 		buf[i] = s.enc.decmap[buf[i]];
    330 	};
    331 	for (let i = 0z, j = 0z; i < nr) {
    332 		obuf[j] = buf[i] << 2 | buf[i + 1] >> 4;
    333 		obuf[j + 1] = buf[i + 1] << 4 | buf[i + 2] >> 2;
    334 		obuf[j + 2] = buf[i + 2] << 6 | buf[i + 3];
    335 
    336 		i += 4;
    337 		j += 3;
    338 	};
    339 	// Removing bytes added due to padding.
    340 	//                         0  1  2 // np
    341 	static const npr: [3]u8 = [0, 1, 2]; // bytes to discard
    342 	const navl = nr / 4 * 3 - npr[np];
    343 	const rem = if(l - n < navl) l - n else navl;
    344 	for (let i = n; i < n + rem; i += 1) {
    345 		out[i] = obuf[i - n];
    346 	};
    347 	s.avail = obuf[rem..navl];
    348 	return n + rem;
    349 };
    350 
    351 // Decodes a byte slice of ASCII-encoded base 64 data, using the given encoding,
    352 // returning a slice of decoded bytes. The caller must free the return value.
    353 export fn decodeslice(
    354 	enc: *encoding,
    355 	in: []u8,
    356 ) ([]u8 | errors::invalid) = {
    357 	let in = bufio::fixed(in, io::mode::READ);
    358 	let decoder = newdecoder(enc, &in);
    359 	let out = bufio::dynamic(io::mode::WRITE);
    360 	match (io::copy(&out, &decoder)) {
    361 	case io::error =>
    362 		io::close(&out)!;
    363 		return errors::invalid;
    364 	case size =>
    365 		return bufio::buffer(&out);
    366 	};
    367 };
    368 
    369 // Decodes a string of ASCII-encoded base 64 data, using the given encoding,
    370 // returning a slice of decoded bytes. The caller must free the return value.
    371 export fn decodestr(enc: *encoding, in: str) ([]u8 | errors::invalid) = {
    372 	return decodeslice(enc, strings::toutf8(in));
    373 };
    374 
    375 // Decodes base64 data from a stream using the given alphabet, returning the
    376 // number of bytes of bytes read (i.e. len(buf)).
    377 export fn decode(
    378 	in: io::handle,
    379 	enc: *encoding,
    380 	buf: []u8,
    381 ) (size | io::EOF | io::error) = {
    382 	const enc = newdecoder(enc, in);
    383 	match (io::readall(&enc, buf)) {
    384 	case let ret: (size | io::EOF) =>
    385 		io::close(&enc)?;
    386 		return ret;
    387 	case let err: io::error =>
    388 		io::close(&enc): void;
    389 		return err;
    390 	};
    391 };
    392 
    393 @test fn decode() void = {
    394 	// RFC 4648 test vectors
    395 	const cases: [_](str, str, *encoding) = [
    396 		("", "", &std_encoding),
    397 		("Zg==", "f", &std_encoding),
    398 		("Zm8=", "fo", &std_encoding),
    399 		("Zm9v", "foo", &std_encoding),
    400 		("Zm9vYg==", "foob", &std_encoding),
    401 		("Zm9vYmE=", "fooba", &std_encoding),
    402 		("Zm9vYmFy", "foobar", &std_encoding),
    403 	];
    404 	for (let i = 0z; i < len(cases); i += 1) {
    405 		let in = bufio::fixed(strings::toutf8(cases[i].0), io::mode::READ);
    406 		let decoder = newdecoder(cases[i].2, &in);
    407 		let buf: [1]u8 = [0];
    408 		let decb: []u8 = [];
    409 		defer free(decb);
    410 		for (true) match (io::read(&decoder, buf)!) {
    411 		case let z: size =>
    412 			if (z > 0) {
    413 				append(decb, buf[0]);
    414 			};
    415 		case io::EOF =>
    416 			break;
    417 		};
    418 		assert(bytes::equal(decb, strings::toutf8(cases[i].1)));
    419 
    420 		// Testing decodestr should cover decodeslice too
    421 		let decb = decodestr(cases[i].2, cases[i].0) as []u8;
    422 		defer free(decb);
    423 		assert(bytes::equal(decb, strings::toutf8(cases[i].1)));
    424 	};
    425 	// Repeat of the above, but with a larger buffer
    426 	for (let i = 0z; i < len(cases); i += 1) {
    427 		let in = bufio::fixed(strings::toutf8(cases[i].0), io::mode::READ);
    428 		let decoder = newdecoder(cases[i].2, &in);
    429 		let buf: [1024]u8 = [0...];
    430 		let decb: []u8 = [];
    431 		defer free(decb);
    432 		for (true) match (io::read(&decoder, buf)!) {
    433 		case let z: size =>
    434 			if (z > 0) {
    435 				append(decb, buf[..z]...);
    436 			};
    437 		case io::EOF =>
    438 			break;
    439 		};
    440 		assert(bytes::equal(decb, strings::toutf8(cases[i].1)));
    441 	};
    442 
    443 	const invalid: [_](str, *encoding) = [
    444 		// invalid padding
    445 		("=", &std_encoding),
    446 		("==", &std_encoding),
    447 		("===", &std_encoding),
    448 		("=====", &std_encoding),
    449 		("======", &std_encoding),
    450 		// invalid characters
    451 		("@Zg=", &std_encoding),
    452 		("êg==", &std_encoding),
    453 		// data after padding is encountered
    454 		("Zg==Zg==", &std_encoding),
    455 		("Zm8=Zm8=", &std_encoding),
    456 	];
    457 	for (let i = 0z; i < len(invalid); i += 1) {
    458 		let in = bufio::fixed(strings::toutf8(invalid[i].0), io::mode::READ);
    459 		let decoder = newdecoder(invalid[i].1, &in);
    460 		let buf: [1]u8 = [0...];
    461 		let valid = false;
    462 		for (true) match(io::read(&decoder, buf)) {
    463 		case errors::invalid =>
    464 			break;
    465 		case size =>
    466 			valid = true;
    467 		case io::EOF =>
    468 			break;
    469 		};
    470 		assert(valid == false, "valid is not false");
    471 
    472 		// Testing decodestr should cover decodeslice too
    473 		assert(decodestr(invalid[i].1, invalid[i].0) is errors::invalid);
    474 	};
    475 };