hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

pss.ha (5235B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 use bytes;
      5 use crypto::math;
      6 use endian;
      7 use errors;
      8 use hash;
      9 use io;
     10 use types;
     11 
     12 export type default = void;
     13 
     14 // Required minimum buffer size for [[pss_verify]]
     15 export def PSS_VERIFYBUFSZ = PUBEXP_BUFSZ + ((BITSZ + 7) / 8);
     16 
     17 // Required minimum buffer size for [[pss_sign]]
     18 export def PSS_SIGNBUFSZ = PRIVEXP_BUFSZ;
     19 
     20 // Signs the hash 'msghash' using the private key 'privkey' by applying the PSS
     21 // signature scheme as defined in RFC 8017. 'sig' must be in the the size of the
     22 // modulus n (see [[privkey_nsize]])
     23 //
     24 // It is recommended that 'hf' is the same hash function that was used to
     25 // generate 'msgmhash'. 'buf' needs to be at least the size of
     26 // [[PSS_SIGNBUFSZ]]. 'rand' must be an [[io::reader]] that returns a
     27 // cryptographiclly random data on read like [[crypto::random::stream]]. The
     28 // expected size of the salt is provided with 'saltsz'. Default is the maximum
     29 // possible salt size.
     30 //
     31 // Returns [[errors::invalid]], if one of the parameters are invalid.
     32 // [[errors::overflow]] is returned, if 'buf' is to small. Errors that occur by
     33 // reading from 'rand' are returned as [[io::error]].
     34 export fn pss_sign(
     35 	privkey: []u8,
     36 	msghash: []u8,
     37 	sig: []u8,
     38 	hf: *hash::hash,
     39 	rand: io::handle,
     40 	buf: []u8,
     41 	saltsz: (size | default) = default,
     42 ) (void | error | io::error) = {
     43 	let priv = privkey_params(privkey);
     44 
     45 	// Use var names that match the rfc.
     46 	let embits = priv.nbitlen - 1;
     47 	if (len(sig) != (embits + 7) / 8) {
     48 		return errors::invalid: error;
     49 	};
     50 	let em = sig;
     51 	let hlen = len(msghash);
     52 	let emlen = len(em);
     53 	const slen = dsaltsz(len(sig), hlen, saltsz);
     54 
     55 	if (emlen < hlen + slen + 2) {
     56 		return errors::invalid: error;
     57 	};
     58 
     59 	let db = em[..emlen - hlen - 1];
     60 	db[..] = [0...];
     61 	db[len(db) - slen - 1] = 0x01;
     62 	let salt = db[len(db) - slen..];
     63 	io::readall(rand, salt)?;
     64 
     65 	let h = em[emlen - hlen - 1..emlen - 1];
     66 	const padding: [8]u8 = [0...];
     67 	hash::reset(hf);
     68 	hash::write(hf, padding);
     69 	hash::write(hf, msghash);
     70 	hash::write(hf, salt);
     71 	hash::sum(hf, h);
     72 
     73 	mgfxor(db, hf, h, buf);
     74 
     75 	em[0] &= 0xff >> (8*emlen - embits): u8;
     76 	em[len(db)..emlen - 1] = h[..];
     77 	em[emlen - 1] = 0xbc;
     78 
     79 	privexp(&priv, em, buf)?;
     80 };
     81 
     82 fn dsaltsz(nsz: size, hsz: size, s: (size | default)) size = {
     83 	match (s) {
     84 	case let s: size =>
     85 		return s;
     86 	case default =>
     87 		return nsz - hsz - 2;
     88 	};
     89 };
     90 
     91 // Verifies a PSS signature 'sig' of the mesage hash 'msghash' using the public
     92 // key 'pupkey' as defined in RFC 8017.
     93 //
     94 // 'hf' must be the hash that was used to create the signature. 'buf' needs to
     95 // be at least the size of [[PSS_VERIFYBUFSZ]]. The expected size of the salt is
     96 // provided with 'saltsz'. Default is the maximum possible salt size. The
     97 // function will fail, if the signature's salt size does not match the expected.
     98 //
     99 // Returns [[badsig]], if the signature verification fails. [[errors::overflow]]
    100 // is returned, if 'buf' is to small.
    101 export fn pss_verify(
    102 	pubkey: []u8,
    103 	msghash: []u8,
    104 	sig: []u8,
    105 	hf: *hash::hash,
    106 	buf: []u8,
    107 	saltsz: (size | default) = default,
    108 ) (void | error) = {
    109 	let pub = pubkey_params(pubkey);
    110 	if (len(sig) != len(pub.n)) {
    111 		return badsig;
    112 	};
    113 
    114 	// rename some variables to match the ones in the RFC
    115 	let mhash = msghash;
    116 	const hlen = hash::sz(hf);
    117 	const slen = dsaltsz(len(pub.n), hlen, saltsz);
    118 	let em = buf[..len(sig)];
    119 	const emlen = len(em);
    120 	em[..] = sig[..];
    121 
    122 	if (emlen < hlen + slen + 2) {
    123 		return badsig;
    124 	};
    125 
    126 	let pubbuf = buf[len(sig)..];
    127 	match (pubexp(&pub, em, pubbuf)) {
    128 	case void => void;
    129 	case errors::invalid =>
    130 		return badsig;
    131 	case let e: error =>
    132 		return e;
    133 	};
    134 
    135 	if (em[emlen - 1] != 0xbc) {
    136 		return badsig;
    137 	};
    138 
    139 	const maskdbsz = emlen - hlen - 1;
    140 	let maskeddb = em[..maskdbsz];
    141 	let h = em[maskdbsz..maskdbsz + hlen];
    142 
    143 	const embitlen = pubkey_nbitlen(pubkey) - 1;
    144 	const zerobitsh = 8 - (8*len(em) - embitlen);
    145 	if (maskeddb[0] >> zerobitsh != 0) {
    146 		return badsig;
    147 	};
    148 
    149 	let db = maskeddb;
    150 	mgfxor(db, hf, h, pubbuf);
    151 	db[0] &= 0xff >> (8*len(em) - embitlen): u8;
    152 
    153 	const seppos = len(em) - hlen - slen - 2;
    154 	for (let i = 0z; i < seppos; i += 1) {
    155 		if (db[i] != 0x00) {
    156 			return badsig;
    157 		};
    158 	};
    159 	if (db[seppos] != 0x01) {
    160 		return badsig;
    161 	};
    162 
    163 	const salt = db[len(db) - slen..];
    164 	const padding: [8]u8 = [0...];
    165 	hash::reset(hf);
    166 	hash::write(hf, padding);
    167 	hash::write(hf, mhash);
    168 	hash::write(hf, salt);
    169 
    170 	let genh = pubbuf[..hlen];
    171 	hash::sum(hf, genh);
    172 
    173 	if (math::eqslice(genh, h) != 1) {
    174 		return badsig;
    175 	};
    176 };
    177 
    178 // dest = dest XOR mgf(h, seed, len(dest)). 'buf' must be hash::sz(h) bytes
    179 // long.
    180 fn mgfxor(dest: []u8, h: *hash::hash, seed: []u8, buf: []u8) void = {
    181 	assert(len(buf) >= hash::sz(h));
    182 
    183 	let ctrbuf: [4]u8 = [0...];
    184 	let sum = buf[..hash::sz(h)];
    185 	const iterations = (len(dest) + len(sum) - 1) / len(sum);
    186 
    187 	for (let ctr: u32 = 0; ctr < iterations; ctr += 1) {
    188 		endian::beputu32(ctrbuf, ctr);
    189 		hash::reset(h);
    190 		hash::write(h, seed);
    191 		hash::write(h, ctrbuf);
    192 		hash::sum(h, sum);
    193 
    194 		const start = ctr * len(sum);
    195 		const remain = len(dest) - start;
    196 		const chunksz = if (remain < len(sum)) remain else len(sum);
    197 
    198 		let chunk = dest[start..start + chunksz];
    199 		math::xor(chunk, chunk, sum[..chunksz]);
    200 	};
    201 
    202 	bytes::zero(sum);
    203 };