commit 99249f26d4b9bfd77320ee8d198760e1344f1c11
parent 0db89b74908cdd8ce36bc9b9e16c72f967be437c
Author: Thomas Bracht Laumann Jespersen <t@laumann.xyz>
Date: Mon, 6 Dec 2021 10:18:33 +0100
encoding/base64: Provide a decoder as an io::stream
This implements decoding as an io::stream and reworks utility functions
to work use the stream.
Signed-off-by: Thomas Bracht Laumann Jespersen <t@laumann.xyz>
Diffstat:
1 file changed, 254 insertions(+), 94 deletions(-)
diff --git a/encoding/base64/base64.ha b/encoding/base64/base64.ha
@@ -3,6 +3,7 @@ use bytes;
use io;
use strio;
use strings;
+use errors;
// RFC 4648 standard "base64" base 64 encoding alphabet.
export const standard: []u8 = [
@@ -31,12 +32,6 @@ export const urlsafe: []u8 = [
// The padding character used at the end of encoding.
export def PADDING: u8 = '=': u32: u8;
-// Indicates that invalid input was found while decoding, either in the form of
-// characters outside of the base 64 alphabet, insufficient padding, or trailing
-// characters. Contains the index of the first invalid character, which may be
-// outside of the bounds of a truncated input.
-export type invalid = !size;
-
// Encodes a byte slice using a base 64 encoding alphabet, with padding, and
// writes it to an [[io::handle]]. The number of bytes written is returned.
export fn encode(
@@ -106,78 +101,19 @@ export fn decode(
alphabet: []u8,
in: io::handle,
out: io::handle,
-) (size | invalid | io::error) = {
- const INVALID_OR_PAD = 255u8;
- let decoder: [256]u8 = [INVALID_OR_PAD...];
- for (let i = 0z; i < len(alphabet); i += 1) {
- decoder[alphabet[i]] = i: u8;
- };
-
- let count = 0z;
- let z = 0z;
- for (true) {
- let buf: [4]u8 = [0...];
- match (io::read(in, buf)) {
- case size =>
- for (let i = 0z; i < 2; i += 1) {
- if (decoder[buf[i]] == INVALID_OR_PAD) {
- return (count + i): invalid;
- } else {
- buf[i] = decoder[buf[i]];
- };
- };
-
- if (decoder[buf[2]] == INVALID_OR_PAD) {
- if (buf[2] != PADDING) {
- return (count + 2z): invalid;
- };
- if (buf[3] != PADDING) {
- return (count + 3z): invalid;
- };
- z += io::write(out, [
- buf[0] << 2 | buf[1] >> 4,
- ])?;
- let extra: []u8 = [0];
- match (io::read(in, extra)) {
- case size =>
- return (count + 4z): invalid;
- case io::EOF =>
- return z;
- };
- } else {
- buf[2] = decoder[buf[2]];
- };
-
- if (decoder[buf[3]] == INVALID_OR_PAD) {
- if (buf[3] != PADDING) {
- return (count + 3z): invalid;
- };
- z += io::write(out, [
- buf[0] << 2 | buf[1] >> 4,
- buf[1] << 4 | buf[2] >> 2,
- ])?;
- let extra: []u8 = [0];
- match (io::read(in, extra)) {
- case size =>
- return (count + 4z): invalid;
- case io::EOF =>
- return z;
- };
- } else {
- buf[3] = decoder[buf[3]];
- };
-
- z += io::write(out, [
- buf[0] << 2 | buf[1] >> 4,
- buf[1] << 4 | buf[2] >> 2,
- buf[2] << 6 | buf[3],
- ])?;
- count += 4;
- case io::EOF =>
- break;
+) (size | errors::invalid | io::error) = {
+ let dec = decoder(alphabet, in);
+ match (io::copy(out, &dec)) {
+ case err: io::error =>
+ match (err) {
+ case errors::invalid =>
+ return errors::invalid;
+ case =>
+ return err;
};
+ case s: size =>
+ return s;
};
- return z;
};
// Decodes base 64-encoded data in the given base 64 alphabet, with padding,
@@ -186,23 +122,27 @@ export fn decode_static(
alphabet: []u8,
out: []u8,
in: io::handle,
-) (size | invalid) = {
+) (size | errors::invalid | io::error) = {
let buf = bufio::fixed(out, io::mode::WRITE);
defer io::close(buf);
- match (decode(alphabet, in, buf)) {
- case io::error =>
- abort();
- case z: invalid =>
- return z: invalid;
- case z: size =>
- return z;
+ let dec = decoder(alphabet, in);
+ match (io::copy(buf, &dec)) {
+ case err: io::error =>
+ match (err) {
+ case errors::invalid =>
+ return errors::invalid;
+ case =>
+ return err;
+ };
+ case s: size =>
+ return s;
};
};
// Decodes a string of base 64-encoded data in the given base 64 encoding
// alphabet, with padding, into a byte slice. The caller must free the return
// value.
-export fn decodestr(alphabet: []u8, in: str) ([]u8 | invalid) = {
+export fn decodestr(alphabet: []u8, in: str) ([]u8 | errors::invalid) = {
return decodeslice(alphabet, strings::toutf8(in));
};
@@ -212,23 +152,22 @@ export fn decodestr_static(
alphabet: []u8,
out: []u8,
in: str,
-) (size | invalid) = {
+) (size | errors::invalid) = {
return decodeslice_static(alphabet, out, strings::toutf8(in));
};
// Decodes a byte slice of base 64-encoded data in the given base 64 encoding
// alphabet, with padding, into a byte slice. The caller must free the return
// value.
-export fn decodeslice(alphabet: []u8, in: []u8) ([]u8 | invalid) = {
+export fn decodeslice(alphabet: []u8, in: []u8) ([]u8 | errors::invalid) = {
let out = bufio::dynamic(io::mode::WRITE);
let in = bufio::fixed(in, io::mode::READ);
defer io::close(in);
- match (decode(alphabet, in, out)) {
+ let dec = decoder(alphabet, in);
+ match (io::copy(out, &dec)) {
case io::error =>
- abort();
- case z: invalid =>
io::close(out);
- return z: invalid;
+ return errors::invalid;
case size =>
return bufio::finish(out);
};
@@ -240,10 +179,17 @@ export fn decodeslice_static(
alphabet: []u8,
out: []u8,
in: []u8,
-) (size | invalid) = {
+) (size | errors::invalid) = {
let in = bufio::fixed(in, io::mode::READ);
defer io::close(in); // bufio::finish?
- return decode_static(alphabet, out, in);
+ match (decode_static(alphabet, out, in)) {
+ case s: size =>
+ return s;
+ case errors::invalid =>
+ return errors::invalid;
+ case =>
+ abort();
+ };
};
@test fn decode() void = {
@@ -276,6 +222,220 @@ export fn decodeslice_static(
const badindex: [_]size = [1, 2, 3, 0, 0, 1, 3, 4];
for (let i = 0z; i < len(bad); i += 1) {
let result = decodestr(standard, bad[i]);
- assert(result as invalid == badindex[i]: invalid);
+ assert(result is errors::invalid);
+ };
+};
+
+const INVALID_OR_PAD: u8 = 255;
+
+// Initialize a new base64 decoder stream wrapping the given [[io::handle]]
+export fn decoder(alphabet: []u8, in: io::handle) decode_stream = {
+ let decoder: [256]u8 = [INVALID_OR_PAD...];
+ for (let i = 0z; i < len(alphabet); i += 1) {
+ decoder[alphabet[i]] = i: u8;
+ };
+ return decode_stream {
+ reader = &decodestream_reader,
+ input = in,
+ decoder = decoder,
+ ...
+ };
+};
+
+// An stream interface for base64. Wraps an [[io::handle]] and does on-the-fly
+// decoding with calls to read().
+export type decode_stream = struct {
+ io::stream,
+ input: io::handle,
+ buf: [4]u8,
+ avail: size, // How many bytes are already decoded, but didn't fit in a previous read
+ waseof: bool,
+ decoder: [256]u8,
+};
+
+fn decodestream_reader(s: *io::stream, out: []u8) (size | io::EOF | io::error) = {
+ assert(len(out) > 0, "zero-length buffer provided");
+ let s = s : *decode_stream;
+ let z = 0z;
+ let decoder = s.decoder;
+ let buf = s.buf;
+
+ // We may have already decoded some bytes that couldn't be pushed out
+ // in a previous call to read.
+ if (s.avail > 0) {
+ z += if (len(out) < s.avail) len(out) else s.avail;
+ out[..z] = s.buf[..z];
+ s.avail -= z;
+ s.buf[..s.avail] = s.buf[z..z+s.avail];
+ out = out[z..];
+ if (len(out) == 0) {
+ return z;
+ };
+ };
+
+ if (s.waseof) {
+ return io::EOF;
+ };
+
+ // If we get here, we have pushed out all cached bytes and are ready to
+ // read some more. Reset the internal buffer here.
+ buf = [INVALID_OR_PAD...];
+ for (true) match (io::read(s.input, buf)) {
+ case size =>
+ for (let i = 0z; i < 2; i += 1) {
+ if (decoder[buf[i]] == INVALID_OR_PAD) {
+ return errors::invalid;
+ } else {
+ buf[i] = decoder[buf[i]];
+ };
+ };
+
+ if (decoder[buf[2]] == INVALID_OR_PAD) {
+ if (buf[2] != PADDING) {
+ return errors::invalid;
+ };
+ if (buf[3] != PADDING) {
+ return errors::invalid;
+ };
+ s.buf[0] = buf[0] << 2 | buf[1] >> 4;
+ s.avail += 1;
+ // End of stream...
+ let extra: []u8 = [0];
+ match (io::read(s.input, extra)) {
+ case size =>
+ return errors::invalid;
+ case io::EOF =>
+ s.waseof = true;
+ if (len(out) > 0) {
+ out[0] = s.buf[0];
+ z += 1;
+ s.avail = 0;
+ };
+ break;
+ case err: io::error =>
+ return err;
+ };
+ } else {
+ buf[2] = decoder[buf[2]];
+ };
+
+ if (decoder[buf[3]] == INVALID_OR_PAD) {
+ if (buf[3] != PADDING) {
+ return errors::invalid;
+ };
+ s.buf[..2] = [
+ buf[0] << 2 | buf[1] >> 4,
+ buf[1] << 4 | buf[2] >> 2,
+ ];
+ s.avail += 2;
+ let extra: []u8 = [0];
+ match (io::read(s.input, extra)) {
+ case size =>
+ return errors::invalid;
+ case io::EOF =>
+ let n = if (len(out) < s.avail) len(out) else s.avail;
+ out[..n] = s.buf[..n];
+ s.avail -= n;
+ out = out[n..];
+ s.buf[..s.avail] = s.buf[n..n+s.avail];
+ s.waseof = true;
+ z += n;
+ break;
+ case err: io::error =>
+ return err;
+ };
+ } else {
+ buf[3] = decoder[buf[3]];
+ };
+
+ s.buf[..3] = [
+ buf[0] << 2 | buf[1] >> 4,
+ buf[1] << 4 | buf[2] >> 2,
+ buf[2] << 6 | buf[3]
+ ];
+ s.avail += 3;
+
+ let n = if (len(out) < s.avail) len(out) else s.avail;
+ out[..n] = s.buf[..n];
+ s.avail -= n;
+ out = out[n..];
+ s.buf[..s.avail] = s.buf[n..n+s.avail];
+ z += n;
+
+ if (len(out) == 0) {
+ break;
+ };
+ case io::EOF =>
+ s.waseof = true;
+ if (z == 0) {
+ return io::EOF;
+ };
+ break;
+ case err: io::error =>
+ return err;
+ };
+ return z;
+};
+
+@test fn decode_stream() void = {
+
+ const cases: [](str, str) = [
+ ("Y2hhbmdlbQ==", "changem"),
+ ("Y2hhbmdlbWU=", "changeme"),
+ ("Y2hhbmdlbWVt", "changemem"),
+ ];
+
+ for (let i = 0z; i < len(cases); i += 1) {
+ let s = cases[i].0;
+ let expected = cases[i].1;
+
+ let b = strings::toutf8(s);
+ let input = bufio::fixed(b, io::mode::READ);
+
+ let dec = decoder(standard, input);
+ defer io::close(&dec);
+
+ let buf: [1]u8 = [0];
+ let out: []u8 = [];
+ defer free(out);
+
+ for (true) match (io::read(&dec, buf)) {
+ case z: size =>
+ append(out, buf[0]);
+ assert(z == 1);
+ case io::EOF =>
+ break;
+ case err: io::error =>
+ abort();
+ };
+
+ assert(bytes::equal(out, strings::toutf8(expected)));
+ };
+
+ // Repeat of the above, but with a larger buffer
+ for (let i = 0z; i < len(cases); i += 1) {
+ let s = cases[i].0;
+ let expected = cases[i].1;
+
+ let b = strings::toutf8(s);
+ let input = bufio::fixed(b, io::mode::READ);
+
+ let dec = decoder(standard, input);
+ defer io::close(&dec);
+
+ let buf: [24]u8 = [0...];
+ let out: []u8 = [];
+ defer free(out);
+
+ for (true) match (io::read(&dec, buf)) {
+ case z: size =>
+ append(out, buf[..z]...);
+ case io::EOF =>
+ break;
+ case err: io::error =>
+ abort();
+ };
+
+ assert(bytes::equal(out, strings::toutf8(expected)));
};
};