hare

The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

commit 99249f26d4b9bfd77320ee8d198760e1344f1c11
parent 0db89b74908cdd8ce36bc9b9e16c72f967be437c
Author: Thomas Bracht Laumann Jespersen <t@laumann.xyz>
Date:   Mon,  6 Dec 2021 10:18:33 +0100

encoding/base64: Provide a decoder as an io::stream

This implements decoding as an io::stream and reworks utility functions
to work use the stream.

Signed-off-by: Thomas Bracht Laumann Jespersen <t@laumann.xyz>

Diffstat:
Mencoding/base64/base64.ha | 348+++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------
1 file changed, 254 insertions(+), 94 deletions(-)

diff --git a/encoding/base64/base64.ha b/encoding/base64/base64.ha @@ -3,6 +3,7 @@ use bytes; use io; use strio; use strings; +use errors; // RFC 4648 standard "base64" base 64 encoding alphabet. export const standard: []u8 = [ @@ -31,12 +32,6 @@ export const urlsafe: []u8 = [ // The padding character used at the end of encoding. export def PADDING: u8 = '=': u32: u8; -// Indicates that invalid input was found while decoding, either in the form of -// characters outside of the base 64 alphabet, insufficient padding, or trailing -// characters. Contains the index of the first invalid character, which may be -// outside of the bounds of a truncated input. -export type invalid = !size; - // Encodes a byte slice using a base 64 encoding alphabet, with padding, and // writes it to an [[io::handle]]. The number of bytes written is returned. export fn encode( @@ -106,78 +101,19 @@ export fn decode( alphabet: []u8, in: io::handle, out: io::handle, -) (size | invalid | io::error) = { - const INVALID_OR_PAD = 255u8; - let decoder: [256]u8 = [INVALID_OR_PAD...]; - for (let i = 0z; i < len(alphabet); i += 1) { - decoder[alphabet[i]] = i: u8; - }; - - let count = 0z; - let z = 0z; - for (true) { - let buf: [4]u8 = [0...]; - match (io::read(in, buf)) { - case size => - for (let i = 0z; i < 2; i += 1) { - if (decoder[buf[i]] == INVALID_OR_PAD) { - return (count + i): invalid; - } else { - buf[i] = decoder[buf[i]]; - }; - }; - - if (decoder[buf[2]] == INVALID_OR_PAD) { - if (buf[2] != PADDING) { - return (count + 2z): invalid; - }; - if (buf[3] != PADDING) { - return (count + 3z): invalid; - }; - z += io::write(out, [ - buf[0] << 2 | buf[1] >> 4, - ])?; - let extra: []u8 = [0]; - match (io::read(in, extra)) { - case size => - return (count + 4z): invalid; - case io::EOF => - return z; - }; - } else { - buf[2] = decoder[buf[2]]; - }; - - if (decoder[buf[3]] == INVALID_OR_PAD) { - if (buf[3] != PADDING) { - return (count + 3z): invalid; - }; - z += io::write(out, [ - buf[0] << 2 | buf[1] >> 4, - buf[1] << 4 | buf[2] >> 2, - ])?; - let extra: []u8 = [0]; - match (io::read(in, extra)) { - case size => - return (count + 4z): invalid; - case io::EOF => - return z; - }; - } else { - buf[3] = decoder[buf[3]]; - }; - - z += io::write(out, [ - buf[0] << 2 | buf[1] >> 4, - buf[1] << 4 | buf[2] >> 2, - buf[2] << 6 | buf[3], - ])?; - count += 4; - case io::EOF => - break; +) (size | errors::invalid | io::error) = { + let dec = decoder(alphabet, in); + match (io::copy(out, &dec)) { + case err: io::error => + match (err) { + case errors::invalid => + return errors::invalid; + case => + return err; }; + case s: size => + return s; }; - return z; }; // Decodes base 64-encoded data in the given base 64 alphabet, with padding, @@ -186,23 +122,27 @@ export fn decode_static( alphabet: []u8, out: []u8, in: io::handle, -) (size | invalid) = { +) (size | errors::invalid | io::error) = { let buf = bufio::fixed(out, io::mode::WRITE); defer io::close(buf); - match (decode(alphabet, in, buf)) { - case io::error => - abort(); - case z: invalid => - return z: invalid; - case z: size => - return z; + let dec = decoder(alphabet, in); + match (io::copy(buf, &dec)) { + case err: io::error => + match (err) { + case errors::invalid => + return errors::invalid; + case => + return err; + }; + case s: size => + return s; }; }; // Decodes a string of base 64-encoded data in the given base 64 encoding // alphabet, with padding, into a byte slice. The caller must free the return // value. -export fn decodestr(alphabet: []u8, in: str) ([]u8 | invalid) = { +export fn decodestr(alphabet: []u8, in: str) ([]u8 | errors::invalid) = { return decodeslice(alphabet, strings::toutf8(in)); }; @@ -212,23 +152,22 @@ export fn decodestr_static( alphabet: []u8, out: []u8, in: str, -) (size | invalid) = { +) (size | errors::invalid) = { return decodeslice_static(alphabet, out, strings::toutf8(in)); }; // Decodes a byte slice of base 64-encoded data in the given base 64 encoding // alphabet, with padding, into a byte slice. The caller must free the return // value. -export fn decodeslice(alphabet: []u8, in: []u8) ([]u8 | invalid) = { +export fn decodeslice(alphabet: []u8, in: []u8) ([]u8 | errors::invalid) = { let out = bufio::dynamic(io::mode::WRITE); let in = bufio::fixed(in, io::mode::READ); defer io::close(in); - match (decode(alphabet, in, out)) { + let dec = decoder(alphabet, in); + match (io::copy(out, &dec)) { case io::error => - abort(); - case z: invalid => io::close(out); - return z: invalid; + return errors::invalid; case size => return bufio::finish(out); }; @@ -240,10 +179,17 @@ export fn decodeslice_static( alphabet: []u8, out: []u8, in: []u8, -) (size | invalid) = { +) (size | errors::invalid) = { let in = bufio::fixed(in, io::mode::READ); defer io::close(in); // bufio::finish? - return decode_static(alphabet, out, in); + match (decode_static(alphabet, out, in)) { + case s: size => + return s; + case errors::invalid => + return errors::invalid; + case => + abort(); + }; }; @test fn decode() void = { @@ -276,6 +222,220 @@ export fn decodeslice_static( const badindex: [_]size = [1, 2, 3, 0, 0, 1, 3, 4]; for (let i = 0z; i < len(bad); i += 1) { let result = decodestr(standard, bad[i]); - assert(result as invalid == badindex[i]: invalid); + assert(result is errors::invalid); + }; +}; + +const INVALID_OR_PAD: u8 = 255; + +// Initialize a new base64 decoder stream wrapping the given [[io::handle]] +export fn decoder(alphabet: []u8, in: io::handle) decode_stream = { + let decoder: [256]u8 = [INVALID_OR_PAD...]; + for (let i = 0z; i < len(alphabet); i += 1) { + decoder[alphabet[i]] = i: u8; + }; + return decode_stream { + reader = &decodestream_reader, + input = in, + decoder = decoder, + ... + }; +}; + +// An stream interface for base64. Wraps an [[io::handle]] and does on-the-fly +// decoding with calls to read(). +export type decode_stream = struct { + io::stream, + input: io::handle, + buf: [4]u8, + avail: size, // How many bytes are already decoded, but didn't fit in a previous read + waseof: bool, + decoder: [256]u8, +}; + +fn decodestream_reader(s: *io::stream, out: []u8) (size | io::EOF | io::error) = { + assert(len(out) > 0, "zero-length buffer provided"); + let s = s : *decode_stream; + let z = 0z; + let decoder = s.decoder; + let buf = s.buf; + + // We may have already decoded some bytes that couldn't be pushed out + // in a previous call to read. + if (s.avail > 0) { + z += if (len(out) < s.avail) len(out) else s.avail; + out[..z] = s.buf[..z]; + s.avail -= z; + s.buf[..s.avail] = s.buf[z..z+s.avail]; + out = out[z..]; + if (len(out) == 0) { + return z; + }; + }; + + if (s.waseof) { + return io::EOF; + }; + + // If we get here, we have pushed out all cached bytes and are ready to + // read some more. Reset the internal buffer here. + buf = [INVALID_OR_PAD...]; + for (true) match (io::read(s.input, buf)) { + case size => + for (let i = 0z; i < 2; i += 1) { + if (decoder[buf[i]] == INVALID_OR_PAD) { + return errors::invalid; + } else { + buf[i] = decoder[buf[i]]; + }; + }; + + if (decoder[buf[2]] == INVALID_OR_PAD) { + if (buf[2] != PADDING) { + return errors::invalid; + }; + if (buf[3] != PADDING) { + return errors::invalid; + }; + s.buf[0] = buf[0] << 2 | buf[1] >> 4; + s.avail += 1; + // End of stream... + let extra: []u8 = [0]; + match (io::read(s.input, extra)) { + case size => + return errors::invalid; + case io::EOF => + s.waseof = true; + if (len(out) > 0) { + out[0] = s.buf[0]; + z += 1; + s.avail = 0; + }; + break; + case err: io::error => + return err; + }; + } else { + buf[2] = decoder[buf[2]]; + }; + + if (decoder[buf[3]] == INVALID_OR_PAD) { + if (buf[3] != PADDING) { + return errors::invalid; + }; + s.buf[..2] = [ + buf[0] << 2 | buf[1] >> 4, + buf[1] << 4 | buf[2] >> 2, + ]; + s.avail += 2; + let extra: []u8 = [0]; + match (io::read(s.input, extra)) { + case size => + return errors::invalid; + case io::EOF => + let n = if (len(out) < s.avail) len(out) else s.avail; + out[..n] = s.buf[..n]; + s.avail -= n; + out = out[n..]; + s.buf[..s.avail] = s.buf[n..n+s.avail]; + s.waseof = true; + z += n; + break; + case err: io::error => + return err; + }; + } else { + buf[3] = decoder[buf[3]]; + }; + + s.buf[..3] = [ + buf[0] << 2 | buf[1] >> 4, + buf[1] << 4 | buf[2] >> 2, + buf[2] << 6 | buf[3] + ]; + s.avail += 3; + + let n = if (len(out) < s.avail) len(out) else s.avail; + out[..n] = s.buf[..n]; + s.avail -= n; + out = out[n..]; + s.buf[..s.avail] = s.buf[n..n+s.avail]; + z += n; + + if (len(out) == 0) { + break; + }; + case io::EOF => + s.waseof = true; + if (z == 0) { + return io::EOF; + }; + break; + case err: io::error => + return err; + }; + return z; +}; + +@test fn decode_stream() void = { + + const cases: [](str, str) = [ + ("Y2hhbmdlbQ==", "changem"), + ("Y2hhbmdlbWU=", "changeme"), + ("Y2hhbmdlbWVt", "changemem"), + ]; + + for (let i = 0z; i < len(cases); i += 1) { + let s = cases[i].0; + let expected = cases[i].1; + + let b = strings::toutf8(s); + let input = bufio::fixed(b, io::mode::READ); + + let dec = decoder(standard, input); + defer io::close(&dec); + + let buf: [1]u8 = [0]; + let out: []u8 = []; + defer free(out); + + for (true) match (io::read(&dec, buf)) { + case z: size => + append(out, buf[0]); + assert(z == 1); + case io::EOF => + break; + case err: io::error => + abort(); + }; + + assert(bytes::equal(out, strings::toutf8(expected))); + }; + + // Repeat of the above, but with a larger buffer + for (let i = 0z; i < len(cases); i += 1) { + let s = cases[i].0; + let expected = cases[i].1; + + let b = strings::toutf8(s); + let input = bufio::fixed(b, io::mode::READ); + + let dec = decoder(standard, input); + defer io::close(&dec); + + let buf: [24]u8 = [0...]; + let out: []u8 = []; + defer free(out); + + for (true) match (io::read(&dec, buf)) { + case z: size => + append(out, buf[..z]...); + case io::EOF => + break; + case err: io::error => + abort(); + }; + + assert(bytes::equal(out, strings::toutf8(expected))); }; };