base64.ha (12956B)
1 // SPDX-License-Identifier: MPL-2.0 2 // (c) Hare authors <https://harelang.org> 3 4 use ascii; 5 use bytes; 6 use errors; 7 use io; 8 use memio; 9 use os; 10 use strings; 11 12 def PADDING: u8 = '='; 13 14 export type encoding = struct { 15 encmap: [64]u8, 16 decmap: [128]u8, 17 }; 18 19 // Represents the standard base-64 encoding alphabet as defined in RFC 4648. 20 export const std_encoding: encoding = encoding { ... }; 21 22 // Represents the "base64url" alphabet as defined in RFC 4648, suitable for use 23 // in URLs and file paths. 24 export const url_encoding: encoding = encoding { ... }; 25 26 // Initializes a new encoding based on the passed alphabet, which must be a 27 // 64-byte ASCII string. 28 export fn encoding_init(enc: *encoding, alphabet: str) void = { 29 const alphabet = strings::toutf8(alphabet); 30 enc.decmap[..] = [-1...]; 31 assert(len(alphabet) == 64); 32 for (let i: u8 = 0; i < 64; i += 1) { 33 const ch = alphabet[i]; 34 assert(ascii::valid(ch: rune) && enc.decmap[ch] == -1); 35 enc.encmap[i] = ch; 36 enc.decmap[ch] = i; 37 }; 38 }; 39 40 @init fn init() void = { 41 const std_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; 42 const url_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; 43 encoding_init(&std_encoding, std_alpha); 44 encoding_init(&url_encoding, url_alpha); 45 }; 46 47 export type encoder = struct { 48 stream: io::stream, 49 out: io::handle, 50 enc: *encoding, 51 ibuf: [3]u8, 52 obuf: [4]u8, 53 iavail: u8, 54 oavail: u8, 55 }; 56 57 const encoder_vtable: io::vtable = io::vtable { 58 writer = &encode_writer, 59 closer = &encode_closer, 60 ... 61 }; 62 63 // Creates a stream that encodes writes as base64 before writing them to a 64 // secondary stream. Afterwards [[io::close]] must be called to write any 65 // unwritten bytes, in case of padding. Closing this stream will not close the 66 // underlying stream. After a write returns an error, the stream must not be 67 // written to again or closed. 68 export fn newencoder( 69 enc: *encoding, 70 out: io::handle, 71 ) encoder = { 72 return encoder { 73 stream = &encoder_vtable, 74 out = out, 75 enc = enc, 76 ... 77 }; 78 }; 79 80 fn encode_writer( 81 s: *io::stream, 82 in: const []u8 83 ) (size | io::error) = { 84 let s = s: *encoder; 85 let i = 0z; 86 for (i < len(in)) { 87 let b = s.ibuf[..]; 88 // fill ibuf 89 for (let j = s.iavail; j < 3 && i < len(in); j += 1) { 90 b[j] = in[i]; 91 i += 1; 92 s.iavail += 1; 93 }; 94 95 if (s.iavail != 3) { 96 return i; 97 }; 98 99 fillobuf(s); 100 101 match (writeavail(s)) { 102 case let e: io::error => 103 if (i == 0) { 104 return e; 105 }; 106 return i; 107 case void => void; 108 }; 109 }; 110 111 return i; 112 }; 113 114 fn fillobuf(s: *encoder) void = { 115 assert(s.iavail == 3); 116 let b = s.ibuf[..]; 117 s.obuf[..] = [ 118 s.enc.encmap[b[0] >> 2], 119 s.enc.encmap[(b[0] & 0x3) << 4 | b[1] >> 4], 120 s.enc.encmap[(b[1] & 0xf) << 2 | b[2] >> 6], 121 s.enc.encmap[b[2] & 0x3f], 122 ][..]; 123 s.oavail = 4; 124 }; 125 126 fn writeavail(s: *encoder) (void | io::error) = { 127 if (s.oavail == 0) { 128 return; 129 }; 130 131 for (s.oavail > 0) { 132 let n = io::write(s.out, s.obuf[len(s.obuf) - s.oavail..])?; 133 s.oavail -= n: u8; 134 }; 135 136 if (s.oavail == 0) { 137 s.iavail = 0; 138 }; 139 }; 140 141 // Flushes pending writes to the underlying stream. 142 fn encode_closer(s: *io::stream) (void | io::error) = { 143 let s = s: *encoder; 144 let finished = false; 145 defer if (finished) clear(s); 146 147 if (s.oavail > 0) { 148 for (s.oavail > 0) { 149 writeavail(s)?; 150 }; 151 finished = true; 152 return; 153 }; 154 155 if (s.iavail == 0) { 156 finished = true; 157 return; 158 }; 159 160 // prepare padding as input length was not a multiple of 3 161 // 0 1 2 162 static const npa: []u8 = [0, 2, 1]; 163 const np = npa[s.iavail]; 164 165 for (let i = s.iavail; i < 3; i += 1) { 166 s.ibuf[i] = 0; 167 s.iavail += 1; 168 }; 169 170 fillobuf(s); 171 for (let i = 0z; i < np; i += 1) { 172 s.obuf[3 - i] = PADDING; 173 }; 174 175 for (s.oavail > 0) { 176 writeavail(s)?; 177 }; 178 finished = true; 179 }; 180 181 fn clear(e: *encoder) void = { 182 bytes::zero(e.ibuf); 183 bytes::zero(e.obuf); 184 }; 185 186 @test fn partialwrite() void = { 187 const raw: [_]u8 = [ 188 0x00, 0x00, 0x00, 0x07, 0x73, 0x73, 0x68, 0x2d, 0x72, 0x73, 189 0x61, 0x00, 190 ]; 191 const expected: str = `AAAAB3NzaC1yc2EA`; 192 193 let buf = memio::dynamic(); 194 let e = newencoder(&std_encoding, &buf); 195 io::writeall(&e, raw[..4])!; 196 io::writeall(&e, raw[4..11])!; 197 io::writeall(&e, raw[11..])!; 198 io::close(&e)!; 199 200 assert(memio::string(&buf)! == expected); 201 }; 202 203 // Encodes a byte slice in base 64, using the given encoding, returning a slice 204 // of ASCII bytes. The caller must free the return value. 205 export fn encodeslice(enc: *encoding, in: []u8) []u8 = { 206 let out = memio::dynamic(); 207 let encoder = newencoder(enc, &out); 208 io::writeall(&encoder, in)!; 209 io::close(&encoder)!; 210 return memio::buffer(&out); 211 }; 212 213 // Encodes base64 data using the given alphabet and writes it to a stream, 214 // returning the number of bytes of data written (i.e. len(buf)). 215 export fn encode( 216 out: io::handle, 217 enc: *encoding, 218 buf: []u8, 219 ) (size | io::error) = { 220 const enc = newencoder(enc, out); 221 match (io::writeall(&enc, buf)) { 222 case let z: size => 223 io::close(&enc)?; 224 return z; 225 case let err: io::error => 226 clear(&enc); 227 return err; 228 }; 229 }; 230 231 // Encodes a byte slice in base 64, using the given encoding, returning a 232 // string. The caller must free the return value. 233 export fn encodestr(enc: *encoding, in: []u8) str = { 234 return strings::fromutf8(encodeslice(enc, in))!; 235 }; 236 237 @test fn encode() void = { 238 // RFC 4648 test vectors 239 const in: [_]u8 = ['f', 'o', 'o', 'b', 'a', 'r']; 240 const expect: [_]str = [ 241 "", 242 "Zg==", 243 "Zm8=", 244 "Zm9v", 245 "Zm9vYg==", 246 "Zm9vYmE=", 247 "Zm9vYmFy" 248 ]; 249 for (let i = 0z; i <= len(in); i += 1) { 250 let out = memio::dynamic(); 251 let encoder = newencoder(&std_encoding, &out); 252 io::writeall(&encoder, in[..i])!; 253 io::close(&encoder)!; 254 let encb = memio::buffer(&out); 255 defer free(encb); 256 assert(bytes::equal(encb, strings::toutf8(expect[i]))); 257 258 // Testing encodestr should cover encodeslice too 259 let s = encodestr(&std_encoding, in[..i]); 260 defer free(s); 261 assert(s == expect[i]); 262 }; 263 }; 264 265 export type decoder = struct { 266 stream: io::stream, 267 in: io::handle, 268 enc: *encoding, 269 obuf: [3]u8, // leftover decoded output 270 ibuf: [4]u8, 271 iavail: u8, 272 oavail: u8, 273 pad: bool, // if padding was seen in a previous read 274 }; 275 276 const decoder_vtable: io::vtable = io::vtable { 277 reader = &decode_reader, 278 ... 279 }; 280 281 // Creates a stream that reads and decodes base 64 data from a secondary stream. 282 // This stream does not need to be closed, and closing it will not close the 283 // underlying stream. If a read returns an error, the stream must not be read 284 // from again. 285 export fn newdecoder( 286 enc: *encoding, 287 in: io::handle, 288 ) decoder = { 289 return decoder { 290 stream = &decoder_vtable, 291 in = in, 292 enc = enc, 293 ... 294 }; 295 }; 296 297 fn decode_reader( 298 s: *io::stream, 299 out: []u8 300 ) (size | io::EOF | io::error) = { 301 let s = s: *decoder; 302 if (len(out) == 0) { 303 return 0z; 304 }; 305 let n = 0z; 306 if (s.oavail != 0) { 307 if (len(out) <= s.oavail) { 308 out[..] = s.obuf[..len(out)]; 309 s.obuf[..len(s.obuf) - len(out)] = s.obuf[len(out)..]; 310 s.oavail = s.oavail - len(out): u8; 311 return len(out); 312 }; 313 n = s.oavail; 314 s.oavail = 0; 315 out[..n] = s.obuf[..n]; 316 out = out[n..]; 317 }; 318 let buf: [os::BUFSZ]u8 = [0...]; 319 buf[..s.iavail] = s.ibuf[..s.iavail]; 320 321 let want = encodedsize(len(out)); 322 let nr = s.iavail: size; 323 let lim = if (want > len(buf)) len(buf) else want; 324 match (io::readall(s.in, buf[s.iavail..lim])) { 325 case let n: size => 326 nr += n; 327 case io::EOF => 328 return if (s.iavail != 0) errors::invalid 329 else if (n != 0) n 330 else io::EOF; 331 case let err: io::error => 332 if (!(err is io::underread)) { 333 return err; 334 }; 335 nr += err: io::underread; 336 }; 337 if (s.pad) { 338 return errors::invalid; 339 }; 340 s.iavail = nr: u8 % 4; 341 s.ibuf[..s.iavail] = buf[nr - s.iavail..nr]; 342 nr -= s.iavail; 343 if (nr == 0) { 344 return 0z; 345 }; 346 // Validating read buffer 347 let np = 0z; // Number of padding chars. 348 for (let i = 0z; i < nr; i += 1) { 349 if (buf[i] == PADDING) { 350 for (i + np < nr; np += 1) { 351 if (np > 2 || buf[i + np] != PADDING) { 352 return errors::invalid; 353 }; 354 }; 355 s.pad = true; 356 break; 357 }; 358 if (!ascii::valid(buf[i]: u32: rune) || s.enc.decmap[buf[i]] == -1) { 359 return errors::invalid; 360 }; 361 buf[i] = s.enc.decmap[buf[i]]; 362 }; 363 364 if (nr / 4 * 3 - np < len(out)) { 365 out = out[..nr / 4 * 3 - np]; 366 }; 367 let i = 0z, j = 0z; 368 nr -= 4; 369 for (i < nr) { 370 out[j ] = buf[i ] << 2 | buf[i + 1] >> 4; 371 out[j + 1] = buf[i + 1] << 4 | buf[i + 2] >> 2; 372 out[j + 2] = buf[i + 2] << 6 | buf[i + 3]; 373 374 i += 4; 375 j += 3; 376 }; 377 s.obuf = [ 378 buf[i ] << 2 | buf[i + 1] >> 4, 379 buf[i + 1] << 4 | buf[i + 2] >> 2, 380 buf[i + 2] << 6 | buf[i + 3], 381 ]; 382 out[j..] = s.obuf[..len(out) - j]; 383 s.oavail = (len(s.obuf) - (len(out) - j)): u8; 384 s.obuf[..s.oavail] = s.obuf[len(s.obuf) - s.oavail..]; 385 s.oavail -= np: u8; 386 return n + len(out); 387 }; 388 389 // Decodes a byte slice of ASCII-encoded base 64 data, using the given encoding, 390 // returning a slice of decoded bytes. The caller must free the return value. 391 export fn decodeslice( 392 enc: *encoding, 393 in: []u8, 394 ) ([]u8 | errors::invalid) = { 395 if (len(in) == 0) { 396 return []; 397 }; 398 if (len(in) % 4 != 0) { 399 return errors::invalid; 400 }; 401 let ins = memio::fixed(in); 402 let decoder = newdecoder(enc, &ins); 403 let out = alloc([0u8...], decodedsize(len(in))); 404 let outs = memio::fixed(out); 405 match (io::copy(&outs, &decoder)) { 406 case io::error => 407 free(out); 408 return errors::invalid; 409 case let sz: size => 410 return memio::buffer(&outs)[..sz]; 411 }; 412 }; 413 414 // Decodes a string of ASCII-encoded base 64 data, using the given encoding, 415 // returning a slice of decoded bytes. The caller must free the return value. 416 export fn decodestr(enc: *encoding, in: str) ([]u8 | errors::invalid) = { 417 return decodeslice(enc, strings::toutf8(in)); 418 }; 419 420 // Decodes base64 data from a stream using the given alphabet, returning the 421 // number of bytes of bytes read (i.e. len(buf)). 422 export fn decode( 423 in: io::handle, 424 enc: *encoding, 425 buf: []u8, 426 ) (size | io::EOF | io::error) = { 427 const dec = newdecoder(enc, in); 428 return io::readall(&dec, buf); 429 }; 430 431 @test fn decode() void = { 432 // RFC 4648 test vectors 433 const cases: [_](str, str, *encoding) = [ 434 ("", "", &std_encoding), 435 ("Zg==", "f", &std_encoding), 436 ("Zm8=", "fo", &std_encoding), 437 ("Zm9v", "foo", &std_encoding), 438 ("Zm9vYg==", "foob", &std_encoding), 439 ("Zm9vYmE=", "fooba", &std_encoding), 440 ("Zm9vYmFy", "foobar", &std_encoding), 441 ]; 442 const invalid: [_](str, *encoding) = [ 443 // invalid padding 444 ("=", &std_encoding), 445 ("==", &std_encoding), 446 ("===", &std_encoding), 447 ("=====", &std_encoding), 448 ("======", &std_encoding), 449 // invalid characters 450 ("@Zg=", &std_encoding), 451 ("ê==", &std_encoding), 452 ("êg==", &std_encoding), 453 ("$3d==", &std_encoding), 454 ("%3d==", &std_encoding), 455 ("[==", &std_encoding), 456 ("!", &std_encoding), 457 // data after padding is encountered 458 ("Zg===", &std_encoding), 459 ("Zg====", &std_encoding), 460 ("Zg==Zg==", &std_encoding), 461 ("Zm8=Zm8=", &std_encoding), 462 ]; 463 let buf: [12]u8 = [0...]; 464 for (let bufsz = 1z; bufsz <= 12; bufsz += 1) { 465 for (let (input, expected, encoding) .. cases) { 466 let in = memio::fixed(strings::toutf8(input)); 467 let decoder = newdecoder(encoding, &in); 468 let buf = buf[..bufsz]; 469 let decb: []u8 = []; 470 defer free(decb); 471 for (true) match (io::read(&decoder, buf)!) { 472 case let z: size => 473 if (z > 0) { 474 append(decb, buf[..z]...); 475 }; 476 case io::EOF => 477 break; 478 }; 479 assert(bytes::equal(decb, strings::toutf8(expected))); 480 481 // Testing decodestr should cover decodeslice too 482 let decb = decodestr(encoding, input) as []u8; 483 defer free(decb); 484 assert(bytes::equal(decb, strings::toutf8(expected))); 485 }; 486 487 for (let (input, encoding) .. invalid) { 488 let in = memio::fixed(strings::toutf8(input)); 489 let decoder = newdecoder(encoding, &in); 490 let buf = buf[..bufsz]; 491 let valid = false; 492 for (true) match(io::read(&decoder, buf)) { 493 case errors::invalid => 494 break; 495 case size => 496 void; 497 case io::EOF => 498 abort(); 499 }; 500 501 // Testing decodestr should cover decodeslice too 502 assert(decodestr(encoding, input) is errors::invalid); 503 }; 504 }; 505 }; 506 507 // Given the length of the message, returns the size of its base64 encoding 508 export fn encodedsize(sz: size) size = if (sz == 0) 0 else ((sz - 1)/ 3 + 1) * 4; 509 510 // Given the size of base64 encoded data, returns maximal length of decoded message. 511 // The message may be at most 2 bytes shorter than the returned value. Input 512 // size must be a multiple of 4. 513 export fn decodedsize(sz: size) size = { 514 assert(sz % 4 == 0); 515 return sz / 4 * 3; 516 }; 517 518 @test fn sizecalc() void = { 519 let enc: [_](size, size) = [(1, 4), (2, 4), (3, 4), (4, 8), (10, 16), 520 (119, 160), (120, 160), (121, 164), (122, 164), (123, 164) 521 ]; 522 assert(encodedsize(0) == 0 && decodedsize(0) == 0); 523 for (let i = 0z; i < len(enc); i += 1) { 524 let (decoded, encoded) = enc[i]; 525 assert(encodedsize(decoded) == encoded); 526 assert(decodedsize(encoded) == ((decoded - 1) / 3 + 1) * 3); 527 }; 528 };