base64.ha (13005B)
1 // SPDX-License-Identifier: MPL-2.0 2 // (c) Hare authors <https://harelang.org> 3 4 use ascii; 5 use bytes; 6 use errors; 7 use io; 8 use memio; 9 use os; 10 use strings; 11 12 def PADDING: u8 = '='; 13 14 export type encoding = struct { 15 encmap: [64]u8, 16 decmap: [128]u8, 17 }; 18 19 // Represents the standard base-64 encoding alphabet as defined in RFC 4648. 20 export const std_encoding: encoding = encoding { ... }; 21 22 // Represents the "base64url" alphabet as defined in RFC 4648, suitable for use 23 // in URLs and file paths. 24 export const url_encoding: encoding = encoding { ... }; 25 26 // Initializes a new encoding based on the passed alphabet, which must be a 27 // 64-byte ASCII string. 28 export fn encoding_init(enc: *encoding, alphabet: str) void = { 29 const alphabet = strings::toutf8(alphabet); 30 enc.decmap[..] = [-1...]; 31 assert(len(alphabet) == 64); 32 for (let i: u8 = 0; i < 64; i += 1) { 33 const ch = alphabet[i]; 34 assert(ascii::valid(ch: rune) && enc.decmap[ch] == -1); 35 enc.encmap[i] = ch; 36 enc.decmap[ch] = i; 37 }; 38 }; 39 40 @init fn init() void = { 41 const std_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; 42 const url_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; 43 encoding_init(&std_encoding, std_alpha); 44 encoding_init(&url_encoding, url_alpha); 45 }; 46 47 export type encoder = struct { 48 stream: io::stream, 49 out: io::handle, 50 enc: *encoding, 51 ibuf: [3]u8, 52 obuf: [4]u8, 53 iavail: u8, 54 oavail: u8, 55 }; 56 57 const encoder_vtable: io::vtable = io::vtable { 58 writer = &encode_writer, 59 closer = &encode_closer, 60 ... 61 }; 62 63 // Creates a stream that encodes writes as base64 before writing them to a 64 // secondary stream. Afterwards [[io::close]] must be called to write any 65 // unwritten bytes, in case of padding. Closing this stream will not close the 66 // underlying stream. After a write returns an error, the stream must not be 67 // written to again or closed. 68 export fn newencoder( 69 enc: *encoding, 70 out: io::handle, 71 ) encoder = { 72 return encoder { 73 stream = &encoder_vtable, 74 out = out, 75 enc = enc, 76 ... 77 }; 78 }; 79 80 fn encode_writer( 81 s: *io::stream, 82 in: const []u8 83 ) (size | io::error) = { 84 let s = s: *encoder; 85 let i = 0z; 86 for (i < len(in)) { 87 let b = s.ibuf[..]; 88 // fill ibuf 89 for (let j = s.iavail; j < 3 && i < len(in); j += 1) { 90 b[j] = in[i]; 91 i += 1; 92 s.iavail += 1; 93 }; 94 95 if (s.iavail != 3) { 96 return i; 97 }; 98 99 fillobuf(s); 100 101 match (writeavail(s)) { 102 case let e: io::error => 103 if (i == 0) { 104 return e; 105 }; 106 return i; 107 case void => void; 108 }; 109 }; 110 111 return i; 112 }; 113 114 fn fillobuf(s: *encoder) void = { 115 assert(s.iavail == 3); 116 let b = s.ibuf[..]; 117 s.obuf[..] = [ 118 s.enc.encmap[b[0] >> 2], 119 s.enc.encmap[(b[0] & 0x3) << 4 | b[1] >> 4], 120 s.enc.encmap[(b[1] & 0xf) << 2 | b[2] >> 6], 121 s.enc.encmap[b[2] & 0x3f], 122 ][..]; 123 s.oavail = 4; 124 }; 125 126 fn writeavail(s: *encoder) (void | io::error) = { 127 if (s.oavail == 0) { 128 return; 129 }; 130 131 for (s.oavail > 0) { 132 let n = io::write(s.out, s.obuf[len(s.obuf) - s.oavail..])?; 133 s.oavail -= n: u8; 134 }; 135 136 if (s.oavail == 0) { 137 s.iavail = 0; 138 }; 139 }; 140 141 // Flushes pending writes to the underlying stream. 142 fn encode_closer(s: *io::stream) (void | io::error) = { 143 let s = s: *encoder; 144 let finished = false; 145 defer if (finished) clear(s); 146 147 if (s.oavail > 0) { 148 for (s.oavail > 0) { 149 writeavail(s)?; 150 }; 151 finished = true; 152 return; 153 }; 154 155 if (s.iavail == 0) { 156 finished = true; 157 return; 158 }; 159 160 // prepare padding as input length was not a multiple of 3 161 // 0 1 2 162 static const npa: []u8 = [0, 2, 1]; 163 const np = npa[s.iavail]; 164 165 for (let i = s.iavail; i < 3; i += 1) { 166 s.ibuf[i] = 0; 167 s.iavail += 1; 168 }; 169 170 fillobuf(s); 171 for (let i = 0z; i < np; i += 1) { 172 s.obuf[3 - i] = PADDING; 173 }; 174 175 for (s.oavail > 0) { 176 writeavail(s)?; 177 }; 178 finished = true; 179 }; 180 181 fn clear(e: *encoder) void = { 182 bytes::zero(e.ibuf); 183 bytes::zero(e.obuf); 184 }; 185 186 @test fn partialwrite() void = { 187 const raw: [_]u8 = [ 188 0x00, 0x00, 0x00, 0x07, 0x73, 0x73, 0x68, 0x2d, 0x72, 0x73, 189 0x61, 0x00, 190 ]; 191 const expected: str = `AAAAB3NzaC1yc2EA`; 192 193 let buf = memio::dynamic(); 194 let e = newencoder(&std_encoding, &buf); 195 io::writeall(&e, raw[..4])!; 196 io::writeall(&e, raw[4..11])!; 197 io::writeall(&e, raw[11..])!; 198 io::close(&e)!; 199 200 assert(memio::string(&buf)! == expected); 201 202 let encb = memio::buffer(&buf); 203 free(encb); 204 }; 205 206 // Encodes a byte slice in base 64, using the given encoding, returning a slice 207 // of ASCII bytes. The caller must free the return value. 208 export fn encodeslice(enc: *encoding, in: []u8) []u8 = { 209 let out = memio::dynamic(); 210 let encoder = newencoder(enc, &out); 211 io::writeall(&encoder, in)!; 212 io::close(&encoder)!; 213 return memio::buffer(&out); 214 }; 215 216 // Encodes base64 data using the given alphabet and writes it to a stream, 217 // returning the number of bytes of data written (i.e. len(buf)). 218 export fn encode( 219 out: io::handle, 220 enc: *encoding, 221 buf: []u8, 222 ) (size | io::error) = { 223 const enc = newencoder(enc, out); 224 match (io::writeall(&enc, buf)) { 225 case let z: size => 226 io::close(&enc)?; 227 return z; 228 case let err: io::error => 229 clear(&enc); 230 return err; 231 }; 232 }; 233 234 // Encodes a byte slice in base 64, using the given encoding, returning a 235 // string. The caller must free the return value. 236 export fn encodestr(enc: *encoding, in: []u8) str = { 237 return strings::fromutf8(encodeslice(enc, in))!; 238 }; 239 240 @test fn encode() void = { 241 // RFC 4648 test vectors 242 const in: [_]u8 = ['f', 'o', 'o', 'b', 'a', 'r']; 243 const expect: [_]str = [ 244 "", 245 "Zg==", 246 "Zm8=", 247 "Zm9v", 248 "Zm9vYg==", 249 "Zm9vYmE=", 250 "Zm9vYmFy" 251 ]; 252 for (let i = 0z; i <= len(in); i += 1) { 253 let out = memio::dynamic(); 254 let encoder = newencoder(&std_encoding, &out); 255 io::writeall(&encoder, in[..i])!; 256 io::close(&encoder)!; 257 let encb = memio::buffer(&out); 258 defer free(encb); 259 assert(bytes::equal(encb, strings::toutf8(expect[i]))); 260 261 // Testing encodestr should cover encodeslice too 262 let s = encodestr(&std_encoding, in[..i]); 263 defer free(s); 264 assert(s == expect[i]); 265 }; 266 }; 267 268 export type decoder = struct { 269 stream: io::stream, 270 in: io::handle, 271 enc: *encoding, 272 obuf: [3]u8, // leftover decoded output 273 ibuf: [4]u8, 274 iavail: u8, 275 oavail: u8, 276 pad: bool, // if padding was seen in a previous read 277 }; 278 279 const decoder_vtable: io::vtable = io::vtable { 280 reader = &decode_reader, 281 ... 282 }; 283 284 // Creates a stream that reads and decodes base 64 data from a secondary stream. 285 // This stream does not need to be closed, and closing it will not close the 286 // underlying stream. If a read returns an error, the stream must not be read 287 // from again. 288 export fn newdecoder( 289 enc: *encoding, 290 in: io::handle, 291 ) decoder = { 292 return decoder { 293 stream = &decoder_vtable, 294 in = in, 295 enc = enc, 296 ... 297 }; 298 }; 299 300 fn decode_reader( 301 s: *io::stream, 302 out: []u8 303 ) (size | io::EOF | io::error) = { 304 let s = s: *decoder; 305 if (len(out) == 0) { 306 return 0z; 307 }; 308 let n = 0z; 309 if (s.oavail != 0) { 310 if (len(out) <= s.oavail) { 311 out[..] = s.obuf[..len(out)]; 312 s.obuf[..len(s.obuf) - len(out)] = s.obuf[len(out)..]; 313 s.oavail = s.oavail - len(out): u8; 314 return len(out); 315 }; 316 n = s.oavail; 317 s.oavail = 0; 318 out[..n] = s.obuf[..n]; 319 out = out[n..]; 320 }; 321 let buf: [os::BUFSZ]u8 = [0...]; 322 buf[..s.iavail] = s.ibuf[..s.iavail]; 323 324 let want = encodedsize(len(out)); 325 let nr = s.iavail: size; 326 let lim = if (want > len(buf)) len(buf) else want; 327 match (io::readall(s.in, buf[s.iavail..lim])) { 328 case let n: size => 329 nr += n; 330 case io::EOF => 331 return if (s.iavail != 0) errors::invalid 332 else if (n != 0) n 333 else io::EOF; 334 case let err: io::error => 335 if (!(err is io::underread)) { 336 return err; 337 }; 338 nr += err: io::underread; 339 }; 340 if (s.pad) { 341 return errors::invalid; 342 }; 343 s.iavail = nr: u8 % 4; 344 s.ibuf[..s.iavail] = buf[nr - s.iavail..nr]; 345 nr -= s.iavail; 346 if (nr == 0) { 347 return 0z; 348 }; 349 // Validating read buffer 350 let np = 0z; // Number of padding chars. 351 for (let i = 0z; i < nr; i += 1) { 352 if (buf[i] == PADDING) { 353 for (i + np < nr; np += 1) { 354 if (np > 2 || buf[i + np] != PADDING) { 355 return errors::invalid; 356 }; 357 }; 358 s.pad = true; 359 break; 360 }; 361 if (!ascii::valid(buf[i]: u32: rune) || s.enc.decmap[buf[i]] == -1) { 362 return errors::invalid; 363 }; 364 buf[i] = s.enc.decmap[buf[i]]; 365 }; 366 367 if (nr / 4 * 3 - np < len(out)) { 368 out = out[..nr / 4 * 3 - np]; 369 }; 370 let i = 0z, j = 0z; 371 nr -= 4; 372 for (i < nr) { 373 out[j ] = buf[i ] << 2 | buf[i + 1] >> 4; 374 out[j + 1] = buf[i + 1] << 4 | buf[i + 2] >> 2; 375 out[j + 2] = buf[i + 2] << 6 | buf[i + 3]; 376 377 i += 4; 378 j += 3; 379 }; 380 s.obuf = [ 381 buf[i ] << 2 | buf[i + 1] >> 4, 382 buf[i + 1] << 4 | buf[i + 2] >> 2, 383 buf[i + 2] << 6 | buf[i + 3], 384 ]; 385 out[j..] = s.obuf[..len(out) - j]; 386 s.oavail = (len(s.obuf) - (len(out) - j)): u8; 387 s.obuf[..s.oavail] = s.obuf[len(s.obuf) - s.oavail..]; 388 s.oavail -= np: u8; 389 return n + len(out); 390 }; 391 392 // Decodes a byte slice of ASCII-encoded base 64 data, using the given encoding, 393 // returning a slice of decoded bytes. The caller must free the return value. 394 export fn decodeslice( 395 enc: *encoding, 396 in: []u8, 397 ) ([]u8 | errors::invalid) = { 398 if (len(in) == 0) { 399 return []; 400 }; 401 if (len(in) % 4 != 0) { 402 return errors::invalid; 403 }; 404 let ins = memio::fixed(in); 405 let decoder = newdecoder(enc, &ins); 406 let out = alloc([0u8...], decodedsize(len(in)))!; 407 let outs = memio::fixed(out); 408 match (io::copy(&outs, &decoder)) { 409 case io::error => 410 free(out); 411 return errors::invalid; 412 case let sz: size => 413 return memio::buffer(&outs)[..sz]; 414 }; 415 }; 416 417 // Decodes a string of ASCII-encoded base 64 data, using the given encoding, 418 // returning a slice of decoded bytes. The caller must free the return value. 419 export fn decodestr(enc: *encoding, in: str) ([]u8 | errors::invalid) = { 420 return decodeslice(enc, strings::toutf8(in)); 421 }; 422 423 // Decodes base64 data from a stream using the given alphabet, returning the 424 // number of bytes of bytes read (i.e. len(buf)). 425 export fn decode( 426 in: io::handle, 427 enc: *encoding, 428 buf: []u8, 429 ) (size | io::EOF | io::error) = { 430 const dec = newdecoder(enc, in); 431 return io::readall(&dec, buf); 432 }; 433 434 @test fn decode() void = { 435 // RFC 4648 test vectors 436 const cases: [_](str, str, *encoding) = [ 437 ("", "", &std_encoding), 438 ("Zg==", "f", &std_encoding), 439 ("Zm8=", "fo", &std_encoding), 440 ("Zm9v", "foo", &std_encoding), 441 ("Zm9vYg==", "foob", &std_encoding), 442 ("Zm9vYmE=", "fooba", &std_encoding), 443 ("Zm9vYmFy", "foobar", &std_encoding), 444 ]; 445 const invalid: [_](str, *encoding) = [ 446 // invalid padding 447 ("=", &std_encoding), 448 ("==", &std_encoding), 449 ("===", &std_encoding), 450 ("=====", &std_encoding), 451 ("======", &std_encoding), 452 // invalid characters 453 ("@Zg=", &std_encoding), 454 ("ê==", &std_encoding), 455 ("êg==", &std_encoding), 456 ("$3d==", &std_encoding), 457 ("%3d==", &std_encoding), 458 ("[==", &std_encoding), 459 ("!", &std_encoding), 460 // data after padding is encountered 461 ("Zg===", &std_encoding), 462 ("Zg====", &std_encoding), 463 ("Zg==Zg==", &std_encoding), 464 ("Zm8=Zm8=", &std_encoding), 465 ]; 466 let buf: [12]u8 = [0...]; 467 for (let bufsz = 1z; bufsz <= 12; bufsz += 1) { 468 for (let (input, expected, encoding) .. cases) { 469 let in = memio::fixed(strings::toutf8(input)); 470 let decoder = newdecoder(encoding, &in); 471 let buf = buf[..bufsz]; 472 let decb: []u8 = []; 473 defer free(decb); 474 for (true) match (io::read(&decoder, buf)!) { 475 case let z: size => 476 if (z > 0) { 477 append(decb, buf[..z]...)!; 478 }; 479 case io::EOF => 480 break; 481 }; 482 assert(bytes::equal(decb, strings::toutf8(expected))); 483 484 // Testing decodestr should cover decodeslice too 485 let decb = decodestr(encoding, input) as []u8; 486 defer free(decb); 487 assert(bytes::equal(decb, strings::toutf8(expected))); 488 }; 489 490 for (let (input, encoding) .. invalid) { 491 let in = memio::fixed(strings::toutf8(input)); 492 let decoder = newdecoder(encoding, &in); 493 let buf = buf[..bufsz]; 494 let valid = false; 495 for (true) match(io::read(&decoder, buf)) { 496 case errors::invalid => 497 break; 498 case size => 499 void; 500 case io::EOF => 501 abort(); 502 }; 503 504 // Testing decodestr should cover decodeslice too 505 assert(decodestr(encoding, input) is errors::invalid); 506 }; 507 }; 508 }; 509 510 // Given the length of the message, returns the size of its base64 encoding 511 export fn encodedsize(sz: size) size = if (sz == 0) 0 else ((sz - 1)/ 3 + 1) * 4; 512 513 // Given the size of base64 encoded data, returns maximal length of decoded message. 514 // The message may be at most 2 bytes shorter than the returned value. Input 515 // size must be a multiple of 4. 516 export fn decodedsize(sz: size) size = { 517 assert(sz % 4 == 0); 518 return sz / 4 * 3; 519 }; 520 521 @test fn sizecalc() void = { 522 let enc: [_](size, size) = [(1, 4), (2, 4), (3, 4), (4, 8), (10, 16), 523 (119, 160), (120, 160), (121, 164), (122, 164), (123, 164) 524 ]; 525 assert(encodedsize(0) == 0 && decodedsize(0) == 0); 526 for (let i = 0z; i < len(enc); i += 1) { 527 let (decoded, encoded) = enc[i]; 528 assert(encodedsize(decoded) == encoded); 529 assert(decodedsize(encoded) == ((decoded - 1) / 3 + 1) * 3); 530 }; 531 };