base32.ha (12025B)
1 // SPDX-License-Identifier: MPL-2.0 2 // (c) Hare authors <https://harelang.org> 3 4 use ascii; 5 use bytes; 6 use errors; 7 use io; 8 use memio; 9 use os; 10 use strings; 11 12 def PADDING: u8 = '='; 13 14 export type encoding = struct { 15 encmap: [32]u8, 16 decmap: [256]u8, 17 valid: [256]bool, 18 }; 19 20 // Represents the standard base-32 encoding alphabet as defined in RFC 4648. 21 export const std_encoding: encoding = encoding { ... }; 22 23 // Represents the "base32hex" alphabet as defined in RFC 4648. 24 export const hex_encoding: encoding = encoding { ... }; 25 26 // Initializes a new encoding based on the passed alphabet, which must be a 27 // 32-byte ASCII string. 28 export fn encoding_init(enc: *encoding, alphabet: str) void = { 29 const alphabet = strings::toutf8(alphabet); 30 assert(len(alphabet) == 32); 31 for (let i: u8 = 0; i < 32; i += 1) { 32 const ch = alphabet[i]; 33 assert(ascii::valid(ch: rune)); 34 enc.encmap[i] = ch; 35 enc.decmap[ch] = i; 36 enc.valid[ch] = true; 37 }; 38 }; 39 40 @init fn init() void = { 41 const std_alpha: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; 42 const hex_alpha: str = "0123456789ABCDEFGHIJKLMNOPQRSTUV"; 43 encoding_init(&std_encoding, std_alpha); 44 encoding_init(&hex_encoding, hex_alpha); 45 }; 46 47 export type encoder = struct { 48 stream: io::stream, 49 out: io::handle, 50 enc: *encoding, 51 buf: [4]u8, // leftover input 52 avail: size, // bytes available in buf 53 err: (void | io::error), 54 }; 55 56 const encoder_vtable: io::vtable = io::vtable { 57 writer = &encode_writer, 58 closer = &encode_closer, 59 ... 60 }; 61 62 // Creates a stream that encodes writes as base-32 before writing them to a 63 // secondary stream. The encoder stream must be closed to finalize any unwritten 64 // bytes. Closing this stream will not close the underlying stream. 65 export fn newencoder( 66 enc: *encoding, 67 out: io::handle, 68 ) encoder = { 69 return encoder { 70 stream = &encoder_vtable, 71 out = out, 72 enc = enc, 73 err = void, 74 ... 75 }; 76 }; 77 78 fn encode_writer( 79 s: *io::stream, 80 in: const []u8 81 ) (size | io::error) = { 82 let s = s: *encoder; 83 match(s.err) { 84 case let err: io::error => 85 return err; 86 case void => void; 87 }; 88 let l = len(in); 89 let i = 0z; 90 for (i + 4 < l + s.avail; i += 5) { 91 static let b: [5]u8 = [0...]; // 5 bytes -> (enc) 8 bytes 92 if (i < s.avail) { 93 for (let j = 0z; j < s.avail; j += 1) { 94 b[j] = s.buf[i]; 95 }; 96 for (let j = s.avail; j < 5; j += 1) { 97 b[j] = in[j - s.avail]; 98 }; 99 } else { 100 for (let j = 0z; j < 5; j += 1) { 101 b[j] = in[j - s.avail + i]; 102 }; 103 }; 104 let encb: [8]u8 = [ 105 s.enc.encmap[b[0] >> 3], 106 s.enc.encmap[(b[0] & 0x7) << 2 | (b[1] & 0xC0) >> 6], 107 s.enc.encmap[(b[1] & 0x3E) >> 1], 108 s.enc.encmap[(b[1] & 0x1) << 4 | (b[2] & 0xF0) >> 4], 109 s.enc.encmap[(b[2] & 0xF) << 1 | (b[3] & 0x80) >> 7], 110 s.enc.encmap[(b[3] & 0x7C) >> 2], 111 s.enc.encmap[(b[3] & 0x3) << 3 | (b[4] & 0xE0) >> 5], 112 s.enc.encmap[b[4] & 0x1F], 113 ]; 114 match(io::write(s.out, encb)) { 115 case let err: io::error => 116 s.err = err; 117 return err; 118 case size => void; 119 }; 120 }; 121 // storing leftover bytes 122 if (l + s.avail < 5) { 123 for (let j = s.avail; j < s.avail + l; j += 1) { 124 s.buf[j] = in[j - s.avail]; 125 }; 126 } else { 127 const begin = (l + s.avail) / 5 * 5; 128 for (let j = begin; j < l + s.avail; j += 1) { 129 s.buf[j - begin] = in[j - s.avail]; 130 }; 131 }; 132 s.avail = (l + s.avail) % 5; 133 return l; 134 }; 135 136 fn encode_closer(s: *io::stream) (void | io::error) = { 137 let s = s: *encoder; 138 if (s.avail == 0) { 139 return; 140 }; 141 static let b: [5]u8 = [0...]; // the 5 bytes that will be encoded into 8 bytes 142 for (let i = 0z; i < 5; i += 1) { 143 b[i] = if (i < s.avail) s.buf[i] else 0; 144 }; 145 let encb: [8]u8 = [ 146 s.enc.encmap[b[0] >> 3], 147 s.enc.encmap[(b[0] & 0x7) << 2 | (b[1] & 0xC0) >> 6], 148 s.enc.encmap[(b[1] & 0x3E) >> 1], 149 s.enc.encmap[(b[1] & 0x1) << 4 | (b[2] & 0xF0) >> 4], 150 s.enc.encmap[(b[2] & 0xF) << 1 | (b[3] & 0x80) >> 7], 151 s.enc.encmap[(b[3] & 0x7C) >> 2], 152 s.enc.encmap[(b[3] & 0x3) << 3 | (b[4] & 0xE0) >> 5], 153 s.enc.encmap[b[4] & 0x1F], 154 ]; 155 // adding padding as input length was not a multiple of 5 156 // 0 1 2 3 4 157 static const npa: []u8 = [0, 6, 4, 3, 1]; 158 const np = npa[s.avail]; 159 for (let i = 0z; i < np; i += 1) { 160 encb[7 - i] = PADDING; 161 }; 162 io::writeall(s.out, encb)?; 163 }; 164 165 // Encodes a byte slice in base-32, using the given encoding, returning a slice 166 // of ASCII bytes. The caller must free the return value. 167 export fn encodeslice(enc: *encoding, in: []u8) []u8 = { 168 let out = memio::dynamic(); 169 let encoder = newencoder(enc, &out); 170 io::writeall(&encoder, in)!; 171 io::close(&encoder)!; 172 return memio::buffer(&out); 173 }; 174 175 // Encodes a byte slice in base-32, using the given encoding, returning a 176 // string. The caller must free the return value. 177 export fn encodestr(enc: *encoding, in: []u8) str = { 178 return strings::fromutf8(encodeslice(enc, in))!; 179 }; 180 181 @test fn encode() void = { 182 // RFC 4648 test vectors 183 const in: [_]u8 = ['f', 'o', 'o', 'b', 'a', 'r']; 184 const expect: [_]str = [ 185 "", 186 "MY======", 187 "MZXQ====", 188 "MZXW6===", 189 "MZXW6YQ=", 190 "MZXW6YTB", 191 "MZXW6YTBOI======", 192 ]; 193 const expect_hex: [_]str = [ 194 "", 195 "CO======", 196 "CPNG====", 197 "CPNMU===", 198 "CPNMUOG=", 199 "CPNMUOJ1", 200 "CPNMUOJ1E8======", 201 ]; 202 for (let i = 0z; i <= len(in); i += 1) { 203 let out = memio::dynamic(); 204 let enc = newencoder(&std_encoding, &out); 205 io::writeall(&enc, in[..i]) as size; 206 io::close(&enc)!; 207 let outb = memio::buffer(&out); 208 assert(bytes::equal(outb, strings::toutf8(expect[i]))); 209 free(outb); 210 // Testing encodestr should cover encodeslice too 211 let s = encodestr(&std_encoding, in[..i]); 212 defer free(s); 213 assert(s == expect[i]); 214 215 out = memio::dynamic(); 216 enc = newencoder(&hex_encoding, &out); 217 io::writeall(&enc, in[..i]) as size; 218 io::close(&enc)!; 219 let outb = memio::buffer(&out); 220 assert(bytes::equal(outb, strings::toutf8(expect_hex[i]))); 221 free(outb); 222 let s = encodestr(&hex_encoding, in[..i]); 223 defer free(s); 224 assert(s == expect_hex[i]); 225 }; 226 }; 227 228 export type decoder = struct { 229 stream: io::stream, 230 in: io::handle, 231 enc: *encoding, 232 avail: []u8, // leftover decoded output 233 pad: bool, // if padding was seen in a previous read 234 state: (void | io::EOF | io::error), 235 }; 236 237 const decoder_vtable: io::vtable = io::vtable { 238 reader = &decode_reader, 239 ... 240 }; 241 242 // Creates a stream that reads and decodes base-32 data from a secondary stream. 243 // This stream does not need to be closed, and closing it will not close the 244 // underlying stream. 245 export fn newdecoder( 246 enc: *encoding, 247 in: io::handle, 248 ) decoder = { 249 return decoder { 250 stream = &decoder_vtable, 251 in = in, 252 enc = enc, 253 state = void, 254 ... 255 }; 256 }; 257 258 fn decode_reader( 259 s: *io::stream, 260 out: []u8 261 ) (size | io::EOF | io::error) = { 262 let s = s: *decoder; 263 let n = 0z; 264 let l = len(out); 265 match(s.state) { 266 case let err: (io::EOF | io ::error) => 267 return err; 268 case void => void; 269 }; 270 if (len(s.avail) > 0) { 271 n += if (l < len(s.avail)) l else len(s.avail); 272 out[..n] = s.avail[0..n]; 273 s.avail = s.avail[n..]; 274 if (l == n) { 275 return n; 276 }; 277 }; 278 static let buf: [os::BUFSZ]u8 = [0...]; 279 static let obuf: [os::BUFSZ / 8 * 5]u8 = [0...]; 280 const nn = ((l - n) / 5 + 1) * 8; // 8 extra bytes may be read. 281 let nr = 0z; 282 for (nr < nn) { 283 match (io::read(s.in, buf[nr..])) { 284 case let n: size => 285 nr += n; 286 case io::EOF => 287 s.state = io::EOF; 288 break; 289 case let err: io::error => 290 s.state = err; 291 return err; 292 }; 293 }; 294 if (nr % 8 != 0) { 295 s.state = errors::invalid; 296 return errors::invalid; 297 }; 298 if (nr == 0) { // io::EOF already set 299 return n; 300 }; 301 // Validating read buffer 302 let valid = true; 303 let np = 0; // Number of padding chars. 304 let p = true; // Pad allowed in buf 305 for (let i = nr; i > 0; i -= 1) { 306 const ch = buf[i - 1]; 307 if (ch == PADDING) { 308 if(s.pad || !p) { 309 valid = false; 310 break; 311 }; 312 np += 1; 313 } else { 314 if (!s.enc.valid[ch]) { 315 valid = false; 316 break; 317 }; 318 // Disallow padding on seeing a non-padding char 319 p = false; 320 }; 321 }; 322 valid = valid && np <= 6 && np != 2 && np != 5; 323 if (np > 0) { 324 s.pad = true; 325 }; 326 if (!valid) { 327 s.state = errors::invalid; 328 return errors::invalid; 329 }; 330 for (let i = 0z; i < nr; i += 1) { 331 buf[i] = s.enc.decmap[buf[i]]; 332 }; 333 for (let i = 0z, j = 0z; i < nr) { 334 obuf[j] = (buf[i] << 3) | (buf[i + 1] & 0x1C) >> 2; 335 obuf[j + 1] = 336 (buf[i + 1] & 0x3) << 6 | buf[i + 2] << 1 | (buf[i + 3] & 0x10) >> 4; 337 obuf[j + 2] = (buf[i + 3] & 0x0F) << 4 | (buf[i + 4] & 0x1E) >> 1; 338 obuf[j + 3] = 339 (buf[i + 4] & 0x1) << 7 | buf[i + 5] << 2 | (buf[i + 6] & 0x18) >> 3; 340 obuf[j + 4] = (buf[i + 6] & 0x7) << 5 | buf[i + 7]; 341 i += 8; 342 j += 5; 343 }; 344 // Removing bytes added due to padding. 345 // 0 1 2 3 4 5 6 // np 346 static const npr: [7]u8 = [0, 1, 0, 2, 3, 0, 4]; // bytes to discard 347 const navl = nr / 8 * 5 - npr[np]; 348 const rem = if(l - n < navl) l - n else navl; 349 out[n..n + rem] = obuf[..rem]; 350 s.avail = obuf[rem..navl]; 351 return n + rem; 352 }; 353 354 // Decodes a byte slice of ASCII-encoded base-32 data, using the given encoding, 355 // returning a slice of decoded bytes. The caller must free the return value. 356 export fn decodeslice( 357 enc: *encoding, 358 in: []u8, 359 ) ([]u8 | errors::invalid) = { 360 let in = memio::fixed(in); 361 let decoder = newdecoder(enc, &in); 362 let out = memio::dynamic(); 363 match (io::copy(&out, &decoder)) { 364 case io::error => 365 io::close(&out)!; 366 return errors::invalid; 367 case size => 368 return memio::buffer(&out); 369 }; 370 }; 371 372 // Decodes a string of ASCII-encoded base-32 data, using the given encoding, 373 // returning a slice of decoded bytes. The caller must free the return value. 374 export fn decodestr(enc: *encoding, in: str) ([]u8 | errors::invalid) = { 375 return decodeslice(enc, strings::toutf8(in)); 376 }; 377 378 @test fn decode() void = { 379 const cases: [_](str, str, *encoding) = [ 380 ("", "", &std_encoding), 381 ("MY======", "f", &std_encoding), 382 ("MZXQ====", "fo", &std_encoding), 383 ("MZXW6===", "foo", &std_encoding), 384 ("MZXW6YQ=", "foob", &std_encoding), 385 ("MZXW6YTB", "fooba", &std_encoding), 386 ("MZXW6YTBOI======", "foobar", &std_encoding), 387 ("", "", &hex_encoding), 388 ("CO======", "f", &hex_encoding), 389 ("CPNG====", "fo", &hex_encoding), 390 ("CPNMU===", "foo", &hex_encoding), 391 ("CPNMUOG=", "foob", &hex_encoding), 392 ("CPNMUOJ1", "fooba", &hex_encoding), 393 ("CPNMUOJ1E8======", "foobar", &hex_encoding), 394 ]; 395 for (let i = 0z; i < len(cases); i += 1) { 396 let in = memio::fixed(strings::toutf8(cases[i].0)); 397 let dec = newdecoder(cases[i].2, &in); 398 let out: []u8 = io::drain(&dec)!; 399 defer free(out); 400 assert(bytes::equal(out, strings::toutf8(cases[i].1))); 401 402 // Testing decodestr should cover decodeslice too 403 let decb = decodestr(cases[i].2, cases[i].0) as []u8; 404 defer free(decb); 405 assert(bytes::equal(decb, strings::toutf8(cases[i].1))); 406 }; 407 // Repeat of the above, but with a larger buffer 408 for (let i = 0z; i < len(cases); i += 1) { 409 let in = memio::fixed(strings::toutf8(cases[i].0)); 410 let dec = newdecoder(cases[i].2, &in); 411 let out: []u8 = io::drain(&dec)!; 412 defer free(out); 413 assert(bytes::equal(out, strings::toutf8(cases[i].1))); 414 }; 415 416 const invalid: [_](str, *encoding) = [ 417 // invalid padding 418 ("=", &std_encoding), 419 ("==", &std_encoding), 420 ("===", &std_encoding), 421 ("=====", &std_encoding), 422 ("======", &std_encoding), 423 ("=======", &std_encoding), 424 ("========", &std_encoding), 425 ("=========", &std_encoding), 426 // invalid characters 427 ("1ZXW6YQ=", &std_encoding), 428 ("êZXW6YQ=", &std_encoding), 429 ("MZXW1YQ=", &std_encoding), 430 // data after padding is encountered 431 ("CO======CO======", &std_encoding), 432 ("CPNG====CPNG====", &std_encoding), 433 ]; 434 for (let i = 0z; i < len(invalid); i += 1) { 435 let in = memio::fixed(strings::toutf8(invalid[i].0)); 436 let dec = newdecoder(invalid[i].1, &in); 437 let buf: [1]u8 = [0...]; 438 let valid = false; 439 for (true) match(io::read(&dec, buf)) { 440 case errors::invalid => 441 break; 442 case size => 443 valid = true; 444 case io::EOF => 445 break; 446 }; 447 assert(valid == false, "valid is not false"); 448 449 // Testing decodestr should cover decodeslice too 450 assert(decodestr(invalid[i].1, invalid[i].0) is errors::invalid); 451 }; 452 };