hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

commit 8eb08f665aae1f499b0e16a53c55b2996f856ca7
parent c0432c642c35ac59b88f2ab9ad7f8898dbd3799e
Author: Bor Grošelj Simić <bgs@turminal.net>
Date:   Thu, 11 Apr 2024 01:46:30 +0200

encoding::base64: rewrite decoder

It uses io::readall now instead of doing repeated reads on its own,
doesn't use static variables for storing state between reads and handles
short/partial reads much better.

References: https://todo.sr.ht/~sircmpwn/hare/819
Signed-off-by: Bor Grošelj Simić <bgs@turminal.net>

Diffstat:
Mencoding/base64/base64.ha | 285++++++++++++++++++++++++++++++++++++++-----------------------------------------
1 file changed, 136 insertions(+), 149 deletions(-)

diff --git a/encoding/base64/base64.ha b/encoding/base64/base64.ha @@ -52,7 +52,6 @@ export type encoder = struct { iavail: size, obuf: [4]u8, oavail: size, - err: (void | io::error), }; const encoder_vtable: io::vtable = io::vtable { @@ -73,7 +72,6 @@ export fn newencoder( stream = &encoder_vtable, out = out, enc = enc, - err = void, ... }; }; @@ -83,13 +81,6 @@ fn encode_writer( in: const []u8 ) (size | io::error) = { let s = s: *encoder; - match(s.err) { - case let err: io::error => - s.err = void; - return err; - case void => void; - }; - let i = 0z; for (i < len(in)) { let b = s.ibuf[..]; @@ -111,7 +102,6 @@ fn encode_writer( if (i == 0) { return e; }; - s.err = e; return i; case void => void; }; @@ -153,13 +143,6 @@ fn encode_closer(s: *io::stream) (void | io::error) = { let finished = false; defer if (finished) clear(s); - match (s.err) { - case let e: io::error => - s.err = void; - return e; - case void => void; - }; - if (s.oavail > 0) { for (s.oavail > 0) { writeavail(s)?; @@ -282,9 +265,11 @@ export type decoder = struct { stream: io::stream, in: io::handle, enc: *encoding, - avail: []u8, // leftover decoded output + obuf: [3]u8, // leftover decoded output + ibuf: [4]u8, + iavail: u8, + oavail: u8, pad: bool, // if padding was seen in a previous read - state: (void | io::EOF | io::error), }; const decoder_vtable: io::vtable = io::vtable { @@ -303,7 +288,6 @@ export fn newdecoder( stream = &decoder_vtable, in = in, enc = enc, - state = void, ... }; }; @@ -313,101 +297,91 @@ fn decode_reader( out: []u8 ) (size | io::EOF | io::error) = { let s = s: *decoder; - let n = 0z; - let l = len(out); - match(s.state) { - case let err: (io::EOF | io ::error) => - return err; - case void => void; + if (len(out) == 0) { + return 0z; }; - if (len(s.avail) > 0) { - n += if (l < len(s.avail)) l else len(s.avail); - out[..n] = s.avail[0..n]; - s.avail = s.avail[n..]; - if (l == n) { - return n; + let n = 0z; + if (s.oavail != 0) { + if (len(out) <= s.oavail) { + out[..] = s.obuf[..len(out)]; + s.obuf[..len(s.obuf) - len(out)] = s.obuf[len(out)..]; + s.oavail = s.oavail - len(out): u8; + return len(out); }; - }; - static let buf: [os::BUFSZ]u8 = [0...]; - static let obuf: [os::BUFSZ / 4 * 3]u8 = [0...]; - const nn = ((l - n) / 3 + 1) * 4; // 4 extra bytes may be read. - let nr = 0z; - for (nr < nn) { - match (io::read(s.in, buf[nr..])) { - case let n: size => - if (n == 0) { - break; - }; - nr += n; - case io::EOF => - s.state = io::EOF; - break; - case let err: io::error => - s.state = err; + n = s.oavail; + s.oavail = 0; + out[..n] = s.obuf[..n]; + out = out[n..]; + }; + let buf: [os::BUFSZ]u8 = [0...]; + buf[..s.iavail] = s.ibuf[..s.iavail]; + + let want = encodedsize(len(out)); + let nr = s.iavail: size; + let lim = if (want > len(buf)) len(buf) else want; + match (io::readall(s.in, buf[s.iavail..lim])) { + case let n: size => + nr += n; + case io::EOF => + return if (s.iavail != 0) errors::invalid + else if (n != 0) n + else io::EOF; + case let err: io::error => + if (!(err is io::underread)) { return err; }; + nr += err: io::underread; }; - if (nr % 4 != 0) { - s.state = errors::invalid; + if (s.pad) { return errors::invalid; }; - if (nr == 0) { // io::EOF already set - return n; + s.iavail = nr: u8 % 4; + s.ibuf[..s.iavail] = buf[nr - s.iavail..nr]; + nr -= s.iavail; + if (nr == 0) { + return 0z; }; // Validating read buffer - let valid = true; - let np = 0; // Number of padding chars. - let p = true; // Pad allowed in buf - for (let i = nr; i > 0; i -= 1) { - const ch = buf[i - 1]; - if (ch >= 128) { - return errors::invalid; - }; - if (ch == PADDING) { - if(s.pad || !p) { - valid = false; - break; - }; - np += 1; - } else { - if (s.enc.decmap[ch] == -1) { - valid = false; - break; + let np = 0z; // Number of padding chars. + for (let i = 0z; i < nr; i += 1) { + if (buf[i] == PADDING) { + for (i + np < nr; np += 1) { + if (np > 2 || buf[i + np] != PADDING) { + return errors::invalid; + }; }; - // Disallow padding on seeing a non-padding char - p = false; + s.pad = true; + break; }; - }; - valid = valid && np <= 2; - if (np > 0) { - s.pad = true; - }; - if (!valid) { - s.state = errors::invalid; - return errors::invalid; - }; - for (let i = 0z; i < nr; i += 1) { - if (buf[i] >= 128) { + if (!ascii::valid(buf[i]: u32: rune) || s.enc.decmap[buf[i]] == -1) { return errors::invalid; }; buf[i] = s.enc.decmap[buf[i]]; }; - for (let i = 0z, j = 0z; i < nr) { - obuf[j] = buf[i] << 2 | buf[i + 1] >> 4; - obuf[j + 1] = buf[i + 1] << 4 | buf[i + 2] >> 2; - obuf[j + 2] = buf[i + 2] << 6 | buf[i + 3]; + + if (nr / 4 * 3 - np < len(out)) { + out = out[..nr / 4 * 3 - np]; + }; + let i = 0z, j = 0z; + nr -= 4; + for (i < nr) { + out[j ] = buf[i ] << 2 | buf[i + 1] >> 4; + out[j + 1] = buf[i + 1] << 4 | buf[i + 2] >> 2; + out[j + 2] = buf[i + 2] << 6 | buf[i + 3]; i += 4; j += 3; }; - // Removing bytes added due to padding. - // 0 1 2 // np - static const npr: [3]u8 = [0, 1, 2]; // bytes to discard - const navl = nr / 4 * 3 - npr[np]; - const rem = if(l - n < navl) l - n else navl; - out[n..n + rem] = obuf[..rem]; - s.avail = obuf[rem..navl]; - return n + rem; + s.obuf = [ + buf[i ] << 2 | buf[i + 1] >> 4, + buf[i + 1] << 4 | buf[i + 2] >> 2, + buf[i + 2] << 6 | buf[i + 3], + ]; + out[j..] = s.obuf[..len(out) - j]; + s.oavail = (len(s.obuf) - (len(out) - j)): u8; + s.obuf[..s.oavail] = s.obuf[len(s.obuf) - s.oavail..]; + s.oavail -= np: u8; + return n + len(out); }; // Decodes a byte slice of ASCII-encoded base 64 data, using the given encoding, @@ -416,15 +390,22 @@ export fn decodeslice( enc: *encoding, in: []u8, ) ([]u8 | errors::invalid) = { - let in = memio::fixed(in); - let decoder = newdecoder(enc, &in); - let out = memio::dynamic(); - match (io::copy(&out, &decoder)) { + if (len(in) == 0) { + return []; + }; + if (len(in) % 4 != 0) { + return errors::invalid; + }; + let ins = memio::fixed(in); + let decoder = newdecoder(enc, &ins); + let out = alloc([0u8...], decodedsize(len(in))); + let outs = memio::fixed(out); + match (io::copy(&outs, &decoder)) { case io::error => - io::close(&out)!; + free(out); return errors::invalid; - case size => - return memio::buffer(&out); + case let sz: size => + return memio::buffer(&outs)[..sz]; }; }; @@ -441,15 +422,8 @@ export fn decode( enc: *encoding, buf: []u8, ) (size | io::EOF | io::error) = { - const enc = newdecoder(enc, in); - match (io::readall(&enc, buf)) { - case let ret: (size | io::EOF) => - io::close(&enc)?; - return ret; - case let err: io::error => - io::close(&enc): void; - return err; - }; + const dec = newdecoder(enc, in); + return io::readall(&dec, buf); }; @test fn decode() void = { @@ -463,54 +437,67 @@ export fn decode( ("Zm9vYmE=", "fooba", &std_encoding), ("Zm9vYmFy", "foobar", &std_encoding), ]; - for (let i = 0z; i < len(cases); i += 1) { - let in = memio::fixed(strings::toutf8(cases[i].0)); - let decoder = newdecoder(cases[i].2, &in); - let decb: []u8 = io::drain(&decoder)!; - defer free(decb); - assert(bytes::equal(decb, strings::toutf8(cases[i].1))); - - // Testing decodestr should cover decodeslice too - let decb = decodestr(cases[i].2, cases[i].0) as []u8; - defer free(decb); - assert(bytes::equal(decb, strings::toutf8(cases[i].1))); - }; - // Repeat of the above, but with a larger buffer - for (let i = 0z; i < len(cases); i += 1) { - let in = memio::fixed(strings::toutf8(cases[i].0)); - let decoder = newdecoder(cases[i].2, &in); - let decb: []u8 = io::drain(&decoder)!; - defer free(decb); - assert(bytes::equal(decb, strings::toutf8(cases[i].1))); - }; - - const invalid: [_]str = [ + const invalid: [_](str, *encoding) = [ // invalid padding - "=", "==", "===", "=====", "======", - // invalid characters - "@Zg=", "êg=", "êg==", "$3d==", "%3d==", "[==", "!", - // data after padding is encountered - "Zg==Zg==", "Zm8=Zm8=", + ("=", &std_encoding), + ("==", &std_encoding), + ("===", &std_encoding), + ("=====", &std_encoding), + ("======", &std_encoding), + // invalid characters + ("@Zg=", &std_encoding), + ("ê==", &std_encoding), + ("êg==", &std_encoding), + ("$3d==", &std_encoding), + ("%3d==", &std_encoding), + ("[==", &std_encoding), + ("!", &std_encoding), + // data after padding is encountered + ("Zg===", &std_encoding), + ("Zg====", &std_encoding), + ("Zg==Zg==", &std_encoding), + ("Zm8=Zm8=", &std_encoding), ]; - const encodings: [_]*encoding = [&std_encoding, &url_encoding]; - for (let i = 0z; i < len(invalid); i += 1) { - for (let enc = 0z; enc < 2; enc += 1) { - let in = memio::fixed(strings::toutf8(invalid[i])); - let decoder = newdecoder(encodings[enc], &in); - let buf: [1]u8 = [0...]; + let buf: [12]u8 = [0...]; + for (let bufsz = 1z; bufsz <= 12; bufsz += 1) { + for (let (input, expected, encoding) .. cases) { + let in = memio::fixed(strings::toutf8(input)); + let decoder = newdecoder(encoding, &in); + let buf = buf[..bufsz]; + let decb: []u8 = []; + defer free(decb); + for (true) match (io::read(&decoder, buf)!) { + case let z: size => + if (z > 0) { + append(decb, buf[..z]...); + }; + case io::EOF => + break; + }; + assert(bytes::equal(decb, strings::toutf8(expected))); + + // Testing decodestr should cover decodeslice too + let decb = decodestr(encoding, input) as []u8; + defer free(decb); + assert(bytes::equal(decb, strings::toutf8(expected))); + }; + + for (let (input, encoding) .. invalid) { + let in = memio::fixed(strings::toutf8(input)); + let decoder = newdecoder(encoding, &in); + let buf = buf[..bufsz]; let valid = false; for (true) match(io::read(&decoder, buf)) { case errors::invalid => break; case size => - valid = true; + void; case io::EOF => - break; + abort(); }; - assert(valid == false, "valid is not false"); // Testing decodestr should cover decodeslice too - assert(decodestr(encodings[enc], invalid[i]) is errors::invalid); + assert(decodestr(encoding, input) is errors::invalid); }; }; };